import torch.nn.functional as F
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
Return:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def py_sigmoid_focal_loss(pred, target, weight=None, gamma=2.0, alpha=0.25,
reduction='mean', avg_factor=None):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target +
pred_sigmoid * (1 - target)
focal_weight = (alpha * target +
(1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
|
import mindspore as ms
from mindspore import nn, ops
class SigmoidFoaclLoss(nn.Cell):
def __init__(self, weight=None, gamma=2.0, alpha=0.25, reduction='mean', avg_factor=None):
super(SigmoidFoaclLoss, self).__init__()
self.sigmoid = ops.Sigmoid()
self.alpha = alpha
self.gamma = gamma
self.weight = ms.Tensor(weight) if weight is not None else weight
self.reduction = reduction
self.avg_factor = avg_factor
self.binary_cross_entropy_with_logits = nn.BCEWithLogitsLoss(reduction="none")
self.is_weight = (weight is not None)
def reduce_loss(self, loss):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
Return:
Tensor: Reduced loss tensor.
"""
if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
return loss
def weight_reduce_loss(self, loss):
# if avg_factor is not specified, just reduce the loss
if self.avg_factor is None:
loss = self.reduce_loss(loss)
else:
# if reduction is mean, then average the loss by avg_factor
if self.reduction == 'mean':
loss = loss.sum() / self.avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif self.reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def construct(self, pred, target):
pred_sigmoid = self.sigmoid(pred)
target = ops.cast(target, pred.dtype)
pt = (1 - pred_sigmoid) * target +
pred_sigmoid * (1 - target)
focal_weight = (self.alpha * target +
(1 - self.alpha) *
(1 - target)) * ops.pow(pt, self.gamma)
loss = self.binary_cross_entropy_with_logits(pred, target) * focal_weight
if self.is_weight:
weight = self.weight
if self.weight.shape != loss.shape:
if self.weight.shape[0] == loss.shape[0]:
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = self.weight.view(-1, 1)
elif self.weight.size == loss.size:
# Sometimes, weight per anchor per class is also needed.
# e.g. in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
weight = self.weight.view(loss.shape[0], -1)
elif self.weight.ndim != loss.ndim:
raise ValueError(f"weight shape {self.weight.shape} is not match to loss shape {loss.shape}")
loss = loss * weight
loss = self.weight_reduce_loss(loss)
return loss
|