mindspore.ops.ctc_greedy_decoder

mindspore.ops.ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True)[源代码]

对输入中给定的logits执行贪婪解码。

参数:
  • inputs (Tensor) - shape: \((max\_time, batch\_size, num\_classes)\),数据类型必须是float32或者float64。num_classesnum_labels + 1 classes,其中 num_labels 表示实际标签的个数,空标签默认使用 num_classes - 1

  • sequence_length (Tensor) - shape: \((batch\_size, )\),数据类型必须是int32,并且Tensor中的数值必须小于等于 max_time

  • merge_repeated (bool) - True表示返回的结果中会合并重复的类。默认值为True。

返回:
  • decoded_indices (Tensor) - shape: \((total\_decoded\_outputs, 2)\),数据类型为int64。

  • decoded_values (Tensor) - shape: \((total\_decoded\_outputs, )\),数据类型为int64。

  • decoded_shape (Tensor) - shape: \((batch\_size, max\_decoded\_length)\),数据类型为int64。

  • log_probability (Tensor) - shape: \((batch\_size, 1)\),包含序列的对数概率,其数据类型与 inputs 保持一致。

异常:
  • TypeError - merge_repeated 不是一个布尔值。

  • ValueError - inputs 的shape长度不等于3。

  • ValueError - sequence_length 的shape长度不等于1。

  • ValueError - sequence_length 中的数值大于 max_time

支持平台:

Ascend CPU

样例:

>>> inputs = Tensor(np.array([[[0.6, 0.4, 0.2], [0.8, 0.6, 0.3]],
...                           [[0.0, 0.6, 0.0], [0.5, 0.4, 0.5]]]), mindspore.float32)
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
>>> decoded_indices, decoded_values, decoded_shape, log_probability = ops.ctc_greedy_decoder(inputs,
...                                                                                          sequence_length)
>>> print(decoded_indices)
[[0 0]
 [0 1]
 [1 0]]
>>> print(decoded_values)
[0 1 0]
>>> print(decoded_shape)
[2 2]
>>> print(log_probability)
[[-1.2]
 [-1.3]]