mindspore.nn.CTCLoss
- class mindspore.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)[source]
Calculates the CTC (Connectionist Temporal Classification) loss. It’s mainly used to calculate the loss between the continuous, unsegemented time series and the target series.
For the CTC algorithm, refer to Connectionist Temporal Classification: Labeling Unsegmented Sequence Data with Recurrent Neural Networks .
- Parameters
blank (int, optional) – The blank label. Default:
0
.reduction (str, optional) – Implements the reduction method to the output with
'none'
,'mean'
, or'sum'
. Default:'mean'
.zero_infinity (bool, optional) – If loss is infinite, this parameter determines whether to set that loss and its correlated gradient to zero. Default:
False
.
- Inputs:
log_probs (Tensor) - A tensor of shape \((T, N, C)\) or \((T, C)\), where T is length of input, N is size of the batch and C is the number of classes. T, N and C are positive integers.
targets (Tensor) - A tensor of shape \((N, S)\) or (sum( target_lengths )), where S is max target length, means the target sequences.
input_lengths (Union[tuple, Tensor]) - A tuple or Tensor of shape(N). It means the lengths of the input.
target_lengths (Union[tuple, Tensor]) - A tuple or Tensor of shape(N). It means the lengths of the target.
- Outputs:
neg_log_likelihood (Tensor) - A loss value which is differentiable with respect to each input node.
- Raises
TypeError – If log_probs or targets is not a Tensor.
TypeError – If zero_infinity is not a bool, reduction is not string.
TypeError – If the dtype of log_probs is not float or double.
TypeError – If the dtype of targets, input_lengths or target_lengths is not int32 or int64.
ValueError – If reduction is not “none”, “mean” or “sum”.
ValueError – If the value of blank is not in range [0, C). C is number of classes of log_probs .
ValueError – If the shape of log_probs is \((T, C)\), the dimension of targets is not 1 or 2.
ValueError – If the shape of log_probs is \((T, C)\), the first dimension of 2-D target is not 1.
RuntimeError – If any value of input_lengths is larger than T. T is length of log_probs .
RuntimeError – If any target_lengths[i] is not in range [0, input_length[i]].
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore as ms >>> import mindspore.nn as nn >>> import numpy as np >>> T = 5 # Input sequence length >>> C = 2 # Number of classes >>> N = 2 # Batch size >>> S = 3 # Target sequence length of longest target in batch (padding length) >>> S_min = 2 # Minimum target length, for demonstration purposes >>> arr = np.arange(T*N*C).reshape((T, N, C)) >>> ms_input = ms.Tensor(arr, dtype=ms.float32) >>> input_lengths = np.full(shape=(N), fill_value=T) >>> input_lengths = ms.Tensor(input_lengths, dtype=ms.int32) >>> target_lengths = np.full(shape=(N), fill_value=S_min) >>> target_lengths = ms.Tensor(target_lengths, dtype=ms.int32) >>> target = np.random.randint(1, C, size=(N, S)) >>> target = ms.Tensor(target, dtype=ms.int32) >>> ctc_loss = nn.CTCLoss(blank=0, reduction='none', zero_infinity=False) >>> loss = ctc_loss(ms_input, target, input_lengths, target_lengths) >>> print(loss) [-45.79497 -55.794968] >>> arr = np.arange(T*C).reshape((T, C)) >>> ms_input = ms.Tensor(arr, dtype=ms.float32) >>> input_lengths = ms.Tensor([T], dtype=ms.int32) >>> target_lengths = ms.Tensor([S_min], dtype=ms.int32) >>> target = np.random.randint(1, C, size=(S_min,)) >>> target = ms.Tensor(target, dtype=ms.int32) >>> ctc_loss = nn.CTCLoss(blank=0, reduction='none', zero_infinity=False) >>> loss = ctc_loss(ms_input, target, input_lengths, target_lengths) >>> print(loss) -25.794968