mindspore.ops.tensor_split
- mindspore.ops.tensor_split(input, indices_or_sections, axis=0)[source]
Split the input tensor into multiple subtensors according to the specified indices or chunks.
- Parameters
input (Tensor) – The input tensor.
indices_or_sections (Union[int, tuple(int), list(int)]) –
The specified indices or chunks.
If it is an integer, input tensor will be split into indices_or_sections sections.
If
can be divisible by indices_or_sections, sub-sections will have equal size .If
can not be divisible by indices_or_sections, the first sections will have size , and the rest will have size .
If it is a tuple(int) or list(int) type, it represts indices and the input tensor will be split at the indices.
axis (int, optional) – The axis along which to split. Default
0
.
- Returns
Tuple of tensors.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore >>> input = mindspore.tensor([0, 1, 2, 3, 4, 5, 6, 7]) >>> mindspore.ops.tensor_split(input, 3) (Tensor(shape=[3], dtype=Int64, value= [0, 1, 2]), Tensor(shape=[3], dtype=Int64, value= [3, 4, 5]), Tensor(shape=[2], dtype=Int64, value= [6, 7])) >>> input = mindspore.tensor([0, 1, 2, 3, 4, 5, 6]) >>> mindspore.ops.tensor_split(input, 3) (Tensor(shape=[3], dtype=Int64, value= [0, 1, 2]), Tensor(shape=[2], dtype=Int64, value= [3, 4]), Tensor(shape=[2], dtype=Int64, value= [5, 6])) >>> mindspore.ops.tensor_split(input, (1, 6)) (Tensor(shape=[1], dtype=Int64, value= [0]), Tensor(shape=[5], dtype=Int64, value= [1, 2, 3, 4, 5]), Tensor(shape=[1], dtype=Int64, value= [6])) >>> input = mindspore.tensor([[ 0, 1, 2, 3, 4, 5, 6], ... [ 7, 8, 9, 10, 11, 12, 13]]) >>> mindspore.ops.tensor_split(input, 3, axis=1) (Tensor(shape=[2, 3], dtype=Int64, value= [[0, 1, 2], [7, 8, 9]]), Tensor(shape=[2, 2], dtype=Int64, value= [[ 3, 4], [10, 11]]), Tensor(shape=[2, 2], dtype=Int64, value= [[ 5, 6], [12, 13]])) >>> mindspore.ops.tensor_split(input, (1, 6), axis=1) (Tensor(shape=[2, 1], dtype=Int64, value= [[0], [7]]), Tensor(shape=[2, 5], dtype=Int64, value= [[ 1, 2, 3, 4, 5], [ 8, 9, 10, 11, 12]]), Tensor(shape=[2, 1], dtype=Int64, value= [[ 6], [13]]))