mindspore.jit

查看源文件
mindspore.jit(fn=None, mode='PSJit', input_signature=None, hash_args=None, jit_config=None, compile_once=False)[源代码]

将Python函数编译为一张可调用的MindSpore图。

MindSpore可以在运行时对图进行优化。

参数:
  • fn (Function) - 要编译成图的Python函数。默认值: None

  • mode (str) - 使用jit的类型,可选值有 "PSJit""PIJit" 。默认值: "PSJit"

    • PSJit :解析python的ast以构建静态图。

    • PIJit :在运行时解析python字节码以构建静态图。

  • input_signature (Union[Tuple, List, Dict, Tensor]) - 输入的Tensor是用于描述输入参数的。Tensor的shape和dtype将被配置到函数中去。如果指定了 input_signature,则 fn 的输入参数不接受 **kwargs 类型,并且实际输入的shape和dtype需要与 input_signature 相匹配。否则,将会抛出TypeError异常。 input_signature 有两种模式:

    • 全量配置模式:参数为Tuple、List或者Tensor,它们将被用作图编译时的完整编译参数。

    • 增量配置模式:参数为Dict,它将被配置到图的部分输入上,替换图编译对应位置上的参数。

    默认值: None

  • hash_args (Union[Object, List or Tuple of Objects]) - fn 里面用到的自由变量,比如外部函数或类对象,再次调用时若 hash_args 出现变化会触发重新编译。默认值: None

  • jit_config (JitConfig) - 编译时所使用的JitConfig配置项,详细可参考 mindspore.JitConfig。默认值: None

  • compile_once (bool) - True: 函数多次重新创建只编译一次,如果函数里面的自由变量有变化,设置True是有正确性风险; False: 函数重新创建会触发重新编译。默认值: False

说明

  • 如果指定了 input_signature ,则 fn 的每个输入都必须是Tensor。并且 fn 的输入参数将不会接受 **kwargs 参数。

返回:

函数,如果 fn 不是None,则返回一个已经将输入 fn 编译成图的可执行函数;如果 fn 为None,则返回一个装饰器。当这个装饰器使用单个 fn 参数进行调用时,等价于 fn 不是None的场景。

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore import ops
>>> from mindspore import jit
...
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
...
>>> # create a callable MindSpore graph by calling decorator @jit
>>> def tensor_add(x, y):
...     z = x + y
...     return z
...
>>> tensor_add_graph = jit(fn=tensor_add)
>>> out = tensor_add_graph(x, y)
...
>>> # create a callable MindSpore graph through decorator @jit
>>> @jit
... def tensor_add_with_dec(x, y):
...     z = x + y
...     return z
...
>>> out = tensor_add_with_dec(x, y)
...
>>> # create a callable MindSpore graph through decorator @jit with input_signature parameter
>>> @jit(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
...                       Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
... def tensor_add_with_sig(x, y):
...     z = x + y
...     return z
...
>>> out = tensor_add_with_sig(x, y)
...
>>> @jit(input_signature={"y": Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))})
... def tensor_add_with_sig_1(x, y):
...     z = x + y
...     return z
...
>>> out1 = tensor_add_with_sig_1(x, y)
...
... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused.
... # While fn differs during calling again, recompilation will be triggered.
>>> def func(x):
...     return ops.exp(x)
...
>>> def closure_fn(x, fn):
...     @jit(hash_args=fn)
...     def inner_fn(a):
...         return fn(a)
...     return inner_fn(x)
...
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
>>> for i in range(10):
...     closure_fn(inputs, func)
...
... # Set compile_once = True, otherwise the train_step will be compiled again.
>>> def train(x):
...     @jit(compile_once = True)
...     def train_step(x):
...         return ops.exp(x)
...     for i in range(10):
...         train_step(x)
...
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
>>> for i in range(10):
...     train(inputs)