mindspore.nn.FocalLoss

class mindspore.nn.FocalLoss(weight=None, gamma=2.0, reduction='mean')[source]

The loss function proposed by Kaiming team in their paper Focal Loss for Dense Object Detection improves the effect of image object detection. It is a loss function to solve the imbalance of categories and the difference of classification difficulty. If you want to learn more, please refer to the paper. Focal Loss for Dense Object Detection. The function is shown as follows:

\[FL(p_t) = -(1-p_t)^\gamma log(p_t)\]
Parameters
  • gamma (float) – Gamma is used to adjust the steepness of weight curve in focal loss. Default: 2.0.

  • weight (Union[Tensor, None]) – A rescaling weight applied to the loss of each batch element. The dimension of weight should be 1. If None, no weights are applied. Default: None.

  • reduction (str) – Type of reduction to be applied to loss. The optional values are “mean”, “sum”, and “none”. If “none”, do not perform reduction. Default: “mean”.

Inputs:
  • logits (Tensor) - Tensor of shape should be \((B, C)\) or \((B, C, H)\) or \((B, C, H, W)\). Where \(C\) is the number of classes. Its value is greater than 1. If the shape is \((B, C, H, W)\) or \((B, C, H)\), the \(H\) or product of \(H\) and \(W\) should be the same as labels.

  • labels (Tensor) - Tensor of shape should be \((B, C)\) or \((B, C, H)\) or \((B, C, H, W)\). The value of \(C\) is 1 or it needs to be the same as predict’s \(C\). If \(C\) is not 1, the shape of target should be the same as that of predict, where \(C\) is the number of classes. If the shape is \((B, C, H, W)\) or \((B, C, H)\), the \(H\) or product of \(H\) and \(W\) should be the same as logits.

Outputs:

Tensor or Scalar, if reduction is “none”, its shape is the same as logits. Otherwise, a scalar value will be returned.

Raises
  • TypeError – If the data type of gamma is not a float.

  • TypeError – If weight is not a Tensor.

  • ValueError – If labels dim is different from logits.

  • ValueError – If labels channel is not 1 and labels shape is different from logits.

  • ValueError – If reduction is not one of ‘none’, ‘mean’, ‘sum’.

Supported Platforms:

Ascend

Example

>>> logits = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
>>> labels = Tensor([[1], [1], [0]], mstype.int32)
>>> focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean')
>>> output = focalloss(logits, labels)
>>> print(output)
0.12516622