mindspore.SparseTensor

class mindspore.SparseTensor(indices, values, shape)[源代码]

用来表示某一张量在给定索引上非零元素的集合。

SparseTensor 只能在 Cell 的构造方法中使用。

Note

目前不支持PyNative模式。

对于稠密张量,其 SparseTensor(indices, values, shape) 具有 dense[indices[i]] = values[i]

参数:

  • indices (Tensor) - 形状为 [N, ndims] 的二维整数张量,其中N和ndims分别表示稀疏张量中 values 的数量和SparseTensor维度的数量。

  • values (Tensor) - 形状为[N]的一维张量,其内部可以为任何数据类型,用来给 indices 中的每个元素提供数值。

  • shape (tuple(int)) - 形状为ndims的整数元组,用来指定稀疏矩阵的稠密形状。

返回:

SparseTensor,由 indicesvaluesshape 组成。

样例:

>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, SparseTensor
>>> indices = Tensor([[0, 1], [1, 2]])
>>> values = Tensor([1, 2], dtype=ms.float32)
>>> shape = (3, 4)
>>> x = SparseTensor(indices, values, shape)
>>> print(x.values)
[1. 2.]
>>> print(x.indices)
[[0 1]
 [1 2]]
>>> print(x.shape)
(3, 4)