mindspore.ops.BNTrainingUpdate
- class mindspore.ops.BNTrainingUpdate(*args, **kwargs)[source]
For the BatchNorm operation, this operator update the moving averages for training and is used in conjunction with BNTrainingReduce.
- Parameters
- Inputs:
x (Tensor) - A 4-D Tensor with float16 or float32 data type. Tensor of shape \((N, C, A, B)\).
sum (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator BNTrainingReduce. Tensor of shape \((C,)\).
square_sum (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator BNTrainingReduce. Tensor of shape \((C,)\).
scale (Tensor) - A 1-D Tensor with float16 or float32, for the scaling factor. Tensor of shape \((C,)\).
offset (Tensor) - A 1-D Tensor with float16 or float32, for the scaling offset. Tensor of shape \((C,)\).
mean (Tensor) - A 1-D Tensor with float16 or float32, for the scaling mean. Tensor of shape \((C,)\).
variance (Tensor) - A 1-D Tensor with float16 or float32, for the update variance. Tensor of shape \((C,)\).
- Outputs:
y (Tensor) - Tensor, has the same shape data type as x.
mean (Tensor) - Tensor for the updated mean, with float32 data type. Has the same shape as variance.
variance (Tensor) - Tensor for the updated variance, with float32 data type. Has the same shape as variance.
batch_mean (Tensor) - Tensor for the mean of x, with float32 data type. Has the same shape as variance.
batch_variance (Tensor) - Tensor for the mean of variance, with float32 data type. Has the same shape as variance.
- Raises
- Supported Platforms:
Ascend
Examples
>>> input_x = Tensor(np.ones([1, 2, 2, 2]), mindspore.float32) >>> sum = Tensor(np.ones([2]), mindspore.float32) >>> square_sum = Tensor(np.ones([2]), mindspore.float32) >>> scale = Tensor(np.ones([2]), mindspore.float32) >>> offset = Tensor(np.ones([2]), mindspore.float32) >>> mean = Tensor(np.ones([2]), mindspore.float32) >>> variance = Tensor(np.ones([2]), mindspore.float32) >>> bn_training_update = ops.BNTrainingUpdate() >>> output = bn_training_update(input_x, sum, square_sum, scale, offset, mean, variance) >>> print(output) (Tensor(shape=[1, 2, 2, 2], dtype=Float32, value= [[[[ 2.73200464e+00, 2.73200464e+00], [ 2.73200464e+00, 2.73200464e+00]], [[ 2.73200464e+00, 2.73200464e+00], [ 2.73200464e+00, 2.73200464e+00]]]]), Tensor(shape=[2], dtype=Float32, value= [9.24999952e-0.1, 9.24999952e-0.1]), Tensor(shape=[2], dtype=Float32, value= [ 9.24999952e-0.1, 9.24999952e-0.1]), Tensor(shape=[2], dtype=Float32, value= [ 2.50000000e-0.1, 2.50000000e-0.1]), Tensor(shape=[2], dtype=Float32, value= [ 1.87500000e-0.1, 1.87500000e-0.1]))