mindspore.dataset.transforms.OneHot
- class mindspore.dataset.transforms.OneHot(num_classes, smoothing_rate=0.0)[源代码]
将Tensor进行OneHot编码。
- 参数:
num_classes (int) - 数据集的类别数,它应该大于数据集中最大的label编号。
smoothing_rate (float,可选) - 标签平滑的系数。默认值:0.0。
- 异常:
TypeError - 参数 num_classes 类型不为int。
TypeError - 参数 smoothing_rate 类型不为float。
ValueError - 参数 smoothing_rate 取值范围不为[0.0, 1.0]。
RuntimeError - 输入Tensor的数据类型不为int。
RuntimeError - 参数Tensor的shape不是1-D。
- 支持平台:
CPU
样例:
>>> # Assume that dataset has 10 classes, thus the label ranges from 0 to 9 >>> onehot_op = transforms.OneHot(num_classes=10) >>> mnist_dataset = mnist_dataset.map(operations=onehot_op, input_columns=["label"])