mindspore.nn.L1Loss

class mindspore.nn.L1Loss(reduction='mean')[source]

L1Loss creates a criterion to measure the mean absolute error (MAE) between \(x\) and \(y\) element-wise, where \(x\) is the input Tensor and \(y\) is the target Tensor.

For simplicity, let \(x\) and \(y\) be 1-dimensional Tensor with length \(N\), the unreduced loss (i.e. with argument reduction set to ‘none’) of \(x\) and \(y\) is given as:

\[\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad \text{with } l_n = \left| x_n - y_n \right|,\]

where \(N\) is the batch size. If reduction is not ‘none’, then:

\[\begin{split}\ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} \end{cases}\end{split}\]
Parameters

reduction (str) – Type of reduction to be applied to loss. The optional values are “mean”, “sum”, and “none”. Default: “mean”.

Inputs:
  • logits (Tensor) - Tensor of shape \((N, *)\) where \(*\) means, any number of additional dimensions.

  • labels (Tensor) - Tensor of shape \((N, *)\), same shape as the logits in common cases. However, it supports the shape of logits is different from the shape of labels and they should be broadcasted to each other.

Outputs:

Tensor, loss float tensor, the shape is zero if reduction is ‘mean’ or ‘sum’, while the shape of output is the broadcasted shape if reduction is ‘none’.

Raises

ValueError – If reduction is not one of ‘none’, ‘mean’, ‘sum’.

Supported Platforms:

Ascend GPU CPU

Examples

>>> # Case 1: logits.shape = labels.shape = (3,)
>>> loss = nn.L1Loss()
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> labels = Tensor(np.array([1, 2, 2]), mindspore.float32)
>>> output = loss(logits, labels)
>>> print(output)
0.33333334
>>> # Case 2: logits.shape = (3,), labels.shape = (2, 3)
>>> loss = nn.L1Loss(reduction='none')
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> labels = Tensor(np.array([[1, 1, 1], [1, 2, 2]]), mindspore.float32)
>>> output = loss(logits, labels)
>>> print(output)
[[0. 1. 2.]
 [0. 0. 1.]]