mindspore.LambdaCallback

class mindspore.LambdaCallback(on_train_epoch_begin=None, on_train_epoch_end=None, on_train_step_begin=None, on_train_step_end=None, on_train_begin=None, on_train_end=None, on_eval_epoch_begin=None, on_eval_epoch_end=None, on_eval_step_begin=None, on_eval_step_end=None, on_eval_begin=None, on_eval_end=None)[source]

Callback for creating simple, custom callbacks.

This callback is constructed with anonymous functions that will be called at the appropriate time (during mindspore.Model.{train | eval | fit}). Note that each stage of callbacks expects one positional arguments: run_context.

Note

This is an experimental interface that is subject to change or deletion.

Parameters
  • on_train_epoch_begin (Function) – called at each train epoch begin.

  • on_train_epoch_end (Function) – called at each train epoch end.

  • on_train_step_begin (Function) – called at each train step begin.

  • on_train_step_end (Function) – called at each train step end.

  • on_train_begin (Function) – called at the beginning of model train.

  • on_train_end (Function) – called at the end of model train.

  • on_eval_epoch_begin (Function) – called at eval epoch begin.

  • on_eval_epoch_end (Function) – called at eval epoch end.

  • on_eval_step_begin (Function) – called at each eval step begin.

  • on_eval_step_end (Function) – called at each eval step end.

  • on_eval_begin (Function) – called at the beginning of model eval.

  • on_eval_end (Function) – called at the end of model eval.

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> from mindspore import nn
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
>>> net = nn.Dense(10, 5)
>>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> lambda_callback = ms.LambdaCallback(on_train_epoch_end=
... lambda run_context: print("loss: ", run_context.original_args().net_outputs))
>>> model = ms.Model(network=net, optimizer=opt, loss_fn=crit, metrics={"recall"})
>>> model.train(2, train_dataset, callbacks=[lambda_callback])
loss: 1.6127687
loss: 1.6106578