mindflow.cell.MultiHeadAttention
- class mindflow.cell.MultiHeadAttention(in_channels, num_heads, drop_mode='dropout', dropout_rate=0.0, compute_dtype=mstype.float32)[source]
Multi Head Attention proposed in Attention Is All You Need.
- Parameters
in_channels (int) – The input channels.
num_heads (int) – The number of attention heads.
drop_mode (str) – Dropout method,
dropout
ordroppath
. Default:dropout
.dropout_rate (float) – The drop rate of dropout layer, greater than 0 and less equal than 1. Default:
0.0
.compute_dtype (mindspore.dtype) – Compute dtype. Default:
mstype.float32
, indicatesmindspore.float32
.
- Inputs:
x (Tensor) - Tensor with shape
.attn_mask (Tensor) - Tensor with shape
or or .key_padding_mask (Tensor) - Tensor with shape
or or .
- Outputs:
output (Tensor) - Tensor with shape
.
- Supported Platforms:
Ascend
CPU
Examples
>>> from mindspore import ops >>> from mindflow.cell import MultiHeadAttention >>> model = MultiHeadAttention(in_channels=512, num_heads=4) >>> x = ops.rand((2, 32, 512)) >>> mask_shape = (2, 4, 32, 32) >>> mask = ops.ones(mask_shape) >>> output = model(x, mask) >>> print(output.shape) (2, 32, 512)