mindspore.ops.gumbel_softmax

mindspore.ops.gumbel_softmax(logits, tau=1, hard=False, dim=- 1)[源代码]

返回Gumbel-Softmax分布的Tensor。在 hard = True 的时候,返回one-hot形式的离散型Tensor,hard = False 时返回在dim维进行过softmax的Tensor。

参数:
  • logits (Tensor) - 输入,是一个非标准化的对数概率分布。只支持float16和float32。

  • tau (float) - 标量温度,正数。默认值:1.0。

  • hard (bool) - 为True时返回one-hot离散型Tensor,可反向求导。默认值:False。

  • dim (int) - 给softmax使用的参数,在dim维上做softmax操作。默认值:-1。

返回:

Tensor,shape与dtype和输入 logits 相同。

异常:
  • TypeError - logits 不是Tensor。

  • TypeError - logits 不是float16或float32。

  • TypeError - tau 不是float。

  • TypeError - hard 不是bool。

  • TypeError - dim 不是int。

  • ValueError - tau 不是正数。

支持平台:

Ascend GPU CPU

样例:

>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> output = ops.gumbel_softmax(input_x, 1.0, True, -1)
>>> print(output.shape)
(2, 3)