mindspore.jit
- mindspore.jit(fn=None, mode='PSJit', input_signature=None, hash_args=None, jit_config=None, compile_once=False)[source]
Create a callable MindSpore graph from a Python function.
This allows the MindSpore runtime to apply optimizations based on graph.
- Parameters
fn (Function) – The Python function that will be run as a graph. Default:
None
.mode (str) –
The type of jit used, the value of mode should be
PIJit
orPSJit
. Default:PSJit
.input_signature (Tensor) – The Tensor which describes the input arguments. The shape and dtype of the Tensor will be supplied to this function. If input_signature is specified, each input to fn must be a Tensor. And the input parameters of fn cannot accept **kwargs. The shape and dtype of actual inputs should keep the same as input_signature. Otherwise, TypeError will be raised. Default:
None
.hash_args (Union[Object, List or Tuple of Objects]) – The local free variables used inside fn, like functions or objects of class defined outside fn. Calling fn again with change of hash_args will trigger recompilation. Default:
None
.jit_config (JitConfig) – Jit config for compile. Default:
None
.compile_once (bool) –
True
: The function would be compiled once when it was created many times. But it may be wrong if the free variables were changed.False
: It would be recompiled when it was created again. Default:False
.
Note
If input_signature is specified, each input of fn must be a Tensor. And the input arguments for fn will not accept **kwargs.
- Returns
Function, if fn is not None, returns a callable function that will execute the compiled function; If fn is None, returns a decorator and when this decorator invokes with a single fn argument, the callable function is equal to the case when fn is not None.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> from mindspore import Tensor >>> from mindspore import ops >>> from mindspore import jit ... >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) ... >>> # create a callable MindSpore graph by calling decorator @jit >>> def tensor_add(x, y): ... z = x + y ... return z ... >>> tensor_add_graph = jit(fn=tensor_add) >>> out = tensor_add_graph(x, y) ... >>> # create a callable MindSpore graph through decorator @jit >>> @jit ... def tensor_add_with_dec(x, y): ... z = x + y ... return z ... >>> out = tensor_add_with_dec(x, y) ... >>> # create a callable MindSpore graph through decorator @jit with input_signature parameter >>> @jit(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)), ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)))) ... def tensor_add_with_sig(x, y): ... z = x + y ... return z ... >>> out = tensor_add_with_sig(x, y) ... ... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused. ... # While fn differs during calling again, recompilation will be triggered. >>> def func(x): ... return ops.exp(x) ... >>> def closure_fn(x, fn): ... @jit(hash_args=fn) ... def inner_fn(a): ... return fn(a) ... return inner_fn(x) ... >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32)) >>> for i in range(10): ... closure_fn(inputs, func) ... ... # Set compile_once = True, otherwise the train_step will be compiled again. >>> def train(x): ... @jit(compile_once = True) ... def train_step(x): ... return ops.exp(x) ... for i in range(10): ... train_step(x) ... >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32)) >>> for i in range(10): ... train(inputs)