mindspore.mint.nn.KLDivLoss

View Source On Gitee
class mindspore.mint.nn.KLDivLoss(reduction='mean', log_target=False)[source]

Computes the Kullback-Leibler divergence between the input and the target.

For tensors of the same shape x and y, the updating formulas of KLDivLoss algorithm are as follows,

L(x,y)=y(logyx)

Then,

(x,y)={L(x,y),if reduction='none';mean(L(x,y)),if reduction='mean';sum(L(x,y))/x.shape[0],if reduction='batchmean';sum(L(x,y)),if reduction='sum'.

where x represents input, y represents target, and (x,y) represents the output.

Note

The output aligns with the mathematical definition of Kullback-Leibler divergence only when reduction is set to 'batchmean'.

Parameters
  • reduction (str, optional) – Specifies the reduction to be applied to the output. Default: 'mean'.

  • log_target (bool, optional) – Specifies whether target is passed in the log space. Default: False.

Inputs:
  • input (Tensor) - The input Tensor. The data type must be float16, float32 or bfloat16(only supported by Atlas A2 training series products).

  • target (Tensor) - The target Tensor which has the same type as input. The shapes of target and input should be broadcastable.

Outputs:

Tensor, has the same dtype as input. If reduction is 'none', then output has the shape as broadcast result of the input and target. Otherwise, it is a scalar Tensor.

Raises
  • TypeError – If neither input nor target is a Tensor.

  • TypeError – If dtype of input or target is not float16, float32 or bfloat16.

  • TypeError – If dtype of target is not the same as input.

  • ValueError – If reduction is not one of 'none', 'mean', 'sum', 'batchmean'.

  • ValueError – If shapes of target and input can not be broadcastable.

Supported Platforms:

Ascend

Examples

>>> import mindspore as ms
>>> from mindspore import mint
>>> import numpy as np
>>> input = ms.Tensor(np.array([[0.5, 0.5], [0.4, 0.6]]), ms.float32)
>>> target = ms.Tensor(np.array([[0., 1.], [1., 0.]]), ms.float32)
>>> loss = mint.nn.KLDivLoss(reduction='mean', log_target=False)
>>> output = loss(input, target)
>>> print(output)
-0.225