mindspore.ops.random_categorical
- mindspore.ops.random_categorical(logits, num_sample, seed=0, dtype=mstype.int64)[源代码]
从一个分类分布中生成随机样本。
警告
Ascend后端不支持随机数重现功能, seed 参数不起作用。
- 参数:
logits (Tensor) - 输入Tensor。Shape为
的二维Tensor。num_sample (int) - 要抽取的样本数。只允许使用常量值。
seed (int) - 随机种子。只允许使用常量值。默认
0
。dtype (mindspore.dtype) - 输出的类型。其值必须是mindspore.int16、mindspore.int32或mindspore.int64之一。默认
mstype.int64
。
- 返回:
Tensor,Shape为
的输出Tensor。- 异常:
TypeError - 如果 dtype 不是以下类型之一:mindspore.int16、mindspore.int32、mindspore.int64。
TypeError - 如果 logits 不是Tensor。
TypeError - 如果 num_sample 或 seed 不是int类型。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> from mindspore import ops >>> from mindspore import Tensor >>> import mindspore.common.dtype as mstype >>> import numpy as np >>> logits = Tensor(np.random.random((10, 5)).astype(np.float32), mstype.float32) >>> net = ops.random_categorical(logits, 8) >>> result = net.shape >>> print(result) (10, 8)