mindspore.jit

View Source On Gitee
mindspore.jit(fn=None, mode='PSJit', input_signature=None, hash_args=None, jit_config=None, compile_once=False)[source]

Create a callable MindSpore graph from a Python function.

This allows the MindSpore runtime to apply optimizations based on graph.

Parameters
  • fn (Function) – The Python function that will be run as a graph. Default: None .

  • mode (str) –

    The type of jit used, the value of mode should be PIJit or PSJit. Default: PSJit .

    • PSJit : Parse python ast to build graph.

    • PIJit : Parse python bytecode to build graph at runtime.

  • input_signature (Union[Tuple, List, Dict, Tensor]) –

    The Tensor which describes the input arguments. The shape and dtype of the Tensor will be supplied to this function. If input_signature is specified, the input parameters of fn cannot accept **kwargs, and the shape and dtype of actual inputs should keep the same as input_signature. Otherwise, TypeError will be raised. There are two mode for input_signature:

    • Full mode: Arguments is a Tuple, List or a Tensor, and they will be used as all compile inputs for graph-compiling.

    • Incremental mode: Argument is a Dict, and they will set to some of the graph inputs, which will be substituted into the input at the corresponding position for graph-compiling.

    Default: None .

  • hash_args (Union[Object, List or Tuple of Objects]) – The local free variables used inside fn, like functions or objects of class defined outside fn. Calling fn again with change of hash_args will trigger recompilation. Default: None .

  • jit_config (JitConfig) – Jit config for compile. Default: None .

  • compile_once (bool) – True: The function would be compiled once when it was created many times. But it may be wrong if the free variables were changed. False : It would be recompiled when it was created again. Default: False .

Note

If input_signature is specified, each input of fn must be a Tensor. And the input arguments for fn will not accept **kwargs.

Returns

Function, if fn is not None, returns a callable function that will execute the compiled function; If fn is None, returns a decorator and when this decorator invokes with a single fn argument, the callable function is equal to the case when fn is not None.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore import ops
>>> from mindspore import jit
...
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
...
>>> # create a callable MindSpore graph by calling decorator @jit
>>> def tensor_add(x, y):
...     z = x + y
...     return z
...
>>> tensor_add_graph = jit(fn=tensor_add)
>>> out = tensor_add_graph(x, y)
...
>>> # create a callable MindSpore graph through decorator @jit
>>> @jit
... def tensor_add_with_dec(x, y):
...     z = x + y
...     return z
...
>>> out = tensor_add_with_dec(x, y)
...
>>> # create a callable MindSpore graph through decorator @jit with input_signature parameter
>>> @jit(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
...                       Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
... def tensor_add_with_sig(x, y):
...     z = x + y
...     return z
...
>>> out = tensor_add_with_sig(x, y)
...
>>> @jit(input_signature={"y": Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))})
... def tensor_add_with_sig_1(x, y):
...     z = x + y
...     return z
...
>>> out1 = tensor_add_with_sig_1(x, y)
...
... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused.
... # While fn differs during calling again, recompilation will be triggered.
>>> def func(x):
...     return ops.exp(x)
...
>>> def closure_fn(x, fn):
...     @jit(hash_args=fn)
...     def inner_fn(a):
...         return fn(a)
...     return inner_fn(x)
...
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
>>> for i in range(10):
...     closure_fn(inputs, func)
...
... # Set compile_once = True, otherwise the train_step will be compiled again.
>>> def train(x):
...     @jit(compile_once = True)
...     def train_step(x):
...         return ops.exp(x)
...     for i in range(10):
...         train_step(x)
...
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
>>> for i in range(10):
...     train(inputs)