mindspore.ops.vsplit

查看源文件
mindspore.ops.vsplit(input, indices_or_sections)[源代码]

根据 indices_or_sections 将至少有两维的输入tensor垂直分割成多个子tensor。

等同于 axis=0 时的 ops.tensor_split

参数:
  • input (Tensor) - 输入tensor。

  • indices_or_sections (Union[int, tuple(int), list(int)]) - 参考 mindspore.ops.tensor_split() 中的 indices_or_sections 参数。

返回:

由多个tensor组成的tuple。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> input = mindspore.tensor([[ 0,  1,  2,  3],
...                           [ 4,  5,  6,  7],
...                           [ 8,  9, 10, 11],
...                           [12, 13, 14, 15]])
>>> mindspore.ops.vsplit(input, 2)
(Tensor(shape=[2, 4], dtype=Int64, value=
 [[0, 1, 2, 3],
  [4, 5, 6, 7]]),
 Tensor(shape=[2, 4], dtype=Int64, value=
 [[ 8,  9, 10, 11],
  [12, 13, 14, 15]]))
>>> mindspore.ops.vsplit(input, [3, 6])
(Tensor(shape=[3, 4], dtype=Int64, value=
 [[ 0,  1,  2,  3],
  [ 4,  5,  6,  7],
  [ 8,  9, 10, 11]]),
 Tensor(shape=[1, 4], dtype=Int64, value=
 [[12, 13, 14, 15]]),
 Tensor(shape=[0, 4], dtype=Int64, value=
 ))