mindspore.numpy.dsplit

查看源文件
mindspore.numpy.dsplit(x, indices_or_sections)[源代码]

沿第三轴(深度)将Tensor分割为多个sub-tensor。此操作等同于使用 axis=2 (默认值)进行分割,不论数组的维度如何,始终沿第三轴分割。

参数:
  • x (Tensor) - 待分割的Tensor。

  • indices_or_sections (Union[int, tuple(int), list(int)]) - 如果是整数 N ,Tensor将沿轴分割为 N 个相等的sub-tensor。如果是tuple(int)、list(int)或排序后的整数,则指示沿轴的分割位置。例如,对于 axis=0[2,3] 将产生三个sub-tensor: x[:2]x[2:3]x[3:] 。如果索引超出轴上数组的维度,则相应地返回空子数组。

返回:

sub-tensor列表。

异常:
  • TypeError - 如果参数 indices_or_sections 不是整数。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore.numpy as np
>>> input_x = np.arange(6).reshape((1, 2, 3)).astype('float32')
>>> output = np.dsplit(input_x, 3)
>>> print(output)
(Tensor(shape=[1, 2, 1], dtype=Float32,
value=[[[ 0.00000000e+00],
        [ 3.00000000e+00]]]),
Tensor(shape=[1, 2, 1], dtype=Float32,
value=[[[ 1.00000000e+00],
        [ 4.00000000e+00]]]),
Tensor(shape=[1, 2, 1], dtype=Float32,
value=[[[ 2.00000000e+00],
        [ 5.00000000e+00]]]))