mindspore.ops.csr_softmax

查看源文件
mindspore.ops.csr_softmax(logits: CSRTensor, dtype: mstype)[源代码]

计算 CSRTensorMatrix 的 softmax 。

参数:
  • logits (CSRTensor) - 输入稀疏的 CSRTensor。

  • dtype (dtype) - 输入的数据类型。

返回:
  • CSRTensor (CSRTensor) - 一个 csr_tensor 包含

    • indptr - 指示每行中非零值的起始点和结束点。

    • indices - 输入中所有非零值的列位置。

    • values - 稠密张量的非零值。

    • shape - csr_tensor的shape。

支持平台:

GPU CPU

样例:

>>> import mindspore as ms
>>> from mindspore import ops
>>> import mindspore.common.dtype as mstype
>>> from mindspore import Tensor, CSRTensor
>>> logits_indptr = Tensor([0, 4, 6], dtype=mstype.int32)
>>> logits_indices = Tensor([0, 2, 3, 4, 3, 4], dtype=mstype.int32)
>>> logits_values = Tensor([1, 2, 3, 4, 1, 2], dtype=mstype.float32)
>>> shape = (2, 6)
>>> logits = CSRTensor(logits_indptr, logits_indices, logits_values, shape)
>>> out = ops.csr_softmax(logits, dtype=mstype.float32)
>>> print(out)
CSRTensor(shape=[2, 6], dtype=Float32, indptr=Tensor(shape=[3], dtype=Int32, value=[0 4 6]),
               indices=Tensor(shape=[6], dtype=Int32, value=[0 2 3 4 3 4]),
               values=Tensor(shape=[6], dtype=Float32, value=[ 3.20586003e-02  8.71443152e-02  2.36882806e-01
               6.43914223e-01  2.68941432e-01  7.31058598e-01]))