mindspore.profiler.tensor_board_trace_handler

查看源文件
mindspore.profiler.tensor_board_trace_handler()[源代码]

对动态图模式的每一个step,调用该方法进行在线解析。

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> from mindspore import context, nn, Profiler
>>> from mindspore.profiler import schedule, tensor_board_trace_handler
>>>
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.fc = nn.Dense(2, 2)
...
...     def construct(self, x):
...         return self.fc(x)
>>>
>>> def generator_net():
...     for _ in range(2):
...         yield np.ones([2, 2]).astype(np.float32), np.ones([2]).astype(np.int32)
>>>
>>> def train(test_net):
...     optimizer = nn.Momentum(test_net.trainable_params(), 1, 0.9)
...     loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
...     data = ds.GeneratorDataset(generator_net(), ["data", "label"])
...     model = ms.train.Model(test_net, loss, optimizer)
...     model.train(1, data)
>>>
>>> if __name__ == '__main__':
...     context.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend")
...
...     net = Net()
...     STEP_NUM = 15
...
...     with Profiler(schedule=schedule(wait=1, warmup=1, active=2, repeat=1, skip_first=2),
...                   on_trace_ready=tensor_board_trace_handler) as prof:
...         for i in range(STEP_NUM):
...             train(net)
...             prof.step()