mindspore.nn.Moments
- class mindspore.nn.Moments(axis=None, keep_dims=None)[source]
Calculates the mean and variance of x.
- Parameters
- Inputs:
x (Tensor) - The tensor to be calculated. Only float16 and float32 are supported. \((N,*)\) where \(*\) means,any number of additional dimensions.
- Outputs:
mean (Tensor) - The mean of x, with the same date type as input x.
variance (Tensor) - The variance of x, with the same date type as input x.
- Raises
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> x = Tensor(np.array([[[[1, 2, 3, 4], [3, 4, 5, 6]]]]), mindspore.float32) >>> net = nn.Moments(axis=0, keep_dims=True) >>> output = net(x) >>> print(output) (Tensor(shape=[1, 1, 2, 4], dtype=Float32, value= [[[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00], [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00, 6.00000000e+00]]]]), Tensor(shape=[1, 1, 2, 4], dtype=Float32, value= [[[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]]])) >>> net = nn.Moments(axis=1, keep_dims=True) >>> output = net(x) >>> print(output) (Tensor(shape=[1, 1, 2, 4], dtype=Float32, value= [[[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00], [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00, 6.00000000e+00]]]]), Tensor(shape=[1, 1, 2, 4], dtype=Float32, value= [[[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]]])) >>> net = nn.Moments(axis=2, keep_dims=True) >>> output = net(x) >>> print(output) (Tensor(shape=[1, 1, 1, 4], dtype=Float32, value= [[[[ 2.00000000e+00, 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]]]]), Tensor(shape=[1, 1, 1, 4], dtype=Float32, value= [[[[ 1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00]]]])) >>> net = nn.Moments(axis=3, keep_dims=True) >>> output = net(x) >>> print(output) (Tensor(shape=[1, 1, 2, 1], dtype=Float32, value= [[[[ 2.50000000e+00], [ 4.50000000e+00]]]]), Tensor(shape=[1, 1, 2, 1], dtype=Float32, value= [[[[ 1.25000000e+00], [ 1.25000000e+00]]]]))