mindspore.numpy.split

mindspore.numpy.split(x, indices_or_sections, axis=0)[源代码]

Splits a tensor into multiple sub-tensors along the given axis.

参数
  • x (Tensor) – A Tensor to be divided.

  • indices_or_sections (Union[int, tuple(int), list(int)]) – If integer, \(N\), the tensor will be divided into \(N\) equal tensors along axis. If tuple(int), list(int) or of sorted integers, the entries indicate where along axis the array is split. For example, \([2, 3]\) would, for \(axis=0\), result in three sub-tensors \(x[:2]\), \(x[2:3]\). If an index exceeds the dimension of the array along axis, an empty sub-array is returned correspondingly.

  • axis (int) – The axis along which to split. Default: 0.

返回

A tuple of sub-tensors.

异常
  • TypeError – If argument indices_or_sections is not integer, tuple(int) or list(int) or argument axis is not integer.

  • ValueError – If argument axis is out of range of \([-x.ndim, x.ndim)\).

Supported Platforms:

Ascend GPU CPU

样例

>>> import mindspore.numpy as np
>>> input_x = np.arange(9).astype("float32")
>>> output = np.split(input_x, 3)
>>> print(output)
(Tensor(shape=[3], dtype=Float32,
  value= [ 0.00000000e+00,  1.00000000e+00,  2.00000000e+00]),
 Tensor(shape=[3], dtype=Float32,
  value= [ 3.00000000e+00,  4.00000000e+00,  5.00000000e+00]),
 Tensor(shape=[3], dtype=Float32,
  value= [ 6.00000000e+00,  7.00000000e+00,  8.00000000e+00]))