mindspore.nn.FocalLoss
- class mindspore.nn.FocalLoss(weight=None, gamma=2.0, reduction='mean')[source]
The loss function proposed by Kaiming team in their paper
Focal Loss for Dense Object Detection
improves the effect of image object detection. It is a loss function to solve the imbalance of categories and the difference of classification difficulty. If you want to learn more, please refer to the paper. Focal Loss for Dense Object Detection. The function is shown as follows:- Parameters
gamma (float) – Gamma is used to adjust the steepness of weight curve in focal loss. Default: 2.0.
weight (Union[Tensor, None]) – A rescaling weight applied to the loss of each batch element. The dimension of weight should be 1. If None, no weight is applied. Default: None.
reduction (str) – Type of reduction to be applied to loss. The optional values are “mean”, “sum”, and “none”. If “none”, do not perform reduction. Default: “mean”.
- Inputs:
logits (Tensor) - Tensor of shape should be
or or . Where is the number of classes. Its value is greater than 1. If the shape is or , the or product of and should be the same as labels.labels (Tensor) - Tensor of shape should be
or or . The value of is 1 or it needs to be the same as predict’s . If is not 1, the shape of target should be the same as that of predict, where is the number of classes. If the shape is or , the or product of and should be the same as logits.
- Outputs:
Tensor or Scalar, if reduction is “none”, its shape is the same as logits. Otherwise, a scalar value will be returned.
- Raises
TypeError – If the data type of gamma is not a float.
TypeError – If weight is not a Tensor.
ValueError – If labels dim is different from logits.
ValueError – If labels channel is not 1 and labels shape is different from logits.
ValueError – If reduction is not one of ‘none’, ‘mean’, ‘sum’.
ValueError – If the value of labels is not in the range [-
, ). Where is the number of classes in logits.
- Supported Platforms:
Ascend
Example
>>> logits = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) >>> labels = Tensor([[1], [1], [0]], mstype.int32) >>> focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean') >>> output = focalloss(logits, labels) >>> print(output) 0.12516622