mindspore.nn.MSELoss
- class mindspore.nn.MSELoss(reduction='mean')[source]
Calculates the mean squared error between the predicted value and the label value.
For simplicity, let \(x\) and \(y\) be 1-dimensional Tensor with length \(N\), the unreduced loss (i.e. with argument reduction set to ‘none’) of \(x\) and \(y\) is given as:
\[\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad \text{with} \quad l_n = (x_n - y_n)^2.\]where \(N\) is the batch size. If reduction is not ‘none’, then:
\[\begin{split}\ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} \end{cases}\end{split}\]- Parameters
reduction (str) – Type of reduction to be applied to loss. The optional values are “mean”, “sum”, and “none”. Default: “mean”.
- Inputs:
logits (Tensor) - The predicted value of the input. Tensor of any dimension.
labels (Tensor) - The input label. Tensor of any dimension, same shape as the logits in common cases. However, it supports the shape of logits is different from the shape of labels and they should be broadcasted to each other.
- Outputs:
Tensor, loss of type float, the shape is zero if reduction is ‘mean’ or ‘sum’, while the shape of output is the broadcasted shape if reduction is ‘none’.
- Raises
ValueError – If reduction is not one of ‘none’, ‘mean’ or ‘sum’.
ValueError – If logits and labels have different shapes and cannot be broadcasted.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> # Case 1: logits.shape = labels.shape = (3,) >>> loss = nn.MSELoss() >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32) >>> labels = Tensor(np.array([1, 1, 1]), mindspore.float32) >>> output = loss(logits, labels) >>> print(output) 1.6666667 >>> # Case 2: logits.shape = (3,), labels.shape = (2, 3) >>> loss = nn.MSELoss(reduction='none') >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32) >>> labels = Tensor(np.array([[1, 1, 1], [1, 2, 2]]), mindspore.float32) >>> output = loss(logits, labels) >>> print(output) [[0. 1. 4.] [0. 0. 1.]]