mindsponge.cell.Attention
- class mindsponge.cell.Attention(num_head, hidden_size, gating, q_data_dim, m_data_dim, output_dim, batch_size=None)[源代码]
多头注意力机制,具体实现请参考 Attention is all you need 。Attention公式如下,query向量长度与输入一致,key向量长度为key长度和目标长度。
\[Attention(query, key, vector) = Concat(head_1, \dots, head_h)W^O\]\(head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)\) 。默认有偏置。如果query,key和value张量相同,则表现为self attention。
- 参数:
num_head (int) - 头的数量。
hidden_size (int) - 输入的隐藏尺寸。
gating (bool) - 判断attention是否经过gating的指示器。
q_data_dim (int) - query的最后一维度的长度。
m_data_dim (int) - key和value最后一维度的长度。
output_dim (int) - 输出的最后一维度的长度。
batch_size (int) - attention中权重的batch size,仅在有while控制流时使用,默认值:
None
。
- 输入:
q_data (Tensor) - shape为 \((batch\_size, query\_seq\_length, q\_data_dim)\) 的query Tensor,其中query_seq_length是query向量的序列长度。
m_data (Tensor) - shape为 \((batch\_size, value\_seq\_length, m\_data_dim)\) 的key和value Tensor,其中value_seq_length是value向量的序列长度。
attention_mask (Tensor) - 注意力矩阵的mask。shape为 \((batch\_size, num\_heads, query\_seq\_length, value\_seq_length)\)。
index (Tensor) - 在while循环中的索引,仅在有while控制流时使用。默认值:
None
。nonbatched_bias (Tensor) - attention矩阵中无batch维的偏置。shape为 \((num\_heads, query\_seq\_length, value\_seq_length)\)。默认值:
None
。
- 输出:
Tensor。Attention层的输出tensor,shape是 \((batch\_size, query\_seq\_length, hidden\_size)\)。
- 支持平台:
Ascend
GPU
样例:
>>> import numpy as np >>> from mindsponge.cell import Attention >>> from mindspore import dtype as mstype >>> from mindspore import Tensor >>> model = Attention(num_head=4, hidden_size=64, gating=True, q_data_dim=64, ... m_data_dim=64, output_dim=64) >>> q_data = Tensor(np.ones((32, 128, 64)), mstype.float32) >>> m_data = Tensor(np.ones((32, 256, 64)), mstype.float32) >>> attention_mask = Tensor(np.ones((32, 4, 128, 256)), mstype.float32) >>> attn_out= model(q_data, m_data, attention_mask) >>> print(attn_out.shape) (32, 128, 64)