mindspore.ops.csr_mm
- mindspore.ops.csr_mm(a: CSRTensor, b: CSRTensor, trans_a: bool = False, trans_b: bool = False, adjoint_a: bool = False, adjoint_b: bool = False)[源代码]
返回稀疏矩阵a与稀疏矩阵或稠密矩阵b的矩阵乘法结果。
说明
若右矩阵为Tensor,则仅支持安装了LLVM12.0.1及以上版本的CPU后端或GPU后端。 若右矩阵为CSRTensor, 则仅支持GPU后端。
- 参数:
a (CSRTensor) - 稀疏的 CSRTensor。
b (CSRTensor) - 稀疏的 CSRTensor或稠密矩阵。
trans_a (bool, 可选) - 是否对矩阵a进行转置。默认值:
False
。trans_b (bool, 可选) - 是否对矩阵b进行转置。默认值:
False
。adjoint_a (bool, 可选) - 是否对矩阵a进行共轭。默认值:
False
。adjoint_b (bool, 可选) - 是否对矩阵b进行共轭。默认值:
False
。
- 返回:
返回稀疏矩阵,类型为CSRTensor。
- 支持平台:
GPU
样例:
>>> from mindspore import Tensor, CSRTensor >>> from mindspore import dtype as mstype >>> from mindspore import ops >>> a_shape = (4, 5) >>> a_indptr = Tensor([0, 1, 1, 3, 4], dtype=mstype.int32) >>> a_indices = Tensor([0, 3, 4, 0],dtype=mstype.int32) >>> a_values = Tensor([1.0, 5.0, -1.0, -2.0], dtype=mstype.float32) >>> b_shape = (5, 3) >>> b_indptr = Tensor([0, 1, 1, 3, 3, 3], dtype=mstype.int32) >>> b_indices = Tensor([0, 0, 1],dtype=mstype.int32) >>> b_values = Tensor([2.0, 7.0, 8.0], dtype=mstype.float32) >>> a = CSRTensor(a_indptr, a_indices, a_values, a_shape) >>> b = CSRTensor(b_indptr, b_indices, b_values, b_shape) >>> c = ops.csr_mm(a, b) >>> print(c.shape) (4, 3) >>> print(c.values) [2. -4.] >>> print(c.indptr) [0 1 1 1 2] >>> print(c.indices) [0 0]