mindelec.vision.MonitorEval
- class mindelec.vision.MonitorEval(summary_dir='./summary_eval', model=None, eval_ds=None, eval_interval=10, draw_flag=True)[源代码]
用于评估的LossMonitor。
- 参数:
summary_dir (str) - 摘要保存路径。默认值:’./summary_eval’。
model (Solver) - 评估的模型对象。默认值:None。
eval_ds (Dataset) - eval数据集。默认值:None。
eval_interval (int) - eval间隔。默认值:10。
draw_flag (bool) - 指定是否保存摘要记录。默认值:True。
- 支持平台:
Ascend
样例:
>>> import mindspore.nn as nn >>> from mindelec.solver import Solver >>> from mindelec.vision import MonitorEval >>> class S11Predictor(nn.Cell): ... def __init__(self, input_dimension): ... super(S11Predictor, self).__init__() ... self.fc1 = nn.Dense(input_dimension, 128) ... self.fc2 = nn.Dense(128, 128) ... self.fc3 = nn.Dense(128, 1001) ... self.relu = nn.ReLU() ... ... def construct(self, x): ... x0 = x ... x1 = self.relu(self.fc1(x0)) ... x2 = self.relu(self.fc2(x1)) ... x = self.fc3(x1 + x2) ... return x >>> model_net = S11Predictor(3) >>> model = Solver(network=model_net, mode="Data", >>> optimizer=nn.Adam(model_net.trainable_params(), 0.001), loss_fn=nn.MSELoss()) >>> # For details about how to build the dataset, please refer to the tutorial >>> # document on the official website. >>> eval_ds = Dataset() >>> summary_dir = './summary_eval_path' >>> eval_interval = 10 >>> draw_flag = True >>> MonitorEval(summary_dir, model, eval_ds, eval_interval, draw_flag)
- epoch_end(run_context)[源代码]
在epoch结束时评估模型。
- 参数:
run_context (RunContext) - 包含一些模型中的信息,详情请参考 mindspore.train.RunContext。