mindspore.Tensor.split

mindspore.Tensor.split(axis=0, output_num=1)[源代码]

根据指定的轴和分割数量对Tensor进行分割。

Tensor将被分割为相同shape的子Tensor,且要求 self.shape(axis) 可被 output_num 整除。

参数:
  • axis (int) - 指定分割轴。默认值:0。

  • output_num (int) - 指定分割数量。其值为正整数。默认值:1。

返回:

tuple[Tensor],每个输出Tensor的shape相同,即 \((y_1, y_2, ..., y_S)\) 。数据类型与Tensor相同。

异常:
  • TypeError - axisoutput_num 不是int。

  • ValueError - axis 超出[-len(self.shape), len(self.shape))范围。或 output_num 小于或等于0。

  • ValueError - self.shape(axis) 不可被 output_num 整除。

支持平台:

Ascend GPU CPU

样例:

>>> x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]), mindspore.int32)
>>> print(x)
[[1 1 1 1]
 [2 2 2 2]]
>>> output = x.split(1, 2)
>>> print(output)
(Tensor(shape=[2, 2], dtype=Int32, value=
[[1, 1],
 [2, 2]]), Tensor(shape=[2, 2], dtype=Int32, value=
[[1, 1],
 [2, 2]]))