mindflow.cell.AttentionBlock
- class mindflow.cell.AttentionBlock(in_channels, num_heads, drop_mode='dropout', dropout_rate=0.0, compute_dtype=mstype.float32)[源代码]
AttentionBlock 包含 MultiHeadAttention 和 MLP 网络堆叠而成。
- 参数:
in_channels (int) - 输入的输入特征维度。
num_heads (int) - 输出的输出特征维度。
drop_mode (str) - dropout方式。默认值:
dropout
。支持以下类型:dropout
和droppath
。dropout_rate (float) - dropout层丢弃的比率,在
[0, 1]
范围。默认值:0.0
。compute_dtype (mindspore.dtype) - 网络层的数据类型。默认值:
mstype.float32
,表示mindspore.float32
。
- 输入:
x (Tensor) - shape为
的Tensor。mask (Tensor) - shape为
或 或 的Tensor.
- 输出:
output (Tensor) - shape为
的Tensor。
- 支持平台:
Ascend
CPU
样例:
>>> from mindspore import ops >>> from mindflow.cell import AttentionBlock >>> model = AttentionBlock(in_channels=256, num_heads=4) >>> x = ops.rand((4, 100, 256)) >>> output = model(x) >>> print(output.shape) (4, 100, 256)