mindspore.nn.MarginRankingLoss
- class mindspore.nn.MarginRankingLoss(margin=0.0, reduction='mean')[source]
MarginRankingLoss creates a criterion that measures the loss.
Given two tensors \(x1\), \(x2\) and a Tensor label \(y\) with values 1 or -1, the operation is as follows:
\[\text{loss}(x1, x2, y) = \max(0, -y * (x1 - x2) + \text{margin})\]- Parameters
- Inputs:
input1 (Tensor) - Tensor of shape \((N, *)\) where \(*\) means, any number of additional dimensions.
input2 (Tensor) - Tensor of shape \((N, *)\), same shape and dtype as input1.
target (Tensor) - Contains value 1 or -1. Suppose the shape of input1 is \((x_1, x_2, x_3, ..., x_R)\), then the shape of labels must be \((x_1, x_3, x_4, ..., x_R)\).
- Outputs:
Tensor or Scalar. if reduction is “none”, its shape is the same as labels. Otherwise, a scalar value will be returned.
- Raises
TypeError – If margin is not a float.
TypeError – If input1, input2 or target is not a Tensor.
TypeError – If the types of input1 and input2 are inconsistent.
TypeError – If the types of input1 and target are inconsistent.
ValueError – If the shape of input1 and input2 are inconsistent.
ValueError – If the shape of input1 and target are inconsistent.
ValueError – If reduction is not one of ‘none’, ‘mean’, ‘sum’.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore as ms >>> import mindspore.nn as nn >>> import mindspore.ops as ops >>> from mindspore.ops import Tensor >>> import numpy as np >>> loss1 = nn.MarginRankingLoss(reduction='none') >>> loss2 = nn.MarginRankingLoss(reduction='mean') >>> loss3 = nn.MarginRankingLoss(reduction='sum') >>> sign = ops.Sign() >>> input1 = Tensor(np.array([0.3864, -2.4093, -1.4076]), ms.float32) >>> input2 = Tensor(np.array([-0.6012, -1.6681, 1.2928]), ms.float32) >>> target = sign(Tensor(np.array([-2, -2, 3]), ms.float32)) >>> output1 = loss1(input1, input2, target) >>> print(output1) [0.98759997 0. 2.7003999 ] >>> output2 = loss2(input1, input2, target) >>> print(output2) 1.2293333 >>> output3 = loss3(input1, input2, target) >>> print(output3) 3.6879997