mindspore.numpy.split

查看源文件
mindspore.numpy.split(x, indices_or_sections, axis=0)[源代码]

将一个Tensor沿指定轴分割为多个sub-tensor。

参数:
  • x (Tensor) - 待分割的Tensor。

  • indices_or_sections (Union[int, tuple(int), list(int)]) - 如果是整数 \(N\) ,Tensor将沿轴分割为 \(N\) 个相等的sub-tensor。如果是tuple(int)、list(int)或排序后的整数,则指示沿轴的分割位置。例如,对于 \(axis=0\)\([2,3]\) 将产生三个sub-tensor: \(x[:2]\)\(x[2:3]\)\(x[3:]\) 。如果索引超出轴上数组的维度,则相应地返回空子数组。

  • axis (int,可选) - 指定进行分割的轴。默认值: 0

返回:

sub-tensor的tuple。

异常:
  • TypeError - 如果参数 indices_or_sections 不是整数、tuple(int)或list(int),或者参数 axis 不是整数。

  • ValueError - 如果参数 axis 超出范围 \([-x.ndim, x.ndim)\)

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore.numpy as np
>>> input_x = np.arange(9).astype("float32")
>>> output = np.split(input_x, 3)
>>> print(output)
(Tensor(shape=[3], dtype=Float32,
  value= [ 0.00000000e+00,  1.00000000e+00,  2.00000000e+00]),
 Tensor(shape=[3], dtype=Float32,
  value= [ 3.00000000e+00,  4.00000000e+00,  5.00000000e+00]),
 Tensor(shape=[3], dtype=Float32,
  value= [ 6.00000000e+00,  7.00000000e+00,  8.00000000e+00]))