mindspore.Profiler

View Source On Gitee
class mindspore.Profiler(**kwargs)[source]

This class to enable the profiling of MindSpore neural networks. MindSpore users can import the mindspore.Profiler, initialize the Profiler object to start profiling, and use Profiler.analyse() to stop profiling and analyse the results. Users can visualize the results using the MindSpore Insight tool. Now, Profiler supports AICORE operator, AICPU operator, HostCPU operator, memory, correspondence, cluster, etc data analysis.

Parameters
  • output_path (str, optional) – Output data path. Default: "./data" .

  • profiler_level (ProfilerLevel, optional) –

    (Ascend only) The level of profiling. Default: None.

    • ProfilerLevel.Level0: Leanest level of profiling data collection, collects information about the elapsed time of the computational operators on the NPU and communication large operator information.

    • ProfilerLevel.Level1: Collect more CANN layer AscendCL data and AICore performance metrics and communication mini operator information based on Level0.

    • ProfilerLevel.Level2: Collect GE and Runtime information in CANN layer on top of Level1

  • op_time (bool, optional) – (Ascend/GPU) Whether to collect operators performance data. Default value: True.

  • profile_communication (bool, optional) – (Ascend only) Whether to collect communication performance data in a multi devices training,collect when True. Setting this parameter has no effect during single card training. When using this parameter, op_time must be set to True . Default: False .

  • profile_memory (bool, optional) – (Ascend only) Whether to collect tensor memory data, collect when True . When using this parameter, op_time must be set to True. Collecting operator memory data when the graph compilation level is O2 requires collecting from the first step. Default: False .

  • parallel_strategy (bool, optional) – (Ascend only) Whether to collect parallel policy performance data. Default value: False .

  • start_profile (bool, optional) – The start_profile parameter controls whether to enable or disable performance data collection based on conditions. Default: True .

  • aicore_metrics (int, optional) –

    (Ascend only) Types of AICORE performance data collected, when using this parameter, op_time must be set to True , and the value must be in [-1, 0, 1, 2, 3, 4, 5, 6], Default: 0 , the data items contained in each metric are as follows:

    • -1: Does not collect AICORE data.

    • 0: ArithmeticUtilization contains mac_fp16/int8_ratio, vec_fp32/fp16/int32_ratio, vec_misc_ratio etc.

    • 1: PipeUtilization contains vec_ratio, mac_ratio, scalar_ratio, mte1/mte2/mte3_ratio, icache_miss_rate etc.

    • 2: Memory contains ub_read/write_bw, l1_read/write_bw, l2_read/write_bw, main_mem_read/write_bw etc.

    • 3: MemoryL0 contains l0a_read/write_bw, l0b_read/write_bw, l0c_read/write_bw etc.

    • 4: ResourceConflictRatio contains vec_bankgroup/bank/resc_cflt_ratio etc.

    • 5: MemoryUB contains ub_read/write_bw_mte, ub_read/write_bw_vector, ub_/write_bw_scalar etc.

    • 6: L2Cache contains write_cache_hit, write_cache_miss_allocate, r0_read_cache_hit, r1_read_cache_hit etc. This function only support Atlas A2 training series products.

  • l2_cache (bool, optional) – (Ascend only) Whether to collect l2 cache data, collect when True. Default: False .

  • hbm_ddr (bool, optional) – (Ascend only) Whether to collect On-Chip Memory/DDR read and write rate data, collect when True. Default: False .

  • pcie (bool, optional) – (Ascend only) Whether to collect PCIe bandwidth data, collect when True. Default: False .

  • sync_enable (bool, optional) –

    (GPU only) Whether the profiler collects operators in a synchronous way. Default: True .

    • True: The synchronous way. Before sending the operator to the GPU, the CPU records the start timestamp. Then the operator is returned to the CPU after execution, and the end timestamp is recorded, The duration of the operator is the difference between the two timestamps.

    • False: The asynchronous way. The duration of the operator is that of sending from the CPU to the GPU. This method can reduce the impact of adding profiler on overall training time.

  • data_process (bool, optional) – (Ascend/GPU) Whether to collect data to prepare performance data. Default value: False .

  • timeline_limit (int, optional) – (Ascend/GPU) Set the maximum storage size of the timeline file (unit M). When using this parameter, op_time must be set to True. Default value: 500 .

  • profile_framework (str, optional) –

    (Ascend/GPU) The host information to collect, it must be one of ["all", "time", None], When is not set to None, it would collect the host profiler data. When using this parameter, the op_time parameter must be enabled. Default: None.

    • "all": Record host timestamp.

    • "time": The same as "all".

    • None: Not record host information.

  • data_simplification (bool, optional) – (Ascend only) Whether to remove FRAMEWORK data and other redundant data. If set to True, only the delivery of profiler and the original performance data in the PROF_XXX directory are retained to save disk space. Default value: True .

  • with_stack (bool, optional) – (Ascend) Whether to collect frame host call stack data on the Python side. This data is presented in the form of a flame graph in the timeline. When using this parameter, the op_time and profile_framework parameters must be enabled. Default value: False .

  • analyse_only (bool, optional) – (Ascend/GPU) Whether to parse only performance data and not collect performance data. This parameter is experimental parameter and does not need to be set by the user. Default value: False .

  • rank_id (int, optional) – (Ascend/GPU) Set the rank id during parsing. This parameter is experimental parameter and does not need to be set by the user. Default value: 0 .

  • env_enable (bool, optional) – (Ascend/GPU) Whether to enable the collection of environment variables. This parameter is experimental parameter and does not need to be set by the user. Default value: False .

Raises

RuntimeError – When the version of CANN does not match the version of MindSpore, MindSpore cannot parse the generated ascend_job_id directory structure.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn
>>> import mindspore.dataset as ds
>>> from mindspore import Profiler
>>> from mindspore.profiler import ProfilerLevel
>>>
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.fc = nn.Dense(2,2)
...     def construct(self, x):
...         return self.fc(x)
>>>
>>> def generator():
...     for i in range(2):
...         yield (np.ones([2, 2]).astype(np.float32), np.ones([2]).astype(np.int32))
>>>
>>> def train(net):
...     optimizer = nn.Momentum(net.trainable_params(), 1, 0.9)
...     loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
...     data = ds.GeneratorDataset(generator, ["data", "label"])
...     model = ms.train.Model(net, loss, optimizer)
...     model.train(1, data)
>>>
>>> if __name__ == '__main__':
...     # If the device_target is GPU, set the device_target to "GPU"
...     ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
...
...     # Init Profiler
...     # Note that the Profiler should be initialized before model.train
...     profiler = Profiler(profiler_level=ProfilerLevel.Level0)
...
...     # Train Model
...     net = Net()
...     train(net)
...
...     # Profiler end
...     profiler.analyse()
add_metadata(key: str, value: str)[source]

Report custom metadata key-value pair data.

Parameters
  • key (str) – The key to the metadata.

  • value (str) – The value to the metadata.

Examples

>>> from mindspore import Profiler
>>> # Profiler init.
>>> profiler = Profiler()
>>> # Call Profiler add_metadata
>>> profiler.add_metadata("test_key", "test_value")
>>> # Profiler end
>>> profiler.analyse()
add_metadata_json(key: str, value: str)[source]

Report custom metadata key-value pair data with the value as a JSON string data.

Parameters
  • key (str) – The key to the metadata.

  • value (str) – The json str format value to the metadata.

Examples

>>> import json
>>> from mindspore import Profiler
>>> # Profiler init.
>>> profiler = Profiler()
>>> # Call Profiler add_metadata_json
>>> profiler.add_metadata_json("test_key", json.dumps({"key1": 1, "key2": 2}))
>>> # Profiler end, metadata will be saved in profiler_metadata.json
>>> profiler.analyse()
analyse(offline_path=None, pretty=False, step_list=None, mode='sync')[source]

Collect and analyze training performance data, support calls during and after training. The example shows above.

Parameters
  • offline_path (Union[str, None], optional) – The data path which need to be analyzed with offline mode. Offline mode isused in abnormal exit scenario. This parameter should be set to None for online mode. Default: None.

  • pretty (bool, optional) – Whether to pretty json files. Default: False.

  • step_list (list, optional) – A list of steps that need to be analyzed, the steps must be consecutive integers. Default: None. By default, all steps will be analyzed.

  • mode (str, optional) –

    Analysis mode, it must be one of ["sync", "async"]. Default: sync.

    • sync: analyse data in current process, it will block the current process.

    • async: analyse data in subprocess, it will not block the current process. Since the parsing process will take up extra CPU resources, please enable this mode according to the actual resource situation.

Examples

>>> from mindspore.train import Callback
>>> from mindspore import Profiler
>>> class StopAtStep(Callback):
...     def __init__(self, start_step=1, stop_step=5):
...         super(StopAtStep, self).__init__()
...         self.start_step = start_step
...         self.stop_step = stop_step
...         self.profiler = Profiler(start_profile=False)
...
...     def step_begin(self, run_context):
...         cb_params = run_context.original_args()
...         step_num = cb_params.cur_step_num
...         if step_num == self.start_step:
...             self.profiler.start()
...
...     def step_end(self, run_context):
...         cb_params = run_context.original_args()
...         step_num = cb_params.cur_step_num
...         if step_num == self.stop_step:
...             self.profiler.stop()
...
...     def end(self, run_context):
...         self.profiler.analyse(step_list=[2,3,4], mode="sync")
classmethod offline_analyse(path: str, pretty=False, step_list=None, data_simplification=True)[source]

Analyze training performance data offline, which is invoked after performance data collection is completed.

Parameters
  • path (str) – The profiling data path which need to be analyzed offline. There needs to be a profiler directory in this path.

  • pretty (bool, optional) – Whether to pretty json files. Default: False.

  • step_list (list, optional) – A list of steps that need to be analyzed, the steps must be consecutive integers. Default: None. By default, all steps will be analyzed.

  • data_simplification (bool, optional) – Whether to enable data simplification. Default: True.

Examples

>>> from mindspore import Profiler
>>> Profiler.offline_analyse("./profiling_path")
op_analyse(op_name, device_id=None)[source]

Profiler users can use this interface to obtain operator performance data.

Parameters
  • op_name (str or list) – The primitive operator name to query.

  • device_id (int, optional) – ID of the target device. This parameter is optional during network training or inference, and users can use device_id parameter to specify which card operator performance data to parse. If this interface is used for offline data parsing, Default: 0 .

Raises
  • TypeError – If the op_name parameter type is incorrect.

  • TypeError – If the device_id parameter type is incorrect.

  • RuntimeError – If MindSpore runs on Ascend, this interface cannot be used.

Supported Platforms:

GPU CPU

Examples

>>> from mindspore import Profiler
>>> from mindspore import nn
>>> from mindspore import Model
>>> # Profiler init.
>>> profiler = Profiler()
>>> # Train Model or eval Model, taking LeNet5 as an example.
>>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
>>> # Create the dataset taking MNIST as an example.
>>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
>>> dataloader = create_dataset()
>>> model = Model(net, loss, optimizer)
>>> model.train(5, dataloader, dataset_sink_mode=False)
>>>
>>> # Profiler end
>>> profiler.analyse()
>>>
>>> profiler.op_analyse(op_name=["BiasAdd", "Conv2D"])
start()[source]

Used for Ascend, GPU, start profiling. Profiling can be turned on based on step and epoch.

Raises
  • RuntimeError – If the profiler has already started.

  • RuntimeError – If the start_profile parameter is not set or is set to True.

Examples

>>> from mindspore.train import Callback
>>> from mindspore import Profiler
>>> class StopAtStep(Callback):
...     def __init__(self, start_step, stop_step):
...         super(StopAtStep, self).__init__()
...         self.start_step = start_step
...         self.stop_step = stop_step
...         self.profiler = Profiler(start_profile=False)
...
...     def step_begin(self, run_context):
...         cb_params = run_context.original_args()
...         step_num = cb_params.cur_step_num
...         if step_num == self.start_step:
...             self.profiler.start()
...
...     def step_end(self, run_context):
...         cb_params = run_context.original_args()
...         step_num = cb_params.cur_step_num
...         if step_num == self.stop_step:
...             self.profiler.stop()
...
...     def end(self, run_context):
...         self.profiler.analyse()
stop()[source]

Used for Ascend, GPU, stop profiling. Profiling can be turned off based on step and epoch.

Raises

RuntimeError – If the profiler has not started, this function is disabled.

Examples

>>> from mindspore.train import Callback
>>> from mindspore import Profiler
>>> class StopAtEpoch(Callback):
...     def __init__(self, start_epoch, stop_epoch):
...         super(StopAtEpoch, self).__init__()
...         self.start_epoch = start_epoch
...         self.stop_epoch = stop_epoch
...         self.profiler = Profiler(start_profile=False)
...
...     def epoch_begin(self, run_context):
...         cb_params = run_context.original_args()
...         epoch_num = cb_params.cur_epoch_num
...         if epoch_num == self.start_epoch:
...             self.profiler.start()
...
...     def epoch_end(self, run_context):
...         cb_params = run_context.original_args()
...         epoch_num = cb_params.cur_epoch_num
...         if epoch_num == self.stop_epoch:
...             self.profiler.stop()
...
...     def end(self, run_context):
...         self.profiler.analyse()