mindspore.SparseTensor

class mindspore.SparseTensor(indices, values, dense_shape)[source]

A sparse representation of a set of nonzero elements from a tensor at given indices.

SparseTensor can only be used in the Cell’s construct method.

For a tensor dense, its SparseTensor(indices, values, dense_shape) has dense[indices[i]] = values[i].

For example, if indices is [[0, 1], [1, 2]], values is [1, 2], dense_shape is (3, 4), then the dense representation of the sparse tensor will be:

[[0, 1, 0, 0],
 [0, 0, 2, 0],
 [0, 0, 0, 0]]

Note

SparseTensor is not supported in Pynative mode at the moment.

Parameters
  • indices (Tensor) – A 2-D integer Tensor of shape [N, ndims], where N and ndims are the number of values and number of dimensions in the SparseTensor, respectively.

  • values (Tensor) – A 1-D tensor of any type and shape [N], which supplies the values for each element in indices.

  • dense_shape (tuple(int)) – A integer tuple of size ndims, which specifies the dense_shape of the sparse tensor.

Returns

SparseTensor, composed of indices, values, and dense_shape.

Examples

>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, SparseTensor
>>> class Net(nn.Cell):
...     def __init__(self, dense_shape):
...         super(Net, self).__init__()
...         self.dense_shape = dense_shape
...     def construct(self, indices, values):
...         x = SparseTensor(indices, values, self.dense_shape)
...         return x.values, x.indices, x.dense_shape
>>>
>>> indices = Tensor([[0, 1], [1, 2]])
>>> values = Tensor([1, 2], dtype=ms.float32)
>>> out = Net((3, 4))(indices, values)
>>> print(out[0])
[1. 2.]
>>> print(out[1])
[[0 1]
 [1 2]]
>>> print(out[2])
(3, 4)