mindspore.mint.nn.BatchNorm2d
- class mindspore.mint.nn.BatchNorm2d(num_features: int, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, dtype=None)[source]
Applies Batch Normalization over a 4D input as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .
\[y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta\]The mean and standard-deviation are calculated per-dimension over the mini-batches and \(\gamma\) and \(\beta\) are learnable parameter vectors of size C (where C is the number of features or channels of the input). By default, the elements of \(\gamma\) are set to 1 and the elements of \(\beta\) are set to 0.
Warning
This API does not support Dynamic Rank. This is an experimental API that is subject to change or deletion.
- Parameters
num_features (int) – C from an expected input of shape \((N, C, H, W)\).
eps (float, optional) – a value added to the denominator for numerical stability. Default:
1e-5
.momentum (float, optional) – the value used for the running_mean and running_var computation. Can be set to
None
for cumulative moving average. Default:0.1
.affine (bool, optional) – a boolean value that when set to
True
, this cell has learnable affine parameters. Default:True
.track_running_stats (bool, optional) – a boolean value that when set to
True
, this cell tracks the running mean and variance, and when set toFalse
, this cell does not track such statistics. And this cell always uses batch statistics in both training and eval modes. Default:True
.dtype (
mindspore.dtype
, optional) – Dtype of Parameters. Default:None
.
- Inputs:
input (Tensor) - The input with shape \((N, C, H, W)\).
- Outputs:
Tensor, has the same type and shape as input.
- Raises
TypeError – If num_features is not a int number.
TypeError – If eps is not a float.
ValueError – If num_features is less than 1.
- Supported Platforms:
Ascend
Examples
>>> import mindspore >>> from mindspore import Tensor, mint >>> input_x = mindspore.Tensor([0.3, 0.4, 0.5, 0.3]) >>> input_x = input_x.reshape((2, 2, 1, 1)) >>> net = mint.nn.BatchNorm2d(2) >>> output = net(input_x) >>> print(output) [[[[-0.99950075]] [[0.9980087]]] [[[0.999501]] [[-0.9980097]]]]