mindspore.ops.tensor_split

查看源文件
mindspore.ops.tensor_split(input, indices_or_sections, axis=0)[源代码]

根据指定索引或份数,将输入tensor拆分成多个子tensor。

参数:
  • input (Tensor) - 输入tensor。

  • indices_or_sections (Union[int, tuple(int), list(int)]) - 指定索引或份数。

    • 如果是int类型,输入tensor将被拆分成 indices_or_sections 份。

      • 如果 input.shape[axis] 能被 indices_or_sections 整除,那么子切片为相同大小 input.shape[axis]/n

      • 如果 input.shape[axis] 不能被 indices_or_sections 整除,那么前 input.shape[axis]modn 个切片的大小为 input.shape[axis]//n+1 ,其余切片的大小为 input.shape[axis]//n

    • 如果是tuple(int)或list(int)类型,则表示索引,输入tensor在索引处被拆分。

  • axis (int,可选) - indices_or_sections 所在的轴。默认 0

返回:

由多个tensor组成的tuple。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> input = mindspore.tensor([0, 1, 2, 3, 4, 5, 6, 7])
>>> mindspore.ops.tensor_split(input, 3)
(Tensor(shape=[3], dtype=Int64, value= [0, 1, 2]),
 Tensor(shape=[3], dtype=Int64, value= [3, 4, 5]),
 Tensor(shape=[2], dtype=Int64, value= [6, 7]))
>>> input = mindspore.tensor([0, 1, 2, 3, 4, 5, 6])
>>> mindspore.ops.tensor_split(input, 3)
(Tensor(shape=[3], dtype=Int64, value= [0, 1, 2]),
 Tensor(shape=[2], dtype=Int64, value= [3, 4]),
 Tensor(shape=[2], dtype=Int64, value= [5, 6]))
>>> mindspore.ops.tensor_split(input, (1, 6))
(Tensor(shape=[1], dtype=Int64, value= [0]),
 Tensor(shape=[5], dtype=Int64, value= [1, 2, 3, 4, 5]),
 Tensor(shape=[1], dtype=Int64, value= [6]))
>>> input = mindspore.tensor([[ 0,  1,  2,  3,  4,  5,  6],
...                           [ 7,  8,  9, 10, 11, 12, 13]])
>>> mindspore.ops.tensor_split(input, 3, axis=1)
(Tensor(shape=[2, 3], dtype=Int64, value=
 [[0, 1, 2],
  [7, 8, 9]]),
 Tensor(shape=[2, 2], dtype=Int64, value=
 [[ 3,  4],
  [10, 11]]),
 Tensor(shape=[2, 2], dtype=Int64, value=
 [[ 5,  6],
  [12, 13]]))
>>> mindspore.ops.tensor_split(input, (1, 6), axis=1)
(Tensor(shape=[2, 1], dtype=Int64, value=
 [[0],
  [7]]),
 Tensor(shape=[2, 5], dtype=Int64, value=
 [[ 1,  2,  3,  4,  5],
  [ 8,  9, 10, 11, 12]]),
 Tensor(shape=[2, 1], dtype=Int64, value=
 [[ 6],
  [13]]))