mindspore.train.LambdaCallback
- class mindspore.train.LambdaCallback(on_train_epoch_begin=None, on_train_epoch_end=None, on_train_step_begin=None, on_train_step_end=None, on_train_begin=None, on_train_end=None, on_eval_epoch_begin=None, on_eval_epoch_end=None, on_eval_step_begin=None, on_eval_step_end=None, on_eval_begin=None, on_eval_end=None)[source]
Callback for creating simple, custom callbacks.
This callback is constructed with anonymous functions that will be called at the appropriate time (during mindspore.train.Model.{train | eval | fit}). Note that each stage of callbacks expects one positional arguments: run_context.
Note
This is an experimental interface that is subject to change or deletion.
- Parameters
on_train_epoch_begin (Function) – called at each train epoch begin.
on_train_epoch_end (Function) – called at each train epoch end.
on_train_step_begin (Function) – called at each train step begin.
on_train_step_end (Function) – called at each train step end.
on_train_begin (Function) – called at the beginning of model train.
on_train_end (Function) – called at the end of model train.
on_eval_epoch_begin (Function) – called at eval epoch begin.
on_eval_epoch_end (Function) – called at eval epoch end.
on_eval_step_begin (Function) – called at each eval step begin.
on_eval_step_end (Function) – called at each eval step end.
on_eval_begin (Function) – called at the beginning of model eval.
on_eval_end (Function) – called at the end of model eval.
Examples
>>> import numpy as np >>> import mindspore.dataset as ds >>> from mindspore import nn >>> from mindspore.train import Model, LambdaCallback >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))} >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32) >>> net = nn.Dense(10, 5) >>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9) >>> lambda_callback = LambdaCallback(on_train_epoch_end= ... lambda run_context: print("loss: ", run_context.original_args().net_outputs)) >>> model = Model(network=net, optimizer=opt, loss_fn=crit, metrics={"recall"}) >>> model.train(2, train_dataset, callbacks=[lambda_callback]) loss: 1.6127687 loss: 1.6106578