mindelec.vision.MonitorTrain

class mindelec.vision.MonitorTrain(per_print_times=1, summary_dir='./summary_train')[源代码]

训练损失监测器。

说明

如果 per_print_times 为0,则不打印loss。

参数:
  • per_print_times (int) - 打印loss间隔。默认值:1。

  • summary_dir (str) - 摘要保存路径。默认值:’./summary_t’。

支持平台:

Ascend

样例:

>>> from mindelec.vision import MonitorTrain
>>> per_print_times = 1
>>> summary_dir = './summary_train'
>>> MonitorTrain(per_print_times, summary_dir)
step_end(run_context)[源代码]

在epoch结束时评估模型。

参数:
  • run_context (RunContext) - 训练运行的上下文。