mindspore.nn.Transformer

class mindspore.nn.Transformer(d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Cell, callable] = 'relu', custom_encoder: Optional[Cell] = None, custom_decoder: Optional[Cell] = None, layer_norm_eps: float = 1e-05, batch_first: bool = False, norm_first: bool = False, dtype=mstype.float32)[source]

Transformer module including encoder and decoder. The difference with the original implements is the module use the residual addition before the layer normalization. And the default hidden activation is gelu. The details can be found in Attention is all you need.

Parameters
  • d_model (int) – The number of expected features in the inputs tensor for Encoder and Decoder. Default: 512.

  • nhead (int) – The number of heads in the MultiheadAttention modules. Default: 8.

  • num_encoder_layers (int) – The number of encoder-layers in the encoder. Default: 6.

  • num_decoder_layers (int) – The number of decoder-layers in the decoder. Default: 6.

  • dim_feedforward (int) – The dimension of the feedforward layer. Default: 2048.

  • dropout (float) – The dropout value. Default: 0.1.

  • activation (Union[str, callable, Cell]) – The activation function of the intermediate layer, can be a string ("relu" or "gelu"), Cell instance (mindspore.nn.ReLU or mindspore.nn.GELU ) or a callable ( mindspore.ops.relu() or mindspore.ops.gelu() ). Default: "relu".

  • custom_encoder (Cell) – Custom encoder. Default: None.

  • custom_decoder (Cell) – Custom decoder. Default: None.

  • layer_norm_eps (float) – the epsilion value in layer normalization module. Default: 1e-5.

  • batch_first (bool) – If batch_first=True, then the shape of input and output tensors is \((batch, seq, feature)\) , otherwise the shape is \((seq, batch, feature)\) . Default: False.

  • norm_first (bool) – If norm_first = True, layer norm is located prior to attention and feedforward operations; if norm_first = False, layer norm is located after the attention and feedforward operations. Default: False.

  • dtype (mindspore.dtype) – Data type of Parameter. Default: mstype.float32 .

Inputs:
  • src (Tensor) - The source sequence to the encoder. For unbatched input, the shape is \((S, E)\) ; otherwise if batch_first=False , the shape is \((S, N, E)\) and if batch_first=True , the shape is \((N, S, E)\), where \((S)\) is the source sequence length, \((N)\) is the batch number and \((E)\) is the feature number. Supported types: float16, float32, float64.

  • tgt (Tensor) - The target sequence to the decoder. For unbatched input, the shape is \((T, E)\) ; otherwise if batch_first=False , the shape is \((T, N, E)\) and if batch_first=True , the shape is \((N, T, E)\), where \((T)\) is the target sequence length. Supported types: float16, float32, float64.

  • src_mask (Tensor, optional) - The mask of the src sequence. The shape is \((S, S)\) or \((N*nhead, S, S)\). Supported types: float16, float32, float64, bool. Default: None.

  • tgt_mask (Tensor, optional) - The mask of the tgt sequence. The shape is \((T, T)\) or \((N*nhead, T, T)\). Supported types: float16, float32, float64, bool. Default: None.

  • memory_mask (Tensor, optional) - The additive mask of the encoder output. The shape is \((T, S)\) . Supported types: float16, float32, float64, bool. Default: None.

  • src_key_padding_mask (Tensor, optional) - The mask of src keys per batch. The shape is \((S)\) for unbatched input, otherwise \((N, S)\) . Supported types: float16, float32, float64, bool. Default: None.

  • tgt_key_padding_mask (Tensor, optional) - The mask of tgt keys per batch. The shape is \((T)\) for unbatched input, otherwise \((N, S)\) . Supported types: float16, float32, float64, bool. Default: None.

  • memory_key_padding_mask (Tensor, optional) - The mask of memory keys per batch. The shape is \((S)\) for unbatched input, otherwise \((N, S)\) . Supported types: float16, float32, float64, bool. Default: None.

Outputs:

Tensor. The shape is \((T, E)\) for unbatched input, otherwise if batch_first=False , the shape is \((T, N, E)\) and if batch_first=True , the shape is \((N, T, E)\).

Raises
  • ValueError – If the batch sizes of the init argument src and tgt are not equal.

  • ValueError – If the number of features of the init argument src and tgt is not equal to that of d_model.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore as ms
>>> import numpy as np
>>> transformer_model = ms.nn.Transformer(nhead=16, num_encoder_layers=12)
>>> src = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
>>> tgt = ms.Tensor(np.random.rand(20, 32, 512), ms.float32)
>>> out = transformer_model(src, tgt)
>>> print(out.shape)
(20, 32, 512)