mindspore.ops.LayerNorm

查看源文件
class mindspore.ops.LayerNorm(begin_norm_axis=1, begin_params_axis=1, epsilon=1e-07)[源代码]

在输入Tensor上应用层归一化(Layer Normalization)。

此算子将在给定的轴上对输入进行层归一化。Layer Normalization 描述了LayerNorm。

\[y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta\]

其中 \(\gamma\) 是Scalar, \(\beta\) 是偏置项, \(\epsilon\) 是精度值。

参数:
  • begin_norm_axis (int) - 指定 input_x 需进行层归一化的起始维度,其值必须在[-1, rank(input_x))范围内。默认值: 1

  • begin_params_axis (int) - 指定输入参数(gamma, beta) 需进行层归一化的开始轴,其值必须在[-1, rank(input_x))范围内。默认值: 1

  • epsilon (float) - 添加到分母中的值(\(\epsilon\)),以确保数据稳定性。默认值: 1e-7

输入:
  • input_x (Tensor) - LayerNorm的输入,shape为 \((N, \ldots)\) 的Tensor。支持的数据类型:float16、float32、float64。

  • gamma (Tensor) - 可学习参数 \(\gamma\) ,shape为 \((P_\text{begin_params_axis}, \ldots, P_\text{rank(input_x)-1})\) 的Tensor。支持的数据类型:float16、float32、float64。

  • beta (Tensor) - 可学习参数 \(\beta\) 。shape为 \((P_\text{begin_params_axis}, \ldots, P_\text{rank(input_x)-1})\) 的Tensor。支持的数据类型:float16、float32、float64。

输出:

tuple[Tensor],3个Tensor组成的tuple,层归一化输入和更新后的参数。

  • output_x (Tensor) - 层归一化输入,数据类型和shape与 input_x 相同。

  • mean (Tensor) - 输入的均值,其shape的前 begin_norm_axis 维与 input_x 相同,其余维度为1。假设输入 input_x 的shape为 \((x_1, x_2, \ldots, x_R)\) , 输出 mean 的shape为 \((x_1, \ldots, x_{begin\_params\_axis}, 1, \ldots, 1)\) (当 begin_params_axis=0 时,mean shape为 \((1, \ldots, 1)\) )。

  • variance (Tensor) - 输入的方差,shape同 mean 一致。

异常:
  • TypeError - begin_norm_axisbegin_params_axis 不是int。

  • TypeError - epsilon 不是float。

  • TypeError - input_xgammabeta 不是Tensor。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> input_x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]), mindspore.float32)
>>> gamma = Tensor(np.ones([3]), mindspore.float32)
>>> beta = Tensor(np.ones([3]), mindspore.float32)
>>> layer_norm = ops.LayerNorm()
>>> output, mean, variance = layer_norm(input_x, gamma, beta)
>>> print(output)
[[-0.2247448  1.         2.2247448]
 [-0.2247448  1.         2.2247448]]
>>> print(mean)
[[2.]
 [2.]]
>>> print(variance)
[[0.6666667]
 [0.6666667]]