mindsponge.cell.Attention

查看源文件
class mindsponge.cell.Attention(num_head, hidden_size, gating, q_data_dim, m_data_dim, output_dim, batch_size=None)[源代码]

多头注意力机制,具体实现请参考 Attention is all you need 。Attention公式如下,query向量长度与输入一致,key向量长度为key长度和目标长度。

\[Attention(query, key, vector) = Concat(head_1, \dots, head_h)W^O\]

\(head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)\) 。默认有偏置。如果query,key和value张量相同,则表现为self attention。

参数:
  • num_head (int) - 头的数量。

  • hidden_size (int) - 输入的隐藏尺寸。

  • gating (bool) - 判断attention是否经过gating的指示器。

  • q_data_dim (int) - query的最后一维度的长度。

  • m_data_dim (int) - key和value最后一维度的长度。

  • output_dim (int) - 输出的最后一维度的长度。

  • batch_size (int) - attention中权重的batch size,仅在有while控制流时使用,默认值: None

输入:
  • q_data (Tensor) - shape为 \((batch\_size, query\_seq\_length, q\_data_dim)\) 的query Tensor,其中query_seq_length是query向量的序列长度。

  • m_data (Tensor) - shape为 \((batch\_size, value\_seq\_length, m\_data_dim)\) 的key和value Tensor,其中value_seq_length是value向量的序列长度。

  • attention_mask (Tensor) - 注意力矩阵的mask。shape为 \((batch\_size, num\_heads, query\_seq\_length, value\_seq_length)\)

  • index (Tensor) - 在while循环中的索引,仅在有while控制流时使用。默认值: None

  • nonbatched_bias (Tensor) - attention矩阵中无batch维的偏置。shape为 \((num\_heads, query\_seq\_length, value\_seq_length)\)。默认值: None

输出:

Tensor。Attention层的输出tensor,shape是 \((batch\_size, query\_seq\_length, hidden\_size)\)

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> from mindsponge.cell import Attention
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor
>>> model = Attention(num_head=4, hidden_size=64, gating=True, q_data_dim=64,
...                   m_data_dim=64, output_dim=64)
>>> q_data = Tensor(np.ones((32, 128, 64)), mstype.float32)
>>> m_data = Tensor(np.ones((32, 256, 64)), mstype.float32)
>>> attention_mask = Tensor(np.ones((32, 4, 128, 256)), mstype.float32)
>>> attn_out= model(q_data, m_data, attention_mask)
>>> print(attn_out.shape)
(32, 128, 64)