mindspore.ops.repeat_elements
- mindspore.ops.repeat_elements(x, rep, axis=0)[源代码]
在指定轴上复制输入Tensor的元素,类似 numpy.repeat 的功能。
说明
推荐使用
mindspore.mint.repeat_interleave()
,输入 x 的维度最大可支持8,并获得更好的性能。- 参数:
x (Tensor) - 输入Tensor。类型为float16、float32、int8、uint8、int16、int32或int64。 x 的维度必须小于等于7。
rep (int) - 指定复制次数,为正数。
axis (int) - 指定复制轴,默认值:
0
。
- 返回:
Tensor,值沿指定轴复制。如果 x 的shape为 \((s1, s2, ..., sn)\) ,轴为i,则输出的shape为 \((s1, s2, ..., si * rep, ..., sn)\) 。输出的数据类型与 x 相同。
- 异常:
ValueError - 如果 x 的维度大于7。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore >>> import numpy as np >>> from mindspore import Tensor, ops >>> # case 1 : repeat on axis 0 >>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32) >>> output = ops.repeat_elements(x, rep = 2, axis = 0) >>> print(output) [[0 1 2] [0 1 2] [3 4 5] [3 4 5]] >>> # case 2 : repeat on axis 1 >>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32) >>> output = ops.repeat_elements(x, rep = 2, axis = 1) >>> print(output) [[0 0 1 1 2 2] [3 3 4 4 5 5]]