mindspore.ops.Scan

查看源文件
class mindspore.ops.Scan[源代码]

将一个函数循环作用于一个数组,且对当前元素的处理依赖上一个元素的执行结果。 Scan算子的执行逻辑可以近似表示为如下代码:

def Scan(loop_func, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = loop_func(carry, x)
        ys.append(y)
    return carry, ys

当前Scan算子存在以下语法限制:

  • 暂不支持 loop_func 为副作用函数,如:对Parameter、全局变量的修改等操作。

  • 暂不支持 loop_func 的返回值的第一个元素与初始值 init 的类型或形状不同。

警告

这是一个实验性API,后续可能修改或删除。

输入:
  • loop_func (Function) - 循环体函数。

  • init (Union[Tensor, number, str, bool, list, tuple, dict]) - 循环的初始值。

  • xs (Union[tuple, list, None]) - 用于执行循环扫描的数组。

  • length (Union[int, None], 可选) - 数组xs的长度,默认值: None

  • unroll (bool, 可选) - 是否在编译阶段进行循环展开,默认值: True

输出:

Tuple(Union[Tensor, number, str, bool, list, tuple, dict], list), 由两个元素组成的tuple,第一个元素为循环的最终结果,和 init 参数保持一样的类型和形状; 第二个元素是一个列表,包含每次循环的执行结果。

异常:
  • TypeError - loop_func 不是一个函数。

  • TypeError - xs 不是一个tuple、一个list或者None。

  • TypeError - length 不是一个整数或者None。

  • TypeError - unroll 不是一个布尔值。

  • ValueError - loop_func 不能接受 init 以及 xs 的元素作为参数。

  • ValueError - loop_func 的返回值不是一个包含两个元素的tuple,或者tuple的第一个元素与 init 的类型或形状不同。

支持平台:

Ascend GPU CPU

样例:

>>> from mindspore import ops
>>> def cumsum(res, el):
...     res = res + el
...     return res, res
...
>>> a = [1, 2, 3, 4]
>>> result_init = 0
>>> scan_op = ops.Scan()
>>> result = scan_op(cumsum, result_init, a)
>>> print(result == (10, [1, 3, 6, 10]))
True