mindspore.ops.gumbel_softmax
- mindspore.ops.gumbel_softmax(logits, tau=1, hard=False, dim=- 1)[source]
Returns the samples from the Gumbel-Softmax distribution and optionally discretizes. If hard = True, the returned samples will be one-hot, otherwise it will be probability distributions that sum to 1 across dim.
- Parameters
logits (Tensor) – Unnormalized log probabilities. The data type must be float16 or float32.
tau (float) – The scalar temperature, which is a positive number. Default: 1.0.
hard (bool) – if True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd. Default: False.
dim (int) – Dim for softmax to compute. Default: -1.
- Returns
Tensor, has the same dtype and shape as logits.
- Raises
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) >>> output = ops.gumbel_softmax(input_x, 1.0, True, -1) >>> print(output.shape) (2, 3)