mindspore.ops.Split
- class mindspore.ops.Split(axis=0, output_num=1)[source]
Splits the input tensor into output_num of tensors along the given axis and output numbers.
Refer to
mindspore.ops.split()
for more details.- Parameters
- Inputs:
input_x (Tensor) - The shape of tensor is \((x_0, x_1, ..., x_{R-1})\), R >= 1.
- Outputs:
tuple[Tensor], the shape of each output tensor is the same, which is \((x_0, x_1, ..., x_{axis}/{output_num}, ..., x_{R-1})\). And the data type is the same as input_x.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore >>> import numpy as np >>> from mindspore import Tensor, ops >>> split = ops.Split(1, 2) >>> x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]), mindspore.int32) >>> print(x) [[1 1 1 1] [2 2 2 2]] >>> output = split(x) >>> print(output) (Tensor(shape=[2, 2], dtype=Int32, value= [[1, 1], [2, 2]]), Tensor(shape=[2, 2], dtype=Int32, value= [[1, 1], [2, 2]])) >>> split = ops.Split(1, 4) >>> output = split(x) >>> print(output) (Tensor(shape=[2, 1], dtype=Int32, value= [[1], [2]]), Tensor(shape=[2, 1], dtype=Int32, value= [[1], [2]]), Tensor(shape=[2, 1], dtype=Int32, value= [[1], [2]]), Tensor(shape=[2, 1], dtype=Int32, value= [[1], [2]]))