mindspore.ops.RandomCategorical
- class mindspore.ops.RandomCategorical(dtype=mstype.int64)[源代码]
从分类分布中抽取样本。
警告
Ascend后端不支持随机数生成结果复现, seed 参数将失效。
- 参数:
dtype (mindspore.dtype,可选) - 输出的类型。其值必须是mstype.int16、mstype.int32或mstype.int64。默认值:
mstype.int64
。
- 输入:
logits (Tensor) - 输入Tensor,是一个shape为
的二维Tensor。num_sample (int) - 要抽取的样本数。只允许使用常量值。
seed (int) - 随机种子值,仅支持常量值。默认值:
0
。
- 输出:
output (Tensor) - 输出Tensor,其shape为
。
- 异常:
TypeError - 如果 dtype 不是mstype.int16、mstype.int32或mstype.int64。
TypeError - 如果 logits 不是Tensor。
TypeError - 如果 num_sample 或者 seed 不是int。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore >>> import numpy as np >>> from mindspore import nn, ops, Tensor >>> class Net(nn.Cell): ... def __init__(self, num_sample): ... super(Net, self).__init__() ... self.random_categorical = ops.RandomCategorical(mindspore.int64) ... self.num_sample = num_sample ... def construct(self, logits, seed=0): ... return self.random_categorical(logits, self.num_sample, seed) ... >>> x = np.random.random((10, 5)).astype(np.float32) >>> net = Net(8) >>> output = net(Tensor(x)) >>> result = output.shape >>> print(result) (10, 8)