mindspore.ops.LayerNorm
- class mindspore.ops.LayerNorm(begin_norm_axis=1, begin_params_axis=1, epsilon=1e-07)[source]
Applies the Layer Normalization to the input tensor.
This operator will normalize the input tensor on given axis. LayerNorm is described in the paper Layer Normalization.
where
is scale, is bias, is epsilon.- Parameters
begin_norm_axis (int) – The begin axis of the input_x to apply LayerNorm, the value must be in [-1, rank(input_x)). Default:
1
.begin_params_axis (int) – The begin axis of the parameter input (gamma, beta) to apply LayerNorm, the value must be in [-1, rank(input_x)). Default:
1
.epsilon (float) – A value added to the denominator for numerical stability(
). Default:1e-7
.
- Inputs:
input_x (Tensor) - Tensor of shape
. The input of LayerNorm. Supported dtypes: float16, float32, float64.gamma (Tensor) - Tensor of shape
. The learnable parameter as the scale on norm. Supported dtypes: float16, float32, float64.beta (Tensor) - Tensor of shape
. The learnable parameter as the scale on norm. Supported dtypes: float16, float32, float64.
- Outputs:
tuple[Tensor], tuple of 3 tensors, the normalized input and the updated parameters.
output_x (Tensor) - The normalized input, has the same type and shape as the input_x.
mean (Tensor) - The first begin_norm_axis dimensions of mean shape is the same as input_x, and the remaining dimensions are 1. Suppose the shape of the input_x is
, the shape of the mean is (when begin_params_axis=0, the shape of mean is ).variance (Tensor) - Shape is the same as mean .
- Raises
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore >>> import numpy as np >>> from mindspore import Tensor, ops >>> input_x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]), mindspore.float32) >>> gamma = Tensor(np.ones([3]), mindspore.float32) >>> beta = Tensor(np.ones([3]), mindspore.float32) >>> layer_norm = ops.LayerNorm() >>> output, mean, variance = layer_norm(input_x, gamma, beta) >>> print(output) [[-0.2247448 1. 2.2247448] [-0.2247448 1. 2.2247448]] >>> print(mean) [[2.] [2.]] >>> print(variance) [[0.6666667] [0.6666667]]