mindspore.ops.RandomCategorical
- class mindspore.ops.RandomCategorical(dtype=mstype.int64)[源代码]
Generates random samples from a given categorical distribution tensor.
- Parameters
dtype (mindspore.dtype) – The type of output. Its value must be one of mindspore.int16, mindspore.int32 and mindspore.int64. Default: mindspore.int64.
- Inputs:
logits (Tensor) - The input tensor. 2-D Tensor with shape [batch_size, num_classes].
num_sample (int) - Number of sample to be drawn. Only constant values is allowed.
seed (int) - Random seed. Default: 0. Only constant values is allowed.
- Outputs:
output (Tensor) - The output Tensor with shape [batch_size, num_samples].
- Raises
- Supported Platforms:
Ascend
GPU
Examples
>>> class Net(nn.Cell): ... def __init__(self, num_sample): ... super(Net, self).__init__() ... self.random_categorical = ops.RandomCategorical(mindspore.int64) ... self.num_sample = num_sample ... def construct(self, logits, seed=0): ... return self.random_categorical(logits, self.num_sample, seed) ... >>> x = np.random.random((10, 5)).astype(np.float32) >>> net = Net(8) >>> output = net(Tensor(x)) >>> result = output.shape >>> print(result) (10, 8)