mindspore.dataset.transforms.c_transforms.OneHot

class mindspore.dataset.transforms.c_transforms.OneHot(num_classes)[源代码]

将Tensor进行OneHot编码。

参数:

  • num_classes (int) - 数据集的类别数,它应该大于数据集中最大的label编号。

异常:

  • TypeError - 参数 num_classes 类型不为int。

  • RuntimeError - 输入Tensor的数据类型不为int。

  • RuntimeError - 参数Tensor的shape不是1-D。

支持平台:

CPU

样例:

>>> # Assume that dataset has 10 classes, thus the label ranges from 0 to 9
>>> onehot_op = c_transforms.OneHot(num_classes=10)
>>> mnist_dataset = mnist_dataset.map(operations=onehot_op, input_columns=["label"])