mindspore.ops.CTCLossV2
- class mindspore.ops.CTCLossV2(blank=0, reduction='none', zero_infinity=False)[源代码]
计算CTC(Connectionist Temporal Classification)损失和梯度。
CTC算法是在 Connectionist Temporal Classification: Labeling Unsegmented Sequence Data with Recurrent Neural Networks 中提出的。
警告
这是一个实验性API,后续可能修改或删除。
- 参数:
blank (int,可选) - 空白标签。默认值:0。
reduction (str,可选) - 对输出应用特定的缩减方法。目前仅支持“none”,不区分大小写。默认值:“none”。
zero_infinity (bool,可选) - 在损失无限大的时候,是否将无限损失和相关梯度置为零。默认值:False。
- 输入:
log_probs (Tensor) - 输入Tensor,其shape为 \((T, C, N)\) 的三维Tensor。 \(T\) 表示输入长度, \(N\) 表示批大小, \(C\) 表示类别数,包含空白标签。
targets (Tensor) - 标签序列。其shape为 \((N, S)\) 的三维Tensor。 \(S\) 表示最大标签长度。
input_lengths (Union(Tuple, Tensor)) - 输入的长度。其shape为 \((N)\) 。
target_lengths (Union(Tuple, Tensor)) - 标签的长度。其shape为 \((N)\) 。
- 输出:
neg_log_likelihood (Tensor) - 相对于每个输入节点可微分的损失值。
log_alpha (Tensor) - 输入到目标的可能跟踪概率。
- 异常:
TypeError - 如果 zero_infinity 不是bool类型。
TypeError - 如果 reduction 不是string类型。
TypeError - 如果 log_probs 的dtype不是float类型或double类型。
TypeError - 如果 targets、 input_lengths 或 target_lengths 的dtype不是int32类型或int64类型。
ValueError - 如果 log_probs 的秩不等于2。
ValueError - 如果 targets 的秩不等于2。
ValueError - 如果 input_lengths 的shape与批大小 \(N\) 不匹配。
ValueError - 如果 targets 的shape与批大小 \(N\) 不匹配。
TypeError - 如果 targets、 input_lengths 或 target_lengths 的类型不同。
ValueError - 如果 blank 的数值不在[0, C)范围内。
RuntimeError - 如果 input_lengths 中任意一个元素值大于(num_labels|C)。
RuntimeError - 如果任何 target_lengths[i] 不在范围 [0, input_length[i]] 范围内。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> log_probs = Tensor(np.array([[[0.3, 0.6, 0.6]], ... [[0.9, 0.4, 0.2]]]).astype(np.float32)) >>> targets = Tensor(np.array([[0, 1]]), mstype.int32) >>> input_lengths = Tensor(np.array([2]), mstype.int32) >>> target_lengths = Tensor(np.array([1]), mstype.int32) >>> CTCLossV2 = ops.CTCLossV2(blank=0, reduction='none', zero_infinity=False) >>> neg_log_hood, log_alpha = CTCLossV2( ... log_probs, targets, input_lengths, target_lengths) >>> print(neg_log_hood) [-2.2986124] >>> print(log_alpha) [[[0.3 0.3 -inf -inf -inf] [1.2 1.8931472 1.2 -inf -inf]]]