文档反馈

问题文档片段

问题文档片段包含公式时,显示为空格。

提交类型
issue

有点复杂...

找人问问吧。

请选择提交类型

问题类型
规范和低错类

- 规范和低错类:

- 错别字或拼写错误,标点符号使用错误、公式错误或显示异常。

- 链接错误、空单元格、格式错误。

- 英文中包含中文字符。

- 界面和描述不一致,但不影响操作。

- 表述不通顺,但不影响理解。

- 版本号不匹配:如软件包名称、界面版本号。

易用性

- 易用性:

- 关键步骤错误或缺失,无法指导用户完成任务。

- 缺少主要功能描述、关键词解释、必要前提条件、注意事项等。

- 描述内容存在歧义指代不明、上下文矛盾。

- 逻辑不清晰,该分类、分项、分步骤的没有给出。

正确性

- 正确性:

- 技术原理、功能、支持平台、参数类型、异常报错等描述和软件实现不一致。

- 原理图、架构图等存在错误。

- 命令、命令参数等错误。

- 代码片段错误。

- 命令无法完成对应功能。

- 界面错误,无法指导操作。

- 代码样例运行报错、运行结果不符。

风险提示

- 风险提示:

- 对重要数据或系统存在风险的操作,缺少安全提示。

内容合规

- 内容合规:

- 违反法律法规,涉及政治、领土主权等敏感词。

- 内容侵权。

请选择问题类型

问题描述

点击输入详细问题描述,以帮助我们快速定位问题。

mindspore.nn.MultiheadAttention

class mindspore.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, has_bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False)[源代码]

论文 Attention Is All You Need 中所述的多头注意力的实现。给定query向量,key向量和value,注意力计算流程如下:

MultiHeadAttention(query,key,vector)=Concat(head1,,headh)WO

其中, headi=Attention(QWiQ,KWiK,VWiV) 。注意:输出层的投影计算中带有偏置参数。

如果query、key和value相同,则上述即为自注意力机制的计算过程。

警告

这是一个实验性API,后续可能修改或删除。

参数:
  • embed_dim (int) - 模型的总维数。

  • num_heads (int) - 并行注意力头的数量。num_heads 需要能够被 embed_dim 整除(每个头的维数为 embed_dim // num_heads)。

  • dropout (float) - 应用到输入 attn_output_weights 上的随机丢弃比例. 默认值: 0.0

  • has_bias (bool) - 是否给输入、输出投射层添加偏置。默认值: True

  • add_bias_kv (bool) - 是否给key、value序列的0维添加偏置。默认值: False

  • add_zero_attn (bool) - 是否给key、value序列的一维添加0。默认值: False

  • kdim (int) - key的总特征数。默认值: None (即 kdim=embed_dim)。

  • vdim (int) - value的总特征数。默认值:None (即 vdim=embed_dim)。

  • batch_first (bool) - 如果为 True,则输入输出Tensor的shape为 (batch,seq,feature) ,否则shape为 (seq,batch,feature) 。 默认值: False

输入:
  • query (Tensor) - Query矩阵。当输入非Batch数据时,Shape为: (L,Eq) 。当输入Batch数据,参数 batch_first=False 时,Shape为 (L,N,Eq) , 当 batch_first=True 时,Shape为 (N,L,Eq)。其中, L 为目标序列的长度, N 为batch size,Eq 为Query矩阵的维数 embed_dim。 注意力机制通过Query与Key-Value运算以生成最终输出。详情请见:”Attention Is All You Need”。

  • key (Tensor) - Key矩阵。当输入非Batch数据时,Shape为: (S,Ek) 。当输入Batch数据,参数 batch_first=False 时,Shape为 (S,N,Ek) , 当 batch_first=True 时,Shape为 (N,S,Ek)。其中, S 为源序列的长度, N 为batch size,Ek 为Key矩阵的维数 kdim。详情请见:”Attention Is All You Need”。

  • value (Tensor) - Value矩阵。当输入非Batch数据时,Shape为: (S,Ev) 。当输入Batch数据,参数 batch_first=False 时,Shape为 (S,N,Ev) , 当 batch_first=True 时,Shape为 (N,S,Ev)。其中, S 为源序列的长度, N 为batch size,Ev 为Key矩阵的维数 vdim。详情请见:”Attention Is All You Need”。

  • key_padding_mask (Tensor, optional) - 如果指定此值,则表示Shape为 (N,S)。当输入非Batch数据时,Shape为: (S) 。 如果输入Tensor为Bool类型,则 key 中对应为 True 的位置将在Attention计算时被忽略。如果输入Tensor为Float类型,则将直接与 key 相加。默认值:None

  • need_weights (bool) - 是否需要返回 attn_output_weights,如果为 True,则输出包含 attn_output_weights。默认值:True

  • attn_mask (Tensor, optional) - 如果指定此值,则表示Shape为 (L,S)(Nnum_heads,L,S) 的掩码将被用于Attention计算。其中 N 为batch size, L 为目标序列长度,S 为源序列长度。如果输入为2维矩阵,则将自动沿batch维广播至3维矩阵。若为3维矩阵,则允许沿batch维使用不同的掩码。如果输入Tensor为Bool类型,则值为 True 对应位置允许被注意力计算。如果输入Tensor为Float类型,则将直接与注意力权重相加。默认值:None

  • average_attn_weights (bool) - 如果为 True, 则返回值 attn_weights 为注意力头的平均值。如果为 False,则 attn_weights 分别返回每个注意力头的值。 本参数仅在 need_weights=True 时生效。默认值: True

输出:

Tuple,表示一个包含(attn_output, attn_output_weights)的元组。

  • attn_output - 注意力机制的输出。当输入非Batch数据时,Shape为: (L,E) 。当输入Batch数据, 参数 batch_first=False 时,Shape为 (L,N,E) , 当 batch_first=True 时,Shape为 (N,L,E)。其中, L 为目标序列的长度, N 为batch size, E 为模型的总维数 embed_dim

  • attn_output_weights - 仅当 need_weights=True 时返回。如果 average_attn_weights=True,则返回值 attn_weights 为注意力头的平均值。当输入非Batch数据时, Shape为: (L,S) ,当输入Batch数据时,Shape为 (N,L,S)。其中 N 为batch size, L 为目标序列的长度,S 为源序列长度。 如果 average_attn_weights=False ,分别返回每个注意力头的值。当输入非Batch数据时,Shape为: (num_heads,L,S) ,当输入Batch数据时,Shape为 (N,num_heads,L,S)

异常:
  • ValueError - 如果 embed_dim 不能被 num_heads 整除。

  • TypeError - 如果 key_padding_mask 不是bool或float类型。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore as ms
>>> import numpy as np
>>> embed_dim, num_heads = 128, 8
>>> seq_length, batch_size = 10, 8
>>> query = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
>>> key = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
>>> value = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
>>> multihead_attn = ms.nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
>>> print(attn_output.shape)
(10, 8, 128)