mindspore.ops.repeat_elements

mindspore.ops.repeat_elements(x, rep, axis=0)[源代码]

在指定轴上复制输入Tensor的元素,类似 np.repeat 的功能。

参数:

  • x (Tensor) - 输入Tensor。类型为float16、float32、int8、uint8、int16、int32或int64。

  • rep (int) - 指定复制次数,为正数。

  • axis (int) - 指定复制轴,默认值:0。

输出:

Tensor,值沿指定轴复制。如果 x 的shape为 \((s1, s2, ..., sn)\) ,轴为i,则输出的shape为 \((s1, s2, ..., si * rep, ..., sn)\) 。输出的数据类型与 x 相同。

支持平台:

Ascend GPU CPU

样例:

>>> from mindspore import Tensor, ops
>>> import mindspore
>>> import numpy as np
>>> # case 1 : repeat on axis 0
>>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32)
>>> output = ops.repeat_elements(x, rep = 2, axis = 0)
>>> print(output)
[[0 1 2]
 [0 1 2]
 [3 4 5]
 [3 4 5]]
>>> # case 2 : repeat on axis 1
>>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32)
>>> output = ops.repeat_elements(x, rep = 2, axis = 1)
>>> print(output)
[[0 0 1 1 2 2]
 [3 3 4 4 5 5]]