mindspore.SparseTensor
- class mindspore.SparseTensor(indices, values, dense_shape)[source]
A sparse representation of a set of nonzero elements from a tensor at given indices.
SparseTensor can only be used in the Cell’s construct method.
For a tensor dense, its SparseTensor(indices, values, dense_shape) has dense[indices[i]] = values[i].
For example, if indices is [[0, 1], [1, 2]], values is [1, 2], dense_shape is (3, 4), then the dense representation of the sparse tensor will be:
[[0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]
Note
SparseTensor is not supported in Pynative mode at the moment.
- Parameters
indices (Tensor) – A 2-D integer Tensor of shape [N, ndims], where N and ndims are the number of values and number of dimensions in the SparseTensor, respectively.
values (Tensor) – A 1-D tensor of any type and shape [N], which supplies the values for each element in indices.
dense_shape (tuple(int)) – A integer tuple of size ndims, which specifies the dense_shape of the sparse tensor.
- Returns
SparseTensor, composed of indices, values, and dense_shape.
Examples
>>> import mindspore as ms >>> import mindspore.nn as nn >>> from mindspore import Tensor, SparseTensor >>> class Net(nn.Cell): ... def __init__(self, dense_shape): ... super(Net, self).__init__() ... self.dense_shape = dense_shape ... def construct(self, indices, values): ... x = SparseTensor(indices, values, self.dense_shape) ... return x.values, x.indices, x.dense_shape >>> >>> indices = Tensor([[0, 1], [1, 2]]) >>> values = Tensor([1, 2], dtype=ms.float32) >>> out = Net((3, 4))(indices, values) >>> print(out[0]) [1. 2.] >>> print(out[1]) [[0 1] [1 2]] >>> print(out[2]) (3, 4)