mindspore.nn.FocalLoss
- class mindspore.nn.FocalLoss(weight=None, gamma=2.0, reduction='mean')[source]
The loss function proposed by Kaiming team in their paper
Focal Loss for Dense Object Detection
improves the effect of image object detection. It is a loss function to solve the imbalance of categories and the difference of classification difficulty. If you want to learn more, please refer to the paper. https://arxiv.org/pdf/1708.02002.pdf. The function is shown as follows:\[FL(p_t) = -(1-p_t)^\gamma log(p_t)\]- Parameters
gamma (float) – Gamma is used to adjust the steepness of weight curve in focal loss. Default: 2.0.
weight (Union[Tensor, None]) – A rescaling weight applied to the loss of each batch element. The dimension of weight should be 1. If None, no weights are applied. Default: None.
reduction (str) – Type of reduction to be applied to loss. The optional values are “mean”, “sum”, and “none”. If “none”, do not perform reduction. Default: “mean”.
- Inputs:
logits (Tensor) - Tensor of shape should be (B, C) or (B, C, H) or (B, C, H, W). Where C is the number of classes. Its value is greater than 1. If the shape is (B, C, H, W) or (B, C, H), the H or product of H and W should be the same as labels.
labels (Tensor) - Tensor of shape should be (B, C) or (B, C, H) or (B, C, H, W). The value of C is 1 or it needs to be the same as predict’s C. If C is not 1, the shape of target should be the same as that of predict, where C is the number of classes. If the shape is (B, C, H, W) or (B, C, H), the H or product of H and W should be the same as logits.
- Outputs:
Tensor, it’s a tensor with the same shape and type as input logits.
- Raises
TypeError – If the data type of
gamma
is not float..TypeError – If
weight
is not a Tensor.ValueError – If
labels
dim different fromlogits
.ValueError – If
labels
channel is not 1 andlabels
shape is different fromlogits
.ValueError – If
reduction
is not one of ‘none’, ‘mean’, ‘sum’.
- Supported Platforms:
Ascend
GPU
Example
>>> logits = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) >>> labels = Tensor([[1], [1], [0]], mstype.int32) >>> focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean') >>> output = focalloss(logits, labels) >>> print(output) 0.12516622