mindspore.ops.ctc_greedy_decoder
- mindspore.ops.ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True)[源代码]
对输入中给定的logits执行贪婪解码。
- 参数:
inputs (Tensor) - shape:
,数据类型必须是float32或者float64。num_classes 为 num_labels + 1 classes,其中 num_labels 表示实际标签的个数,空标签默认使用 num_classes - 1。sequence_length (Tensor) - shape:
,数据类型必须是int32,并且Tensor中的数值必须小于等于 max_time。merge_repeated (bool) - True表示返回的结果中会合并重复的类。默认值为True。
- 返回:
decoded_indices (Tensor) - shape:
,数据类型为int64。decoded_values (Tensor) - shape:
,数据类型为int64。decoded_shape (Tensor) - shape:
,数据类型为int64。log_probability (Tensor) - shape:
,包含序列的对数概率,其数据类型与 inputs 保持一致。
- 异常:
TypeError - merge_repeated 不是一个布尔值。
ValueError - inputs 的shape长度不等于3。
ValueError - sequence_length 的shape长度不等于1。
ValueError - sequence_length 中的数值大于 max_time。
- 支持平台:
Ascend
CPU
样例:
>>> inputs = Tensor(np.array([[[0.6, 0.4, 0.2], [0.8, 0.6, 0.3]], ... [[0.0, 0.6, 0.0], [0.5, 0.4, 0.5]]]), mindspore.float32) >>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32) >>> decoded_indices, decoded_values, decoded_shape, log_probability = ops.ctc_greedy_decoder(inputs, ... sequence_length) >>> print(decoded_indices) [[0 0] [0 1] [1 0]] >>> print(decoded_values) [0 1 0] >>> print(decoded_shape) [2 2] >>> print(log_probability) [[-1.2] [-1.3]]