mindspore.nn.MultiheadAttention

class mindspore.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, has_bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False)[source]

This is an implementation of multihead attention in the paper Attention is all you need. Given the query vector with source length, and the key and value vector with target length, the attention will be performed as the following

\[MultiHeadAttention(query, key, vector) = Concat(head_1, \dots, head_h)W^O\]

where \(head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)\). The default is with a bias.

if query, key and value tensor is same, then it will be self attention.

Parameters
  • embed_dim (int) – Total dimension of MultiheadAttention.

  • num_heads (int) – Number of attention heads. Note that embed_dim will be split across num_heads (i.e. each head will have dimension embed_dim // num_heads).

  • dropout (float) – Dropout probability of attn_output_weights. Default: 0.0.

  • has_bias (bool) – Whether adds bias to input / output projection layers. Default: True.

  • add_bias_kv (bool) – Whether adds bias to the key and value sequences at axis=0. Default: False.

  • add_zero_attn (bool) – Whether adds a new batch of zeros to the key and value sequences at axis=1. Default: False.

  • kdim (int) – Total number of features for keys. Default: None (kdim=embed_dim).

  • vdim (int) – Total number of features for values. Default: None (vdim=embed_dim).

  • batch_first (bool) – If True, then the input and output shape are \((batch, seq, feature)\) , else \((seq, batch, feature)\) . Default: False.

Inputs:
  • query (Tensor): The query embeddings. If query is unbatched, the shape is \((L, E_q)\), otherwise the shape is \((L, N, E_q)\) when batch_first=False or \((N, L, E_q)\) when batch_first=True, where \(L\) is the batch size, and \(E_q\) is the query embedding dimension embed_dim. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.

  • key (Tensor): The key embeddings. If key is unbatched, the shape is \((S, E_k)\), otherwise the shape is \((S, N, E_k)\) when batch_first=False or \((N, S, E_k)\) when batch_first=True, where \(S\) is the source sequence length, \(N\) is the batch size, and \(E_k\) is the key embedding dimension kdim. See “Attention Is All You Need” for more details.

  • value (Tensor): The value embeddings. If value is unbatched, the shape is \((S, E_v)\), otherwise the shape is \((S, N, E_v)\) when batch_first=False or \((N, S, E_v)\) when batch_first=True, where \(S\) is the source sequence length, \(N\) is the batch size, and \(E_v\) is the value embedding dimension vdim. See “Attention Is All You Need” for more details.

  • key_padding_mask (Tensor, optional): If specified, a mask of shape \((N, S)\) indicating which elements within key to ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be \((S)\). Binary and byte masks are supported. For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding key value.

  • need_weights (bool): Whether returns attn_output_weights in addition to attn_outputs. Default: True.

  • attn_mask (Tensor, optional): If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape \((L, S)\) or \((N\cdot\text{num\_heads}, L, S)\), where \(N\) is the batch size, \(L\) is the target sequence length, and \(S\) is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a True value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.

  • average_attn_weights (bool): If true, indicates that the returned attn_weights should be averaged across heads. Otherwise, attn_weights are provided separately per head. Note that this flag only has an effect when need_weights=True. Default: True (i.e. average weights across heads)

Outputs:

Tuple, a tuple contains(attn_output, attn_output_weights)

  • attn_output - Attention outputs. If input is unbatched, the output shape is \((L, E)\), otherwise the output shape is \((L, N, E)\) when batch_first=False or \((N, L, E)\) when batch_first=True, where \(L\) is the target sequence length, \(N\) is the batch size, and \(E\) is the embedding dimension embed_dim.

  • attn_output_weights - Only returned when need_weights=True. If average_attn_weights=True, returns attention weights averaged across heads with shape \((L, S)\) when input is unbatched or \((N, L, S)\) when input is batched, where \(N\) is the batch size, \(L\) is the target sequence length, and \(S\) is the source sequence length. If average_attn_weights=False, returns attention weights per head of shape \((\text{num\_heads}, L, S)\) when input is unbatched or \((N, \text{num\_heads}, L, S)\) when input is batched.

Supported Platforms:

Ascend GPU CPU

Examples

>>> embed_dim, num_heads = 128, 8
>>> seq_length, batch_size = 10, 8
>>> query = Tensor(np.random.randn(seq_length, batch_size, embed_dim), mindspore.float32)
>>> key = Tensor(np.random.randn(seq_length, batch_size, embed_dim), mindspore.float32)
>>> value = Tensor(np.random.randn(seq_length, batch_size, embed_dim), mindspore.float32)
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
>>> print(attn_output.shape)
(10, 8, 128)