mindspore.mint.nn.GroupNorm

View Source On Gitee
class mindspore.mint.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, dtype=None)[source]

Group Normalization over a mini-batch of inputs.

Group Normalization is widely used in recurrent neural networks. It applies normalization on a mini-batch of inputs for each single training case as described in the paper Group Normalization.

Group Normalization divides the channels into groups and computes within each group the mean and variance for normalization, and it performs very stable over a wide range of batch size. \(\gamma\) and \(\beta\) are trainable scale and shift. It can be described using the following formula:

\[y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

where \(\gamma\) is weight, \(\beta\) is bias, and \(\epsilon\) is eps.

Parameters
  • num_groups (int) – The number of groups to be divided along the channel dimension.

  • num_channels (int) – The number of input channels.

  • eps (float, optional) – A value added to the denominator for numerical stability. Default: 1e-05 .

  • affine (bool, optional) – The parameters, such as \(\gamma\) and \(\beta\), are learnable when set to true . Default: True .

  • dtype (mindspore.dtype, optional) – Dtype of Parameters. Default: None .

Inputs:
  • input (Tensor) - The input feature with shape \((N, C, *)\), where \(*\) means, any number of additional dimensions.

Outputs:

Tensor, the normalized and scaled offset tensor, has the same shape and data type as the x.

Raises
  • TypeError – If num_groups or num_channels is not an int.

  • TypeError – If eps is not a float.

  • TypeError – If affine is not a bool.

  • ValueError – If num_groups or num_channels is less than 1.

  • ValueError – If num_channels is not divided by num_groups.

Supported Platforms:

Ascend

Examples

>>> import mindspore as ms
>>> import numpy as np
>>> group_norm_op = ms.mint.nn.GroupNorm(2, 2)
>>> x = ms.Tensor(np.ones([1, 2, 4, 4], np.float32))
>>> output = group_norm_op(x)
>>> print(output)
[[[[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]
  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]]]