mindspore.set_context

View Source On Gitee
mindspore.set_context(**kwargs)[source]

Set context for running environment.

Context should be configured before running your program. If there is no configuration, it will be automatically set according to the device target by default.

Note

Attribute name is required for setting attributes. The mode is not recommended to be changed after net was initialized because the implementations of some operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE .

Some configurations are device specific, and some parameters will be deprecated and removed in the future version (marked D in the second column), please use the replacement in the fourth column. see the below table for details:

Parameters
  • device_id (int) – ID of the target device, the value must be in [0, device_num_per_host-1], while device_num_per_host should be no more than 4096. Default: 0 . This parameter will be deprecated and will be removed in future versions.Please use api mindspore.set_device() with 'device_target' instead.

  • device_target (str) – The target device to run, support "Ascend", "GPU", and "CPU". If device target is not set, the version of MindSpore package is used. This parameter will be deprecated and will be removed in future versions.Please use api mindspore.set_device() with 'device_id' instead.

  • max_device_memory (str) – Set the maximum memory available for devices. The format is "xxGB". Default: " 1024GB" . The actual used memory size is the minimum of the available memory of the device and max_device_memory. 'max_device_memory' should be set before the program runs. When virtual memory is enabled, a too small 'max_device_memory' will cause frequent defragmentation, affecting performance. This parameter will be deprecated and will be removed in future versions. Please use the api mindspore.runtime.set_memory() instead.

  • variable_memory_max_size (str) – This parameter will be deprecated and will be removed in future versions. Please use the api mindspore.runtime.set_memory() instead.

  • mempool_block_size (str) – It takes effect when virtual memory is turned off, set the size of the memory pool block for devices. The format is "xxGB". Default: "1GB" . Minimum size is "1G". The actual used memory block size is the minimum of the available memory of the device and mempool_block_size. When there is enough memory, the memory will be expanded by this value. This parameter will be deprecated and will be removed in future versions. Please use the api mindspore.runtime.set_memory() instead.

  • op_timeout (int) – Set the maximum duration of executing an operator in seconds. If the execution time exceeds this value, system will terminate the task. 0 means endless wait. The defaults for AI Core and AICPU operators vary on different hardware. For more information, please refer to Ascend Community document about aclrtSetOpExecuteTimeOut. Default: 900 .

  • save_graphs (bool or int) –

    Whether to save intermediate compilation graphs. Default: 0 . Available values are:

    • False or 0: disable saving of intermediate compilation graphs.

    • 1: some intermediate files will be generated during graph compilation.

    • True or 2: Generate more ir files related to backend process.

    • 3: Generate visualization computing graphs and detailed frontend ir graphs.

    When the network structure is complex, setting save_graphs attribute to 2 or 3 may take too long. If you need quick problem locating, you can switch to 1 first.

    When the save_graphs attribute is set as True , 1 , 2 or 3 , attribute of save_graphs_path is used to set the intermediate compilation graph storage path. By default, the graphs are saved in the current directory. This parameter will be deprecated and removed in a future version. Please use the environment variable MS_DEV_SAVE_GRAPHS instead.

  • save_graphs_path (str) – Path to save graphs. Default: ".". If the specified directory does not exist, the system will automatically create the directory. During distributed training, graphs will be saved to the directory of save_graphs_path/rank_${rank_id}/. rank_id is the ID of the current device in the cluster. This parameter will be deprecated and removed in a future version. Please use the environment variable MS_DEV_SAVE_GRAPHS_PATH instead.

  • deterministic (str) –

    Whether to enable op run in deterministic mode. The value must be in the range of ['ON', 'OFF'], and the default value is 'OFF' .

    • "ON": Enable operator deterministic running mode.

    • "OFF": Disable operator deterministic running mode.

    When deterministic mode is on, model ops will be deterministic in Ascend. This means that if op run multiple times with the same inputs on the same hardware, it will have the exact same outputs each time. This is useful for debugging models. This parameter will be deprecated and will be removed in future versions. Please use the api mindspore.set_deterministic() instead.

  • print_file_path (str) – This parameter will be deprecated and will be removed in future versions.

  • env_config_path (str) – This parameter will be deprecated and will be removed in future versions.

  • precompile_only (bool) – Whether to only precompile the network. Default: False . If set to True , the network will only be compiled, not executed. This parameter will be deprecated and removed in a future version. Please use the environment variable MS_DEV_PRECOMPILE_ONLY instead.

  • pynative_synchronize (bool) – Whether to enable synchronous execution of the device in PyNative mode. Default: False . When the value is set to False , the operator is executed asynchronously on the device. When an error occurs in the execution of the operator, the specific error script code location cannot be located, when the value is set to True , the operator is executed synchronously on the device. It will reduce the execution performance of the program. At this time, when an error occurs in the execution of the operator, the location of the error script code can be located according to the call stack of the error. This parameter will be deprecated and will be removed in future versions.Please use the api mindspore.runtime.launch_blocking() instead.

  • mode (int) – Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Both modes support all backends. Default: PYNATIVE_MODE .

  • enable_reduce_precision (bool) – Whether to enable precision reduction. If the operator does not support the user-specified precision, the precision will be changed automatically. Default: True .

  • aoe_tune_mode (str) – AOE tuning mode setting, which is not set by default. When set to "online" , the tuning in online function is turned on. When set to "offline" , ge graph will be save for offline tuning.

  • aoe_config (dict) –

    Set the parameters specific to Ascend Optimization Engine. It is not set by default.

    • job_type (str): Mode type setting, default value is "2".

      • "1": subgraph tuning;

      • "2": operator tuning.

  • check_bprop (bool) – Whether to check back propagation nodes. The checking ensures that the shape and dtype of back propagation node outputs is the same as input parameters. Default: False . This parameter will be deprecated and removed in a future version.

  • max_call_depth (int) – Specify the maximum depth of function call. Must be positive integer. Default: 1000 . The max_call_depth parameter needs to be set when the nested call is too deep or the number of subgraphs is too large. If max_call_depth is set larger than before, the system max stack depth should be set larger too, otherwise a core dumped exception may be raised because of system stack overflow. This parameter will be deprecated and removed in a future version. Please use the api mindspore.set_recursion_limit() instead.

  • grad_for_scalar (bool) – Whether to get gradient for scalar. Default: False . When grad_for_scalar is set to True , the function's scalar input can be derived. The default value is False . Because the back-end does not support scaling operations currently, this interface only supports simple operations that can be deduced by the front-end. This parameter will be deprecated and removed in a future version. Please take the tensor derivative.

  • enable_compile_cache (bool) – Whether to save or load the compiled cache of the graph. After enable_compile_cache is set to True , during the first execution, a compilation cache is generated and exported to a MINDIR file. When the network is executed again, if enable_compile_cache is still set to True and the network scripts are not changed, the compile cache is loaded. Note that only limited automatic detection for the changes of python scripts is supported by now, which means that there is a correctness risk. Default: False . Currently, do not support the graph which is larger than 2G after compiled. This is an experimental prototype that is subject to change and/or deletion. This parameter will be deprecated and removed in a future version. Please use the environment variable MS_COMPILER_CACHE_ENABLE instead.

  • compile_cache_path (str) – Path to save the compile cache. Default: ".". If the specified directory does not exist, the system will automatically create the directory. The cache will be saved to the directory of compile_cache_path/rank_${rank_id}/. The rank_id is the ID of the current device in the cluster. This parameter will be deprecated and removed in a future version. Please use the environment variable MS_COMPILER_CACHE_PATH instead.

  • inter_op_parallel_num (int) – The thread number of op parallel at the same time. Default value is 0 , which means use the default num.

  • runtime_num_threads (int) – The thread pool number of cpu kernel used in runtime, which must bigger than or equal to 0. Default value is 30 , if you run many processes at the same time, you should set the value smaller to avoid thread contention. If set runtime_num_threads to 1, the runtime asynchronous pipeline capability cannot be enabled, which may affect performance. This parameter will be deprecated and will be removed in future versions. Please use the api mindspore.device_context.cpu.op_tuning.threads_num() instead.

  • disable_format_transform (bool) – Whether to disable the automatic format transform function from NCHW to NHWC. When the network training performance of fp16 is worse than fp32, disable_format_transform can be set to True to try to improve training performance. Default: False .

  • support_binary (bool) – Whether to support run .pyc or .so in graph mode. If want to support run .so or .pyc in graph mode, coulde set 'support_binary' to be True , and run once .py file. It would save the source of the interfaces would be compiled by MindSpore to the interfaces definition .py file that should be guaranteed to be writable. Then compile the .py file to the .pyc or .so file, and could run in Graph mode.

  • memory_optimize_level (str) –

    The memory optimize level. On Ascend hardware platform, default: O1, on other hardware platforms, default: O0. The value must be in ['O0', 'O1'].

    • O0: priority performance option, disable SOMAS (Safe Optimized Memory Allocation Solver) and some other memory optimizations.

    • O1: priority memory option, enable SOMAS and some other memory optimizations.

    This parameter will be deprecated and will be removed in future versions. Please use the api mindspore.runtime.set_memory() instead.

  • memory_offload (str) –

    Whether to enable the memory offload function. When it is enabled, the idle data will be temporarily copied to the host side in the case of insufficient device memory. The value must be in the range of ['ON', 'OFF'], and the default value is 'OFF' .

    • ON: Enable the memory Offload function. On Ascend hardware platform, this parameter does not take effect when the graph compilation level is not 'O0'; This parameter does not take effect when memory_optimize_level is set 'O1'.

    • OFF: Turn off the memory Offload function.

  • ascend_config (dict) –

    Set the parameters specific to Ascend hardware platform. It is not set by default. The default value of precision_mode, jit_compile and atomic_clean_policy are experimental parameters, may change in the future.

    • precision_mode (str): Mixed precision mode setting, and the default value of inference network is force_fp16 . The value range is as follows:

      • force_fp16: When the operator supports both float16 and float32, select float16 directly.

      • allow_fp32_to_fp16: For cube operators, use the float16. For vector operators, prefer to keep the origin dtype, if the operator in model can support float32, it will keep original dtype, otherwise it will reduce to float16.

      • allow_mix_precision: Automatic mixing precision, facing the whole network operator, according to the built-in optimization strategy, automatically reduces the precision of some operators to float16 or bfloat16.

      • must_keep_origin_dtype: Keep the accuracy of the original drawing.

      • force_fp32: When the input of the matrix calculation operator is float16 and the output supports float16 and float32, output is forced to float32.

      • allow_fp32_to_bf16: For cube operators, use the bfloat16. For vector operators, prefer to keep the origin dtype, if the operator in model can support float32, it will keep original dtype, otherwise it will reduce to bfloat16.

      • allow_mix_precision_fp16: Automatic mixing precision, facing the whole network operator, automatically reduces the precision of some operators to float16 according to the built-in optimization strategy.

      • allow_mix_precision_bf16: Automatic mixing precision, facing the whole network operator, according to the built-in optimization strategy, automatically reduces the precision of some operators to bfloat16.

    • jit_compile (bool): Whether to select online compilation. When set to 'True', online compilation is prioritized. When set to 'False', compiled operator binary files are prioritized to improve compilation performance. The default settings are online compilation for static shape, and compiled operator binary files for dynamic shape.

    • atomic_clean_policy (int): The policy for cleaning memory occupied by atomic operators in the network. Default: 1 .

      • 0: The memory occupied by all atomic operators in the network is cleaned centrally.

      • 1: Memory is not cleaned centrally and each atomic operator in the network is cleaned separately. When the memory of the network exceeds the limit, you may try this cleaning policy, but it may cause performance loss.

    • matmul_allow_hf32 (bool): Whether to convert FP32 to HF32 for Matmul operators. Default value: False. This is an experimental prototype that is subject to change and/or deletion. For detailed information, please refer to Ascend community .

    • conv_allow_hf32 (bool): Whether to convert FP32 to HF32 for Conv operators. Default value: True. This is an experimental prototype that is subject to change and/or deletion. For detailed information, please refer to Ascend community .

    • exception_dump (str): Enable exception dump for Ascend operators, providing the input and output data for failing Ascend operators. The value can be "0" , "1" and "2". For "0" , exception dump is turned off; for "1", all inputs and outputs will be dumped for AICore exception operators; for "2", inputs will be dumped for AICore exception operators, reducing the saved information but improving performance. Default: "2" .

    • op_precision_mode (str): Path to config file of op precision mode. For detailed information, please refer to Ascend community .

    • op_debug_option (str): Enable debugging options for Ascend operators, default not enabled. The value currently only supports being set to "oom".

      • "oom": When there is a memory out of bounds during the execution of an operator, AscendCL will return an error code of EZ9999.

    • ge_options (dict): Set options for CANN. The options are divided into two categories: global and session. This is an experimental prototype that is subject to change and/or deletion. For detailed information, please refer to Ascend community . The configuration options in ge_options may be duplicated with the options in ascend_config. If the same configuration options are set in both ascend_config and ge_options, the one set in ge_options shall prevail.

      • global (dict): Set global options.

      • session (dict): Set session options.

    • parallel_speed_up_json_path(Union[str, None]): The path to the parallel speed up json file, configuration can refer to parallel_speed_up.json . If its value is None or '', it does not take effect. Default None.

      • recompute_comm_overlap (bool): Enable overlap between recompute ops and communication ops if True. Default: False.

      • matmul_grad_comm_overlap (bool): Enable overlap between dw matmul and tensor parallel communication ops if True. Default: False.

      • recompute_allgather_overlap_fagrad (bool): Enable overlap between duplicated allgather by recomputing in sequence parallel and flashattentionscoregrad ops if True. Default: False.

      • enable_task_opt (bool): Enable communication fusion to optimize the number of communication operator tasks if True. Default: False.

      • enable_grad_comm_opt (bool): Enable overlap between dx ops and data parallel communication ops if True. Currently, do not support LazyInline <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.lazy_inline.html> Default: False.

      • enable_opt_shard_comm_opt (bool): Enable overlap between forward ops and optimizer parallel allgather communication if True. Currently, do not support LazyInline <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.lazy_inline.html> Default: False.

      • compute_communicate_fusion_level (int): Enable the fusion between compute and communicate. Default: 0. Note: This function must be used with Ascend Training Solution 24.0.RC2 or later.

        • 0: Disable fusion.

        • 1: Apply fusion to forward nodes.

        • 2: Apply fusion to backward nodes.

        • 3: Apply fusion to all nodes.

      • dataset_broadcast_opt_level (int): Optimize the scenario that the dataset repeated reading. Only support O0/O1 jit level. It doesn't work in O2 mode. Default: 0.

        • 0: Disable this optimize.

        • 1: Optimize dataset reader between pipeline stage.

        • 2: Optimize dataset reader within pipeline stage.

        • 3: Optimize dataset reader with all scenes.

      • bias_add_comm_swap (bool): Enable node execution order swap communication operators and add operators if True. Only 1-dimension bias node is supported. Default: False.

      • enable_allreduce_slice_to_reducescatter (bool): Enable allreduce optimization. In the scenario where the batchmatmul model introduces allreduce in parallel, if the subsequent nodes are stridedslice operator with model parallel, allreduce will be optimized as reducescatter according to the identified patterns. Typical used in MoE module with groupwise alltoall. Default: False.

      • enable_interleave_split_concat_branch (bool): Enable communication computation parallel optimization for branches formed by split and concat operators with enable_interleave attribute. It is typical used in MoE parallel scenario. After splitting the input data, each slice of data is processed by the MoE module, and then the branch results are concatenated. When the optimization is enable, communication and computation will be executed in parallel between branches. Default: False.

      • enable_interleave_parallel_branch (bool): Enable communication computation parallel optimization for parallel branches with parallel_branch attribute in branches merge node. It is typical used in MoE parallel scenario with routed and shared expert. When the optimization is enable, communication and computation will be executed in parallel between branches. Default: False.

    • host_scheduling_max_threshold(int): The max threshold to control whether the dynamic shape process is used when run the static graph, the default value is 0. When the number of operations in the static graph is less than the max threshold, this graph will be executed in dynamic shape process. In large model scenarios, this approach can save stream resources. If the number of operations in the static graph is greater than the maximum threshold, this graph will be executed in original static process.

  • jit_syntax_level (int) –

    Set JIT syntax level for graph compiling, triggered by GRAPH_MODE and @jit decorator. The value must be STRICT or LAX . Default: LAX . All levels support all backends.

    • STRICT : Only basic syntax is supported, and execution performance is optimal. Can be used for MindIR load and export.

    • LAX : Compatible with all Python syntax as much as possible. However, execution performance may be affected and not optimal. Cannot be used for MindIR load and export due to some syntax that may not be able to be exported.

  • debug_level (int) –

    Set config for debugging. Default value: RELEASE.

    • RELEASE: Used for normally running, and some debug information will be discard to get a better compiling performance.

    • DEBUG: Used for debugging when errors occur, more information will be record in compiling process.

    This parameter will be deprecated and removed in a future version.

  • gpu_config (dict) –

    Set the parameters specific to gpu hardware platform. It is not set by default. Currently, only setting conv_fprop_algo and conv_dgrad_algo and conv_wgrad_algo and conv_allow_tf32 and matmul_allow_tf32 are supported on GPU hardware platform.

    • conv_fprop_algo (str): Specifies convolution forward algorithm and the default value is 'normal', The value range is as follows:

      • normal: Use the heuristic search algorithm.

      • performance: Use the trial search algorithm.

      • implicit_gemm: This algorithm expresses the convolution as a matrix product without actually explicitly forming the matrix that holds the input tensor data.

      • implicit_precomp_gemm: This algorithm expresses convolution as a matrix product without actually explicitly forming the matrix that holds the input tensor data, but still needs some memory workspace to precompute some indices in order to facilitate the implicit construction of the matrix that holds the input tensor data.

      • gemm: This algorithm expresses the convolution as an explicit matrix product. A significant memory workspace is needed to store the matrix that holds the input tensor data.

      • direct: This algorithm expresses the convolution as a direct convolution (for example, without implicitly or explicitly doing a matrix multiplication).

      • fft: This algorithm uses the Fast-Fourier Transform approach to compute the convolution. A significant memory workspace is needed to store intermediate results.

      • fft_tiling: This algorithm uses the Fast-Fourier Transform approach but splits the inputs into tiles. A significant memory workspace is needed to store intermediate results but less than fft algorithm for large size images.

      • winograd: This algorithm uses the Winograd Transform approach to compute the convolution. A reasonably sized workspace is needed to store intermediate results.

      • winograd_nonfused: This algorithm uses the Winograd Transform approach to compute the convolution. A significant workspace may be needed to store intermediate results.

    • conv_dgrad_algo (str): Specifies convolution data grad algorithm and the default value is 'normal', The value range is as follows:

      • normal: Use the heuristic search algorithm.

      • performance: Use the trial search algorithm.

      • algo_0: This algorithm expresses the convolution as a sum of matrix products without actually explicitly forming the matrix that holds the input tensor data. The sum is done using the atomic add operation, thus the results are non-deterministic.

      • algo_1: This algorithm expresses the convolution as a matrix product without actually explicitly forming the matrix that holds the input tensor data. The results are deterministic.

      • fft: This algorithm uses a Fast-Fourier Transform approach to compute the convolution. A significant memory workspace is needed to store intermediate results. The results are deterministic.

      • fft_tiling: This algorithm uses the Fast-Fourier Transform approach but splits the inputs into tiles. A significant memory workspace is needed to store intermediate results but less than fft for large size images. The results are deterministic.

      • winograd: This algorithm uses the Winograd Transform approach to compute the convolution. A reasonably sized workspace is needed to store intermediate results. The results are deterministic.

      • winograd_nonfused: This algorithm uses the Winograd Transform approach to compute the convolution. A significant workspace may be needed to store intermediate results. The results are deterministic.

    • conv_wgrad_algo (str): Specifies convolution filter grad algorithm and the default value is 'normal', The value range is as follows:

      • normal: Use the heuristic search algorithm.

      • performance: Use the trial search algorithm.

      • algo_0: This algorithm expresses the convolution as a sum of matrix products without actually explicitly forming the matrix that holds the input tensor data. The sum is done using the atomic add operation, thus the results are non-deterministic.

      • algo_1: This algorithm expresses the convolution as a matrix product without actually explicitly forming the matrix that holds the input tensor data. The results are deterministic.

      • fft: This algorithm uses a Fast-Fourier Transform approach to compute the convolution. A significant memory workspace is needed to store intermediate results. The results are deterministic.

      • algo_3: This algorithm is similar to algo_0 but uses some small workspace to precompute some indices. The results are also non-deterministic.

      • winograd_nonfused: This algorithm uses the Winograd Transform approach to compute the convolution. A significant workspace may be needed to store intermediate results. The results are deterministic.

      • fft_tiling: This algorithm uses the Fast-Fourier Transform approach but splits the inputs into tiles. A significant memory workspace is needed to store intermediate results but less than fft for large size images. The results are deterministic.

    • conv_allow_tf32 (bool): The flag below controls to allow Tensor core TF32 computation on CUDNN and the default value is True.

    • matmul_allow_tf32 (bool): The flag below controls to allow Tensor core TF32 computation on CUBLAS and the default value is False.

  • jit_config (dict) –

    Set the global jit config for compile, take effect in network defined in Cell or jit decorators. It is not set by default. The setting in context is the global jit config, while JitConfig is the local network's jit config. When both exist simultaneously, the global jit config will not overwrite the local network's jit config.

    • jit_level (str): Used to control the compilation optimization level. Default: "" , The framework automatically selects the execution method based on product, Atlas training product is O2, and all other products are O0. In addition, The option of the dynamic shape must be O0 or O1, O2 is not supported. The value range is as follows:

      • "O0": Except for optimizations that may affect functionality, all other optimizations are turned off, adopt KernelByKernel execution mode.

      • "O1": Using commonly used optimizations and automatic operator fusion optimizations, adopt KernelByKernel execution mode. This optimization level is experimental and is being improved.

      • "O2": Ultimate performance optimization, adopt Sink execution mode.

    • infer_boost (str): Used to control the infer mode. Default: "off" . The value range is as follows:

      • "on": Enable infer mode, get better infer performance.

      • "off": Disable infer mode, use forward to infer, performance is not good.

  • exec_order (str) –

    Set the sorting method for operator execution in GRAPH_MODE Currently, only three sorting methods are supported: bfs and gpto, and the default method is bfs.

    • "bfs": The default sorting method, breadth priority, good communication masking, relatively good performance.

    • "dfs": An optional sorting method, depth-first sorting. The performance is relatively worse than that of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where other execution orders run out of memory (OOM).

    • "gpto": An optional sorting method. This method combines multiple execution orders and selects a method with relatively good performance. There may be some performance gains in scenarios with multiple replicas running in parallel.

Raises

ValueError – If input key is not an attribute in context.

Examples

>>> import mindspore as ms
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> ms.set_context(precompile_only=True)
>>> ms.set_context(device_target="Ascend")
>>> ms.set_context(device_id=0)
>>> ms.set_context(save_graphs=True, save_graphs_path="./model.ms")
>>> ms.set_context(enable_reduce_precision=True)
>>> ms.set_context(reserve_class_name_in_scope=True)
>>> ms.set_context(variable_memory_max_size="6GB")
>>> ms.set_context(aoe_tune_mode="online")
>>> ms.set_context(aoe_config={"job_type": "2"})
>>> ms.set_context(check_bprop=True)
>>> ms.set_context(max_device_memory="3.5GB")
>>> ms.set_context(mempool_block_size="1GB")
>>> ms.set_context(print_file_path="print.pb")
>>> ms.set_context(max_call_depth=80)
>>> ms.set_context(env_config_path="./env_config.json")
>>> ms.set_context(grad_for_scalar=True)
>>> ms.set_context(enable_compile_cache=True, compile_cache_path="./cache.ms")
>>> ms.set_context(pynative_synchronize=True)
>>> ms.set_context(runtime_num_threads=10)
>>> ms.set_context(inter_op_parallel_num=4)
>>> ms.set_context(disable_format_transform=True)
>>> ms.set_context(memory_optimize_level='O0')
>>> ms.set_context(memory_offload='ON')
>>> ms.set_context(deterministic='ON')
>>> ms.set_context(ascend_config={"precision_mode": "force_fp16", "jit_compile": True,
...                "atomic_clean_policy": 1, "op_precision_mode": "./op_precision_config_file",
...                "op_debug_option": "oom",
...                "ge_options": {"global": {"ge.opSelectImplmode": "high_precision"},
...                               "session": {"ge.exec.atomicCleanPolicy": "0"}}})
>>> ms.set_context(jit_syntax_level=ms.STRICT)
>>> ms.set_context(debug_level=ms.context.DEBUG)
>>> ms.set_context(gpu_config={"conv_fprop_algo": "performance", "conv_allow_tf32": True,
...                "matmul_allow_tf32": True})
>>> ms.set_context(jit_config={"jit_level": "O0"})
>>> ms.set_context(exec_order="gpto")