mindspore.nn.Roll

class mindspore.nn.Roll(shift, axis)[源代码]

沿轴移动Tensor的元素。

元素沿着 axis 维度按照 shift 偏移(朝着较大的索引)正向移动。 shift 为负值则使元素向相反方向移动。移动最后位置的元素将绕到第一个位置,反之亦然。可以指定沿多个轴的多个偏移。

参数:
  • shift (Union[list(int), tuple(int), int]) - 指定元素移动方式,如果为整数,则元素沿指定维度正向移动(朝向较大的索引)的位置数。负偏移将向相反的方向滚动元素。

  • axis (Union[list(int), tuple(int), int]) - 指定需移动维度的轴。

输入:
  • input_x (Tensor) - 输入Tensor。

输出:

Tensor,shape和数据类型与输入的 input_x 相同。

异常:
  • TypeError - shift 不是int、tuple或list。

  • TypeError - axis 不是int、tuple或list。

  • TypeError - shift 的元素不是int。

  • TypeError - axis 的元素不是int。

  • ValueError - axis 超出[-len(input_x.shape), len(input_x.shape))范围。

  • ValueError - shift 的shape长度不等于 axis 的shape长度。

支持平台:

Ascend GPU

样例:

>>> input_x = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.float32))
>>> op = nn.Roll(shift=2, axis=0)
>>> output = op(input_x)
>>> print(output)
[3. 4. 0. 1. 2.]
>>> input_x = Tensor(np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]).astype(np.float32))
>>> op = nn.Roll(shift=[1, -2], axis=[0, 1])
>>> output = op(input_x)
>>> print(output)
[[7. 8. 9. 5. 6.]
 [2. 3. 4. 0. 1.]]