mindspore.nn.MSELoss
- class mindspore.nn.MSELoss(reduction='mean')[source]
MSELoss creates a criterion to measure the mean squared error (squared L2-norm) between \(x\) and \(y\) element-wise, where \(x\) is the input and \(y\) is the labels.
For simplicity, let \(x\) and \(y\) be 1-dimensional Tensor with length \(N\), the unreduced loss (i.e. with argument reduction set to ‘none’) of \(x\) and \(y\) is given as:
\[\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad \text{with} \quad l_n = (x_n - y_n)^2.\]where \(N\) is the batch size. If reduction is not ‘none’, then:
\[\begin{split}\ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases}\end{split}\]- Parameters
reduction (str) – Type of reduction to be applied to loss. The optional values are “mean”, “sum”, and “none”. Default: “mean”.
- Inputs:
logits (Tensor) - Tensor of shape \((N, *)\) where \(*\) means, any number of additional dimensions.
labels (Tensor) - Tensor of shape \((N, *)\), same shape as the logits in common cases. However, it supports the shape of logits is different from the shape of labels and they should be broadcasted to each other.
- Outputs:
Tensor, loss float tensor, the shape is zero if reduction is ‘mean’ or ‘sum’, while the shape of output is the broadcasted shape if reduction is ‘none’.
- Raises
ValueError – If reduction is not one of ‘none’, ‘mean’, ‘sum’.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> # Case 1: logits.shape = labels.shape = (3,) >>> loss = nn.MSELoss() >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32) >>> labels = Tensor(np.array([1, 1, 1]), mindspore.float32) >>> output = loss(logits, labels) >>> print(output) 1.6666667 >>> # Case 2: logits.shape = (3,), labels.shape = (2, 3) >>> loss = nn.MSELoss(reduction='none') >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32) >>> labels = Tensor(np.array([[1, 1, 1], [1, 2, 2]]), mindspore.float32) >>> output = loss(logits, labels) >>> print(output) [[0. 1. 4.] [0. 0. 1.]]