mindspore.mint.nn.BatchNorm1d
- class mindspore.mint.nn.BatchNorm1d[source]
Applies Batch Normalization over a 2D or 3D input as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .
The mean and standard-deviation are calculated per-dimension over the mini-batches and
and are learnable parameter vectors of size C (where C is the number of features or channels of the input). By default, the elements of are set to 1 and the elements of are set to 0.Warning
This API does not support Dynamic Rank. This is an experimental API that is subject to change or deletion.
- Parameters
num_features (int) – C from an expected input of shape
.eps (float, optional) – a value added to the denominator for numerical stability. Default:
1e-5
.momentum (float, optional) – the value used for the running_mean and running_var computation. Can be set to
None
for cumulative moving average. Default:0.1
.affine (bool, optional) – a boolean value that when set to
True
, this cell has learnable affine parameters. Default:True
.track_running_stats (bool, optional) – a boolean value that when set to
True
, this cell tracks the running mean and variance, and when set toFalse
, this cell does not track such statistics. And this cell always uses batch statistics in both training and eval modes. Default:True
.dtype (
mindspore.dtype
, optional) – Dtype of Parameters. Default:None
.
- Inputs:
input (Tensor) - The input with shape
or , where means batch, means the number of feature or the number of channel, and is the length of sequence.
- Outputs:
Tensor, has the same type and shape as input.
- Raises
TypeError – If num_features is not an int number.
TypeError – If eps is not a float.
ValueError – If num_features is less than 1.
- Supported Platforms:
Ascend
Examples
>>> import mindspore >>> from mindspore import Tensor, mint >>> input_x = mindspore.Tensor([[0.7, 0.5, 0.5, 0.6], [0.5, 0.4, 0.6, 0.9]]) >>> net = mint.nn.BatchNorm1d(4) >>> output = net(input_x) >>> print(output) [[ 0.99950075 0.9980011 -0.9980068 -0.9997783] [-0.9995012 -0.99799967 0.9980068 0.9997778]]