mindspore.nn.MultiheadAttention

View Source On Gitee
class mindspore.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, has_bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, dtype=mstype.float32)[source]

This is an implementation of multihead attention in the paper Attention is all you need. Given the query vector, the key vector and value vector, the attention will be performed as the following:

\[MultiHeadAttention(query, key, value) = Concat(head_1, \dots, head_h)W^O\]

where \(head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)\), and \(W^O\) , \(W_i^Q\) , \(W_i^K\) , \(W_i^V\) are weight matrices. The default input / output projection layers is with a bias.

if query, key and value tensor is same, then it will be self attention.

Parameters
  • embed_dim (int) – Total dimension of MultiheadAttention.

  • num_heads (int) – Number of attention heads. Note that embed_dim will be split across num_heads (i.e. each head will have dimension embed_dim // num_heads).

  • dropout (float) – Dropout probability of attn_output_weights. Default: 0.0.

  • has_bias (bool) – Whether adds bias to input / output projection layers. Default: True.

  • add_bias_kv (bool) – Whether adds bias to the key and value sequences at axis=0. Default: False.

  • add_zero_attn (bool) – Whether adds a new batch of zeros to the key and value sequences at axis=1. Default: False.

  • kdim (int) – Total number of features for keys. Default: None (kdim=embed_dim).

  • vdim (int) – Total number of features for values. Default: None (vdim=embed_dim).

  • batch_first (bool) – If True, then the input and output shape are \((batch, seq, feature)\) , else \((seq, batch, feature)\) . Default: False.

  • dtype (mindspore.dtype) – Data type of Parameter. Default: mstype.float32 .

Inputs:
  • query (Tensor) - The query embeddings. If query is unbatched, the shape is \((L, E_q)\), otherwise the shape is \((L, N, E_q)\) when batch_first=False or \((N, L, E_q)\) when batch_first=True , where \(L\) is the batch size, and \(E_q\) is the query embedding dimension embed_dim. Supported types: float16, float32, float64. Queries are compared against key-value pairs to produce the output.

  • key (Tensor) - The key embeddings. If key is unbatched, the shape is \((S, E_k)\), otherwise the shape is \((S, N, E_k)\) when batch_first=False or \((N, S, E_k)\) when batch_first=True , where \(S\) is the source sequence length, \(N\) is the batch size, and \(E_k\) is the key embedding dimension kdim. Supported types: float16, float32, float64.

  • value (Tensor) - The value embeddings. If value is unbatched, the shape is \((S, E_v)\), otherwise the shape is \((S, N, E_v)\) when batch_first=False or \((N, S, E_v)\) when batch_first=True , where \(S\) is the source sequence length, \(N\) is the batch size, and \(E_v\) is the value embedding dimension vdim. Supported types: float16, float32, float64.

  • key_padding_mask (Tensor, optional) - If specified, a mask of shape \((N, S)\) indicating which elements within key to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched query, shape should be \((S)\). Binary and float masks are supported. For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding key value. Supported float types: float16, float32, float64. Default: None.

  • need_weights (bool) - Whether returns attn_output_weights in addition to attn_outputs. Default: True.

  • attn_mask (Tensor, optional) - If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape \((L, S)\) or \((N\cdot\text{num_heads}, L, S)\), where \(N\) is the batch size, \(L\) is the target sequence length, and \(S\) is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. For a binary mask, a True value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight. Supported float types: float16, float32, float64. Default: None.

  • average_attn_weights (bool) - If true, indicates that the returned attn_weights should be averaged across heads. Otherwise, attn_weights are provided separately per head. Note that this flag only has an effect when need_weights=True. Default: True (i.e. average weights across heads)

Outputs:

Tuple, a tuple contains(attn_output, attn_output_weights)

  • attn_output - Attention outputs. If input is unbatched, the output shape is \((L, E)\), otherwise the output shape is \((L, N, E)\) when batch_first=False or \((N, L, E)\) when batch_first=True , where \(L\) is the target sequence length, \(N\) is the batch size, and \(E\) is the embedding dimension embed_dim.

  • attn_output_weights - Only returned when need_weights=True. If average_attn_weights=True, returns attention weights averaged across heads with shape \((L, S)\) when input is unbatched or \((N, L, S)\) when input is batched, where \(N\) is the batch size, \(L\) is the target sequence length, and \(S\) is the source sequence length. If average_attn_weights=False, returns attention weights per head of shape \((\text{num_heads}, L, S)\) when input is unbatched or \((N, \text{num_heads}, L, S)\) when input is batched.

Raises
  • ValueError – If the init argument embed_dim is not divisible by num_heads.

  • TypeError – If the input argument key_padding_mask is not bool or floating types.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore as ms
>>> import numpy as np
>>> embed_dim, num_heads = 128, 8
>>> seq_length, batch_size = 10, 8
>>> query = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
>>> key = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
>>> value = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
>>> multihead_attn = ms.nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
>>> print(attn_output.shape)
(10, 8, 128)