mindspore.nn.GlobalBatchNorm

class mindspore.nn.GlobalBatchNorm(num_features, eps=1e-05, momentum=0.9, affine=True, gamma_init='ones', beta_init='zeros', moving_mean_init='zeros', moving_var_init='ones', use_batch_statistics=None, device_num_each_group=2)[source]

Global Batch Normalization layer over a N-dimension input.

Global Batch Normalization is cross device synchronized Batch Normalization. The implementation of Batch Normalization only normalizes the data within each device. Global Normalization will normalize the input within the group.It has been described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. It rescales and recenters the feature using a mini-batch of data and the learned parameters which can be described in the following formula.

\[y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

Note

Currently, GlobalBatchNorm only supports 2D and 4D inputs.

Parameters
  • num_features (int) – C from an expected input of size (N, C, H, W).

  • device_num_each_group (int) – The number of devices in each group. Default: 2.

  • eps (float) – A value added to the denominator for numerical stability. Default: 1e-5.

  • momentum (float) – A floating hyperparameter of the momentum for the running_mean and running_var computation. Default: 0.9.

  • affine (bool) – A bool value. When set to True, gamma and beta can be learned. Default: True.

  • gamma_init (Union[Tensor, str, Initializer, numbers.Number]) – Initializer for the gamma weight. The values of str refer to the function initializer including ‘zeros’, ‘ones’, ‘xavier_uniform’, ‘he_uniform’, etc. Default: ‘ones’.

  • beta_init (Union[Tensor, str, Initializer, numbers.Number]) – Initializer for the beta weight. The values of str refer to the function initializer including ‘zeros’, ‘ones’, ‘xavier_uniform’, ‘he_uniform’, etc. Default: ‘zeros’.

  • moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]) – Initializer for the moving mean. The values of str refer to the function initializer including ‘zeros’, ‘ones’, ‘xavier_uniform’, ‘he_uniform’, etc. Default: ‘zeros’.

  • moving_var_init (Union[Tensor, str, Initializer, numbers.Number]) – Initializer for the moving variance. The values of str refer to the function initializer including ‘zeros’, ‘ones’, ‘xavier_uniform’, ‘he_uniform’, etc. Default: ‘ones’.

  • use_batch_statistics (bool) – If true, use the mean value and variance value of current batch data. If false, use the mean value and variance value of specified value. If None, training process will use the mean and variance of current batch data and track the running mean and variance, eval process will use the running mean and variance. Default: None.

Inputs:
  • x (Tensor) - Tensor of shape \((N, C_{in}, H_{in}, W_{in})\).

Outputs:

Tensor, the normalized, scaled, offset tensor, of shape \((N, C_{out}, H_{out}, W_{out})\).

Raises
  • TypeError – If num_features or device_num_each_group is not an int.

  • TypeError – If eps is not a float.

  • ValueError – If num_features is less than 1.

  • ValueError – If momentum is not in range [0, 1].

  • ValueError – If device_num_each_group is less than 2.

Supported Platforms:

Ascend

Examples

>>> # This example should be run with multiple processes.
>>> # Please refer to the tutorial > Distributed Training on mindspore.cn.
>>> import numpy as np
>>> from mindspore.communication import init
>>> from mindspore import context
>>> from mindspore.context import ParallelMode
>>> from mindspore import nn
>>> from mindspore import Tensor
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> init()
>>> context.reset_auto_parallel_context()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
>>> global_bn_op = nn.GlobalBatchNorm(num_features=3, device_num_each_group=2)
>>> x = Tensor(np.ones([1, 3, 2, 2]).astype(np.float32))
>>> output = global_bn_op(x)
>>> print(output)
[[[[ 0.999995 0.999995 ]
   [ 0.999995 0.999995 ]]
  [[ 0.999995 0.999995 ]
   [ 0.999995 0.999995 ]]
  [[ 0.999995 0.999995 ]
   [ 0.999995 0.999995 ]]]]