mindspore.Tensor.chunk

View Source On Gitee
Tensor.chunk(chunks, dim=0) Tuple of Tensors

Cut the self Tensor into chunks sub-tensors along the specified dimension.

Note

This function may return less than the specified number of chunks!

Warning

This is an experimental API that is subject to change or deletion.

Parameters
  • chunks (int) – Number of sub-tensors to cut.

  • dim (int, optional) – Specify the dimensions that you want to split. Default: 0 .

Returns

A tuple of sub-tensors.

Raises
  • TypeError – The sum of chunks is not int.

  • TypeError – If argument dim is not int.

  • ValueError – If argument dim is out of range of \([-self.ndim, self.ndim)\) .

  • ValueError – If argument chunks is not positive number.

Supported Platforms:

Ascend

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> input_x = Tensor(np.arange(9).astype("float32"))
>>> output = input_x.chunk(3, dim=0)
>>> print(output)
(Tensor(shape=[3], dtype=Float32, value= [ 0.00000000e+00,  1.00000000e+00,  2.00000000e+00]),
    Tensor(shape=[3], dtype=Float32, value= [ 3.00000000e+00,  4.00000000e+00,  5.00000000e+00]),
    Tensor(shape=[3], dtype=Float32, value= [ 6.00000000e+00,  7.00000000e+00,  8.00000000e+00]))
Tensor.chunk(chunks, axis=0) Tuple of Tensors

Cut the self Tensor into chunks sub-tensors along the specified axis.

Note

This function may return less than the specified number of chunks!

Parameters
  • chunks (int) – Number of sub-tensors to cut.

  • axis (int, optional) – Specify the dimensions that you want to split. Default: 0 .

Returns

A tuple of sub-tensors.

Raises
  • TypeError – The sum of chunks is not int.

  • TypeError – If argument axis is not int.

  • ValueError – If argument axis is out of range of \([-self.ndim, self.ndim)\) .

  • ValueError – If argument chunks is not positive number.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> from mindspore import Tensor
>>> input_x = Tensor(np.arange(9).astype("float32"))
>>> output = input_x.chunk(3, axis=0)
>>> print(output)
(Tensor(shape=[3], dtype=Float32, value= [ 0.00000000e+00,  1.00000000e+00,  2.00000000e+00]),
    Tensor(shape=[3], dtype=Float32, value= [ 3.00000000e+00,  4.00000000e+00,  5.00000000e+00]),
    Tensor(shape=[3], dtype=Float32, value= [ 6.00000000e+00,  7.00000000e+00,  8.00000000e+00]))