mindspore.nn.MultilabelMarginLoss
- class mindspore.nn.MultilabelMarginLoss(reduction='mean')[source]
Creates a loss criterion that minimizes the hinge loss for multi-class classification tasks. It takes a 2D mini-batch Tensor
as input and a 2D Tensor containing target class indices as output.Each sample in the mini-batch, the loss is computed as follows:
where
, , , and for all and , does not equal to .Furthermore, both
and should have identical sizes.Note
For this operator, only a contiguous sequence of non-negative targets that starts at the beginning is taken into consideration, which means that different samples can have different number of target classes.
- Parameters
reduction (str, optional) –
Apply specific reduction method to the output:
'none'
,'mean'
,'sum'
. Default:'mean'
.'none'
: no reduction will be applied.'mean'
: compute and return the mean of elements in the output.'sum'
: the output elements will be summed.
- Inputs:
x (Tensor) - Predict data. Tensor of shape
or , where is the batch size and is the number of classes. Data type must be float16 or float32.target (Tensor) - Ground truth data, with the same shape as x, data type must be int32 and label targets padded by -1.
- Outputs:
y (Union[Tensor, Scalar]) - The loss of MultilabelMarginLoss. If reduction is
"none"
, its shape is . Otherwise, a scalar value will be returned.
- Raises
TypeError – If x or target is not a Tensor.
TypeError – If dtype of x is neither float16 nor float32.
TypeError – If dtype of target is not int32.
ValueError – If length of shape of x is neither 1 nor 2.
ValueError – If shape of x is not the same as target.
ValueError – If reduction is not one of
'none'
,'mean'
,'sum'
.
- Supported Platforms:
Ascend
GPU
Examples
>>> import mindspore as ms >>> import mindspore.nn as nn >>> import numpy as np >>> loss = nn.MultilabelMarginLoss() >>> x = ms.Tensor(np.array([[0.1, 0.2, 0.4, 0.8], [0.2, 0.3, 0.5, 0.7]]), ms.float32) >>> target = ms.Tensor(np.array([[1, 2, 0, 3], [2, 3, -1, 1]]), ms.int32) >>> output = loss(x, target) >>> print(output) 0.325