mindspore.nn.BatchNorm2d
- class mindspore.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True, gamma_init='ones', beta_init='zeros', moving_mean_init='zeros', moving_var_init='ones', use_batch_statistics=None, data_format='NCHW', dtype=mstype.float32)[source]
Batch Normalization is widely used in convolutional networks. This layer applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) to avoid internal covariate shift as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. It rescales and recenters the feature using a mini-batch of data and the learned parameters which can be described in the following formula.
\[y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]Note
The implementation of BatchNorm is different in graph mode and pynative mode, therefore that mode can not be changed after net was initialized. Note that the formula for updating the \(moving\_mean\) and \(moving\_var\) is
\[\begin{split}\text{moving_mean}=\text{moving_mean*momentum}+μ_β\text{*(1−momentum)}\\ \text{moving_var}=\text{moving_var*momentum}+σ^2_β\text{*(1−momentum)}\end{split}\]where \(moving\_mean\) is the updated mean, \(moving\_var\) is the updated variance, \(μ_β, σ^2_β\) are the observed value (mean and variance) of each batch of data.
- Parameters
num_features (int) – The number of channels of the input tensor. Expected input size is \((N, C, H, W)\), C represents the number of channels.
eps (float) – \(\epsilon\) added to the denominator for numerical stability. Default:
1e-5
.momentum (float) – A floating hyperparameter of the momentum for the running_mean and running_var computation. Default:
0.9
.affine (bool) – A bool value. When set to
True
, \(\gamma\) and \(\beta\) can be learned. Default:True
.gamma_init (Union[Tensor, str, Initializer, numbers.Number]) – Initializer for the \(\gamma\) weight. The values of str refer to the function mindspore.common.initializer including
'zeros'
,'ones'
, etc. Default:'ones'
.beta_init (Union[Tensor, str, Initializer, numbers.Number]) –
Initializer for the \(\beta\) weight. The values of str refer to the function mindspore.common.initializer including
'zeros'
,'ones'
, etc. Default:'zeros'
.moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]) –
Initializer for the moving mean. The values of str refer to the function mindspore.common.initializer including
'zeros'
,'ones'
, etc. Default:'zeros'
.moving_var_init (Union[Tensor, str, Initializer, numbers.Number]) –
Initializer for the moving variance. The values of str refer to the function mindspore.common.initializer including
'zeros'
,'ones'
, etc. Default:'ones'
.use_batch_statistics (bool) –
Default:
None
.If
true
, use the mean value and variance value of current batch data and track running mean and running variance.If
false
, use the mean value and variance value of specified value, and not track statistical value.If
None
, the use_batch_statistics is automatically set totrue
orfalse
according to the training and evaluation mode. During training, the parameter is set to true, and during evaluation, the parameter is set to false.
data_format (str) – The optional value for data format, is
'NHWC'
or'NCHW'
. Default:'NCHW'
.dtype (
mindspore.dtype
) – Dtype of Parameters. Default:mstype.float32
.
- Inputs:
x (Tensor) - Tensor of shape \((N, C, H, W)\). Supported types: float16, float32.
- Outputs:
Tensor, the normalized, scaled, offset tensor, of shape \((N, C, H, W)\).
- Raises
TypeError – If num_features is not an int.
TypeError – If eps is not a float.
ValueError – If num_features is less than 1.
ValueError – If momentum is not in range [0, 1].
ValueError – If data_format is neither ‘NHWC’ not ‘NCHW’.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> import mindspore as ms >>> net = ms.nn.BatchNorm2d(num_features=3) >>> x = ms.Tensor(np.ones([1, 3, 2, 2]).astype(np.float32)) >>> output = net(x) >>> print(output) [[[[ 0.999995 0.999995 ] [ 0.999995 0.999995 ]] [[ 0.999995 0.999995 ] [ 0.999995 0.999995 ]] [[ 0.999995 0.999995 ] [ 0.999995 0.999995 ]]]]