mindspore.nn.MSELoss

class mindspore.nn.MSELoss(reduction='mean')[source]

MSELoss creates a criterion to measure the mean squared error (squared L2-norm) between \(x\) and \(y\) element-wise, where \(x\) is the input and \(y\) is the target.

For simplicity, let \(x\) and \(y\) be 1-dimensional Tensor with length \(N\), the unreduced loss (i.e. with argument reduction set to ‘none’) of \(x\) and \(y\) is given as:

\[\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad \text{with} \quad l_n = (x_n - y_n)^2.\]

where \(N\) is the batch size. If reduction is not ‘none’, then:

\[\begin{split}\ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases}\end{split}\]
Parameters

reduction (str) – Type of reduction to be applied to loss. The optional values are “mean”, “sum”, and “none”. Default: “mean”.

Inputs:
  • logits (Tensor) - Tensor of shape \((N, *)\) where \(*\) means, any number of additional dimensions.

  • labels (Tensor) - Tensor of shape \((N, *)\), same shape as the logits in common cases. However, it supports the shape of logits is different from the shape of labels and they should be broadcasted to each other.

Outputs:

Tensor, loss float tensor, the shape is zero if reduction is ‘mean’ or ‘sum’, while the shape of output is the broadcasted shape if reduction is ‘none’.

Raises

ValueError – If reduction is not one of ‘none’, ‘mean’, ‘sum’.

Supported Platforms:

Ascend GPU CPU

Examples

>>> # Case 1: logits.shape = labels.shape = (3,)
>>> loss = nn.MSELoss()
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> labels = Tensor(np.array([1, 1, 1]), mindspore.float32)
>>> output = loss(logits, labels)
>>> print(output)
1.6666667
>>> # Case 2: logits.shape = (3,), labels.shape = (2, 3)
>>> loss = nn.MSELoss(reduction='none')
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> labels = Tensor(np.array([[1, 1, 1], [1, 2, 2]]), mindspore.float32)
>>> output = loss(logits, labels)
>>> print(output)
[[0. 1. 4.]
 [0. 0. 1.]]