mindspore.ops.group_norm

View Source On Gitee
mindspore.ops.group_norm(input, num_groups, weight=None, bias=None, eps=1e-05)[source]

Group Normalization over a mini-batch of inputs.

Group Normalization is widely used in recurrent neural networks. It applies normalization on a mini-batch of inputs for each single training case as described in the paper Group Normalization. Group Normalization divides the channels into groups and computes within each group the mean and variance for normalization, and it performs very stable over a wide range of batch size. \(\gamma\) and \(\beta\) are trainable scale and shift. It can be described using the following formula:

\[y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

where \(\gamma\) is weight, \(\beta\) is bias, \(\epsilon\) is eps.

Parameters
  • input (Tensor) – The input feature with shape \((N, C, *)\) where \(*\) means, any number of additional dimensions.

  • num_groups (int) – The number of groups to be divided along the channel dimension.

  • weight (Tensor, optional) – The shape \((C,)\), Default: None, has the same data type with input.

  • bias (Tensor, optional) – The shape \((C,)\), Default: None, has the same data type with input.

  • eps (float, optional) – A value added to the denominator for numerical stability. Default: 1e-5 .

Returns

Tensor, the normalized and scaled offset tensor, has the same shape and data type as the input.

Raises
  • TypeError – If num_groups is not an int.

  • TypeError – If eps is not a float.

  • ValueError – If num_groups is less than 1.

  • ValueError – If C (the second parameter of dimensions of input) is not divided by num_groups.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore as ms
>>> import numpy as np
>>> from mindspore import ops
>>> x = ms.Tensor(np.ones([1, 2, 4, 4], np.float32))
>>> output = ops.group_norm(x, 2)
>>> print(output)
[[[[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]
  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]]]