mindspore.ops.csr_softmax
- mindspore.ops.csr_softmax(logits: CSRTensor, dtype: mstype)[源代码]
计算 CSRTensorMatrix 的 softmax 。
- 参数:
logits (CSRTensor) - 输入稀疏的 CSRTensor。
dtype (dtype) - 输入的数据类型。
- 返回:
CSRTensor (CSRTensor) - 一个 csr_tensor 包含
indptr - 指示每行中非零值的起始点和结束点。
indices - 输入中所有非零值的列位置。
values - 稠密张量的非零值。
shape - csr_tensor的shape。
- 支持平台:
GPU
CPU
样例:
>>> import mindspore as ms >>> from mindspore import ops >>> import mindspore.common.dtype as mstype >>> from mindspore import Tensor, CSRTensor >>> logits_indptr = Tensor([0, 4, 6], dtype=mstype.int32) >>> logits_indices = Tensor([0, 2, 3, 4, 3, 4], dtype=mstype.int32) >>> logits_values = Tensor([1, 2, 3, 4, 1, 2], dtype=mstype.float32) >>> shape = (2, 6) >>> logits = CSRTensor(logits_indptr, logits_indices, logits_values, shape) >>> out = ops.csr_softmax(logits, dtype=mstype.float32) >>> print(out) CSRTensor(shape=[2, 6], dtype=Float32, indptr=Tensor(shape=[3], dtype=Int32, value=[0 4 6]), indices=Tensor(shape=[6], dtype=Int32, value=[0 2 3 4 3 4]), values=Tensor(shape=[6], dtype=Float32, value=[ 3.20586003e-02 8.71443152e-02 2.36882806e-01 6.43914223e-01 2.68941432e-01 7.31058598e-01]))