mindspore.Profiler
- class mindspore.Profiler(**kwargs)[源代码]
MindSpore用户能够通过该类对神经网络的性能进行采集。可以通过导入 mindspore.Profiler 然后初始化Profiler对象以开始分析,使用 Profiler.analyse() 停止收集和分析。可通过Mindinsight工具可视化分析结果。目前,Profiler支持AICORE算子、AICPU算子、HostCPU算子、内存、设备通信、集群等数据的分析。
- 参数:
output_path (str, 可选) - 表示输出数据的路径。默认值:”./data”。
profile_communication (bool, 可选) - (仅限Ascend)表示是否在多设备训练中收集通信性能数据。当值为True时,收集这些数据。在单台设备训练中,该参数的设置无效。默认值:False。
profile_memory (bool, 可选) - (仅限Ascend)表示是否收集Tensor内存数据。当值为True时,收集这些数据。默认值:False。
start_profile (bool, 可选) - 该参数控制是否在Profiler初始化的时候开启数据采集。默认值:True。
- 异常:
RuntimeError - 当CANN的版本与MindSpore版本不匹配时,生成的ascend_job_id目录结构MindSpore无法解析。
- 支持平台:
Ascend
GPU
样例:
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import nn >>> import mindspore.dataset as ds >>> from mindspore import Profiler >>> >>> >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.fc = nn.Dense(2,2) ... def construct(self, x): ... return self.fc(x) >>> >>> def generator(): ... for i in range(2): ... yield (np.ones([2, 2]).astype(np.float32), np.ones([2]).astype(np.int32)) >>> >>> def train(net): ... optimizer = nn.Momentum(net.trainable_params(), 1, 0.9) ... loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) ... data = ds.GeneratorDataset(generator, ["data", "label"]) ... model = ms.Model(net, loss, optimizer) ... model.train(1, data) >>> >>> if __name__ == '__main__': ... # If the device_target is GPU, set the device_target to "GPU" ... ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") ... ... # Init Profiler ... # Note that the Profiler should be initialized before model.train ... profiler = Profiler() ... ... # Train Model ... net = Net() ... train(net) ... ... # Profiler end ... profiler.analyse()
- op_analyse(op_name, device_id=None)[源代码]
获取primitive类型的算子性能数据。
- 参数:
op_name (str 或 list) - 表示要查询的primitive算子类型。
device_id (int, 可选) - 设备卡号,表示指定解析哪张卡的算子性能数据。在网络训练或者推理时使用,该参数可选。基于离线数据解析使用该接口时,默认值:0。
- 异常:
TypeError - op_name参数类型不正确。
TypeError - device_id参数类型不正确。
RunTimeError - 在Ascend上使用该接口获取性能数据。
- 支持平台:
GPU
CPU
样例:
>>> from mindspore import Profiler >>> ... # Profiler init. ... profiler = Profiler() ... ... # Train Model or eval Model. ... net = Net() ... train(net) ... ... # Profiler end ... profiler.analyse() ... ... profiler.op_analyse(op_name=["BiasAdd", "Conv2D"]) ... >>> from mindspore import Profiler >>> ... # Profiler init. ... profiler = Profiler(output_path="my_profiler_path") ... profiler.op_analyse(op_name="Conv2D")
- start()[源代码]
开启Profiler数据采集。可以按条件开启Profiler。
- 异常:
RuntimeError - profiler已经开启。
RuntimeError - 停止Minddata采集后,不支持重复开启。
RuntimeError - 如果start_profile参数未设置或设置为True。
样例:
>>> class StopAtStep(Callback): >>> def __init__(self, start_step, stop_step): ... super(StopAtStep, self).__init__() ... self.start_step = start_step ... self.stop_step = stop_step ... self.profiler = Profiler(start_profile=False) ... >>> def step_begin(self, run_context): ... cb_params = run_context.original_args() ... step_num = cb_params.cur_step_num ... if step_num == self.start_step: ... self.profiler.start() ... >>> def step_end(self, run_context): ... cb_params = run_context.original_args() ... step_num = cb_params.cur_step_num ... if step_num == self.stop_step: ... self.profiler.stop() ... >>> def end(self, run_context): ... self.profiler.analyse()
- stop()[源代码]
停止Profiler。可以按条件停止Profiler。
- 异常:
RuntimeError - profiler没有开启。
样例:
>>> class StopAtEpoch(Callback): >>> def __init__(self, start_epoch, stop_epoch): ... super(StopAtEpoch, self).__init__() ... self.start_epoch = start_epoch ... self.stop_epoch = stop_epoch ... self.profiler = Profiler(start_profile=False) ... >>> def epoch_begin(self, run_context): ... cb_params = run_context.original_args() ... epoch_num = cb_params.cur_epoch_num ... if epoch_num == self.start_epoch: ... self.profiler.start() ... >>> def epoch_end(self, run_context): ... cb_params = run_context.original_args() ... epoch_num = cb_params.cur_epoch_num ... if epoch_num == self.stop_epoch: ... self.profiler.stop() ... >>> def end(self, run_context): ... self.profiler.analyse()