mindspore.nn.FocalLoss
- class mindspore.nn.FocalLoss(weight=None, gamma=2.0, reduction='mean')[source]
It is a loss function to solve the imbalance of categories and the difference of classification difficulty. The loss function proposed by Kaiming team in their paper Focal Loss for Dense Object Detection improves the effect of image object detection. The function is shown as follows:
\[FL(p_t) = -(1-p_t)^\gamma log(p_t)\]- Parameters
gamma (float) – Gamma is used to adjust the steepness of weight curve in focal loss. Default: 2.0.
weight (Union[Tensor, None]) – A rescaling weight applied to the loss of each batch element. The dimension of weight should be 1. If None, no weight is applied. Default: None.
reduction (str) – Type of reduction to be applied to loss. The optional values are “mean”, “sum”, and “none”. If “none”, do not perform reduction. Default: “mean”.
- Inputs:
logits (Tensor) - Tensor of shape should be \((N, C)\) or \((N, C, H)\) or \((N, C, H, W)\). Where \(C\) is the number of classes. Its value is greater than 1. If the shape is \((N, C, H, W)\) or \((N, C, H)\), the \(H\) or product of \(H\) and \(W\) should be the same as labels.
labels (Tensor) - Tensor of shape should be \((N, C)\) or \((N, C, H)\) or \((N, C, H, W)\). The value of \(C\) is 1 or it needs to be the same as predict’s \(C\). If \(C\) is not 1, the shape of target should be the same as that of predict, where \(C\) is the number of classes. If the shape is \((N, C, H, W)\) or \((N, C, H)\), the \(H\) or product of \(H\) and \(W\) should be the same as logits. The value of labels is should be in the range [-\(C\), \(C\)). Where \(C\) is the number of classes in logits.
- Outputs:
Tensor or Scalar, if reduction is “none”, its shape is the same as logits. Otherwise, a scalar value will be returned.
- Raises
TypeError – If the data type of gamma is not a float.
TypeError – If weight is not a Tensor.
ValueError – If labels dim is different from logits.
ValueError – If labels channel is not 1 and labels shape is different from logits.
ValueError – If reduction is not one of ‘none’, ‘mean’, ‘sum’.
- Supported Platforms:
Ascend
Examples
>>> logits = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) >>> labels = Tensor([[1], [1], [0]], mstype.int32) >>> focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean') >>> output = focalloss(logits, labels) >>> print(output) 0.12516622