mindelec.vision.MonitorEval

查看源文件
class mindelec.vision.MonitorEval(summary_dir='./summary_eval', model=None, eval_ds=None, eval_interval=10, draw_flag=True)[源代码]

用于评估的LossMonitor。

参数:
  • summary_dir (str) - 摘要保存路径。默认值: './summary_eval'

  • model (Solver) - 评估的模型对象。默认值: None

  • eval_ds (Dataset) - eval数据集。默认值: None

  • eval_interval (int) - eval间隔。默认值: 10

  • draw_flag (bool) - 指定是否保存摘要记录。默认值: True

支持平台:

Ascend

样例:

>>> import mindspore.nn as nn
>>> from mindelec.solver import Solver
>>> from mindelec.vision import MonitorEval
>>> class S11Predictor(nn.Cell):
...     def __init__(self, input_dimension):
...         super(S11Predictor, self).__init__()
...         self.fc1 = nn.Dense(input_dimension, 128)
...         self.fc2 = nn.Dense(128, 128)
...         self.fc3 = nn.Dense(128, 1001)
...         self.relu = nn.ReLU()
...
...     def construct(self, x):
...         x0 = x
...         x1 = self.relu(self.fc1(x0))
...         x2 = self.relu(self.fc2(x1))
...         x = self.fc3(x1 + x2)
...         return x
>>> model_net = S11Predictor(3)
>>> model = Solver(network=model_net, mode="Data",
>>>                optimizer=nn.Adam(model_net.trainable_params(), 0.001), loss_fn=nn.MSELoss())
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> eval_ds = Dataset()
>>> summary_dir = './summary_eval_path'
>>> eval_interval = 10
>>> draw_flag = True
>>> MonitorEval(summary_dir, model, eval_ds, eval_interval, draw_flag)
epoch_end(run_context)[源代码]

在epoch结束时评估模型。

参数: