mindspore.nn.TransformerEncoderLayer
- class mindspore.nn.TransformerEncoderLayer(d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Cell, callable] = 'relu', layer_norm_eps: float = 1e-05, batch_first: bool = False, norm_first: bool = False)[source]
Transformer Encoder Layer. This is an implementation of the single layer of the transformer encoder layer, including multihead attention and feedward layer.
- Parameters
d_model (int) – The number of features in the input tensor.
nhead (int) – The number of heads in the MultiheadAttention modules.
dim_feedforward (int) – The dimension of the feedforward layer. Default:
2048
.dropout (float) – The dropout value. Default:
0.1
.activation (Union[str, callable, Cell]) – The activation function of the intermediate layer, can be a string (“relu” or “gelu”), Cell instance (nn.ReLU() or nn.GELU()) or a callable (ops.relu or ops.gelu). Default:
"relu"
.layer_norm_eps (float) – The epsilon value in LayerNorm modules. Default:
1e-5
.batch_first (bool) – If batch_first = True, then the shape of input and output tensors is \((batch, seq, feature)\) , otherwise the shape is \((seq, batch, feature)\) . Default:
False
.norm_first (bool) – If norm_first = True, layer norm is done prior to attention and feedforward operations, respectively. Default:
False
.
- Inputs:
src (Tensor): the sequence to the encoder layer.
src_mask (Tensor, optional): the mask for the src sequence. Default:
None
.src_key_padding_mask (Tensor, optional): the mask for the src keys per batch. Default:
None
.
- Outputs:
Tensor.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) >>> src = Tensor(np.random.rand(10, 32, 512), mindspore.float32) >>> out = encoder_layer(src) >>> # Alternatively, when batch_first=True: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) >>> src = Tensor(np.random.rand(32, 10, 512), mindspore.float32) >>> out = encoder_layer(src) >>> print(out.shape) (32, 10, 512)