mindflow.cell.AttentionBlock

查看源文件
class mindflow.cell.AttentionBlock(in_channels, num_heads, drop_mode='dropout', dropout_rate=0.0, compute_dtype=mstype.float32)[源代码]

AttentionBlock 包含 MultiHeadAttentionMLP 网络堆叠而成。

参数:
  • in_channels (int) - 输入的输入特征维度。

  • num_heads (int) - 输出的输出特征维度。

  • drop_mode (str) - dropout方式。默认值: dropout 。支持以下类型: dropoutdroppath

  • dropout_rate (float) - dropout层丢弃的比率,在 [0, 1] 范围。默认值: 0.0

  • compute_dtype (mindspore.dtype) - 网络层的数据类型。默认值: mstype.float32 ,表示 mindspore.float32

输入:
  • x (Tensor) - shape为 (batch_size,sequence_len,in_channels) 的Tensor。

  • mask (Tensor) - shape为 (batch_size,sequence_len,sequence_len)(sequence_len,sequence_len)(batch_size,numheads,sequence_len,sequence_len) 的Tensor.

输出:
  • output (Tensor) - shape为 (batch_size,sequence_len,in_channels) 的Tensor。

支持平台:

Ascend CPU

样例:

>>> from mindspore import ops
>>> from mindflow.cell import AttentionBlock
>>> model = AttentionBlock(in_channels=256, num_heads=4)
>>> x = ops.rand((4, 100, 256))
>>> output = model(x)
>>> print(output.shape)
(4, 100, 256)