mindspore.ops.csr_mm

查看源文件
mindspore.ops.csr_mm(a: CSRTensor, b: CSRTensor, trans_a: bool = False, trans_b: bool = False, adjoint_a: bool = False, adjoint_b: bool = False)[源代码]

返回稀疏矩阵a与稀疏矩阵或稠密矩阵b的矩阵乘法结果。

说明

若右矩阵为Tensor,则仅支持安装了LLVM12.0.1及以上版本的CPU后端或GPU后端。 若右矩阵为CSRTensor, 则仅支持GPU后端。

参数:
  • a (CSRTensor) - 稀疏的 CSRTensor。

  • b (CSRTensor) - 稀疏的 CSRTensor或稠密矩阵。

  • trans_a (bool, 可选) - 是否对矩阵a进行转置。默认值: False

  • trans_b (bool, 可选) - 是否对矩阵b进行转置。默认值: False

  • adjoint_a (bool, 可选) - 是否对矩阵a进行共轭。默认值: False

  • adjoint_b (bool, 可选) - 是否对矩阵b进行共轭。默认值: False

返回:

返回稀疏矩阵,类型为CSRTensor。

支持平台:

GPU

样例:

>>> from mindspore import Tensor, CSRTensor
>>> from mindspore import dtype as mstype
>>> import mindspore.ops as ops
>>> a_shape = (4, 5)
>>> a_indptr = Tensor([0, 1, 1, 3, 4], dtype=mstype.int32)
>>> a_indices = Tensor([0, 3, 4, 0],dtype=mstype.int32)
>>> a_values = Tensor([1.0, 5.0, -1.0, -2.0], dtype=mstype.float32)
>>> b_shape = (5, 3)
>>> b_indptr = Tensor([0, 1, 1, 3, 3, 3], dtype=mstype.int32)
>>> b_indices = Tensor([0, 0, 1],dtype=mstype.int32)
>>> b_values = Tensor([2.0, 7.0, 8.0], dtype=mstype.float32)
>>> a = CSRTensor(a_indptr, a_indices, a_values, a_shape)
>>> b = CSRTensor(b_indptr, b_indices, b_values, b_shape)
>>> c = ops.csr_mm(a, b)
>>> print(c.shape)
(4, 3)
>>> print(c.values)
[2. -4.]
>>> print(c.indptr)
[0 1 1 1 2]
>>> print(c.indices)
[0 0]