Function Differences with torch.var_mean
torch.var_mean
torch.var_mean(input, dim, unbiased=True, keepdim=False, *, out=None)
For more information, see torch.var_mean.
mindspore.ops.var_mean
mindspore.ops.var_mean(input, axis=None, ddof=0, keepdims=False)
For more information, see mindspore.ops.var_mean.
Differences
PyTorch: Output the variance and mean value of the Tensor in each dimension, or the variance and mean value of the specified dimension according to dim
. If unbiased
is True, use Bessel for correction; if False, use bias estimation to calculate the variance. keepdim
controls whether the output and input dimensions are the same.
MindSpore: Output the variance and mean value of the Tensor in each dimension, or the variance and mean value of the specified dimension according to axis
. If ddof
is a boolean, it has the same effect as unbiased
; if ddof
is an integer, the divisor used in the calculation is N-ddof, where N denotes the number of elements. keepdim
controls whether the output and the input have the same dimensionality.
Categories |
Subcategories |
PyTorch |
MindSpore |
Differences |
---|---|---|---|---|
Parameters |
Parameter 1 |
input |
input |
Same function, different parameter names |
Parameter 2 |
dim |
axis |
Same function, different parameter names |
|
Parameter 3 |
unbiased |
ddof |
|
|
Parameter 4 |
keepdim |
keepdims |
Same function, different parameter names |
|
Parameter 5 |
out |
- |
MindSpore does not have this parameter |
Code Example
# PyTorch
import torch
input = torch.tensor([[[9, 7, 4, -10],
[-9, -2, 1, -2]]], dtype=torch.float32)
print(torch.var_mean(input, dim=2, unbiased=True, keepdim=True))
# (tensor([[[73.6667],
# [18.0000]]]), tensor([[[ 2.5000],
# [-3.0000]]]))
# MindSpore
import mindspore as ms
input = ms.Tensor([[[9, 7, 4, -10],
[-9, -2, 1, -2]]], ms.float32)
print(ms.ops.var_mean(input, axis=2, ddof=True, keepdims=True))
# (Tensor(shape=[1, 2, 1], dtype=Float32, value=
# [[[ 7.36666641e+01],
# [ 1.79999981e+01]]]), Tensor(shape=[1, 2, 1], dtype=Float32, value=
# [[[ 2.50000000e+00],
# [-3.00000000e+00]]]))