mindspore.mint.nn.functional.one_hot
- mindspore.mint.nn.functional.one_hot(tensor, num_classes=- 1)[源代码]
返回一个one-hot类型的Tensor。
生成一个新的Tensor,由索引 tensor 表示的位置取值为 1 ,而在其他所有位置取值为 0 。
- 参数:
tensor (Tensor) - 输入索引,shape为 \((X_0, \ldots, X_n)\) 的Tensor。数据类型必须为int32或int64。
num_classes (int) - 输入的Scalar,定义one-hot的深度,默认值:
-1
。
- 返回:
Tensor,one-hot类型的Tensor。
- 异常:
TypeError - num_classes 的数据类型不是int。
TypeError - tensor 的数据类型不是int32或者int64。
ValueError - num_classes 的输入值小于-1。
- 支持平台:
Ascend
样例:
>>> import mindspore >>> import numpy as np >>> from mindspore import Tensor, mint >>> tensor = Tensor(np.array([0, 1, 2]), mindspore.int32) >>> num_classes = 3 >>> output = mint.nn.functional.one_hot(tensor, num_classes) >>> print(output) [[1 0 0] [0 1 0] [0 0 1]]