mindspore.mint.nn.SyncBatchNorm

View Source On Gitee
class mindspore.mint.nn.SyncBatchNorm(num_features: int, eps: float = 1e-5, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, process_group: Optional[str] = None, dtype=None)[source]

Sync Batch Normalization layer over a N-dimension input.

Sync Batch Normalization is cross device synchronized Batch Normalization. The implementation of Batch Normalization only normalizes the data within each device. Sync Batch Normalization will normalize the input within the group. It has been described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. It rescales and recenters the feature using a mini-batch of data and the learned parameters which can be described in the following formula.

\[y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

Warning

This is an experimental API that is subject to change or deletion.

Parameters
  • num_features (int) – C from an expected input of size \((N, C, +)\).

  • eps (float) – \(\epsilon\), a value added to the denominator for numerical stability. Default: 1e-5 .

  • momentum (float) – A floating hyperparameter of the momentum for the running_mean and running_var computation. Default: 0.1 .

  • affine (bool) – A bool value. When set to True , \(\gamma\) and \(\beta\) can be learned. Default: True .

  • track_running_stats (bool, optional) – a boolean value that when set to True, this cell tracks the running mean and variance, and when set to False, this cell does not track such statistics. And this cell always uses batch statistics in both training and eval modes. Default: True .

  • process_group (mindspore.communication._comm_helper.GlobalComm, optional) – synchronization of stats happen within each process group individually. Default behavior is synchronization across the whole world.

  • dtype (mindspore.dtype, optional) – Dtype of Parameters. Default: None .

Inputs:
  • x (Tensor) - Tensor of shape \((N, C_{in}, +)\).

Outputs:

Tensor, the normalized, scaled, offset tensor, of shape \((N, C_{out}, +)\).

Raises
  • TypeError – If num_features is not an int.

  • TypeError – If eps is not a float.

  • ValueError – If num_features is less than 1.

  • ValueError – If momentum is not in range [0, 1].

  • ValueError – If rank_id in process_groups is not in range [0, rank_size).

Supported Platforms:

Ascend