mindspore.train.History
- class mindspore.train.History[源代码]
将网络输出和评估指标的相关信息记录到 History 对象中。
用户不自定义训练网络或评估网络情况下,记录的内容将为损失值;用户自定义了训练网络/评估网络的情况下,如果定义的网络返回 Tensor 或 numpy.ndarray,则记录此返回值均值,如果返回 tuple 或 list,则记录第一个元素。
说明
通常使用在 mindspore.train.Model.train 和 mindspore.train.Model.fit 中。
样例:
>>> import numpy as np >>> import mindspore.dataset as ds >>> from mindspore import nn >>> from mindspore.train import Model, History >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))} >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32) >>> net = nn.Dense(10, 5) >>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9) >>> history_cb = History() >>> model = Model(network=net, optimizer=opt, loss_fn=crit, metrics={"recall"}) >>> model.train(2, train_dataset, callbacks=[history_cb]) >>> print(history_cb.epoch) >>> print(history_cb.history) {'epoch': [1, 2]} {'net_output': [1.607877, 1.6033841]}
- begin(run_context)[源代码]
训练开始时初始化History对象的epoch属性。
- 参数:
run_context (RunContext) - 包含模型的一些基本信息。详情请参考
mindspore.train.RunContext
。
- epoch_end(run_context)[源代码]
epoch结束时记录网络输出和评估指标的相关信息。
- 参数:
run_context (RunContext) - 包含模型的一些基本信息。详情请参考
mindspore.train.RunContext
。