mindspore.train.LossMonitor
- class mindspore.train.LossMonitor(per_print_times=1)[源代码]
训练场景下,监控训练的loss;边训练边推理场景下,监控训练的loss和推理的metrics。
如果loss是NAN或INF,则终止训练。
说明
如果 per_print_times 为0,则不打印loss。
- 参数:
per_print_times (int) - 表示每隔多少个step打印一次loss。默认值:
1
。
- 异常:
ValueError - 当 per_print_times 不是整数或小于零。
样例:
>>> from mindspore import nn >>> from mindspore.train import Model, LossMonitor >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.4.1/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim) >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.4.1/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> loss_monitor = LossMonitor() >>> model.train(10, dataset, callbacks=loss_monitor)
- on_train_epoch_end(run_context)[源代码]
LossMoniter用于 model.fit,即边训练边推理场景时,打印训练的loss和当前epoch推理的metrics。
- 参数:
run_context (RunContext) - 包含模型的相关信息。详情请参考
mindspore.train.RunContext
。
- step_end(run_context)[源代码]
step结束时打印训练loss。
- 参数:
run_context (RunContext) - 包含模型的相关信息。详情请参考
mindspore.train.RunContext
。