sciai.common.TrainCellWithCallBack
- class sciai.common.TrainCellWithCallBack(network, optimizer, loss_interval=1, time_interval=0, ckpt_interval=0, loss_names=('loss',), batch_num=1, grad_first=False, amp_level='O0', ckpt_dir='./checkpoints', clip_grad=False, clip_norm=0.001, model_name='model')[source]
TrainOneStepCell with callbacks, which can handle multi-losses. Callbacks can be as follows:
1.loss: print loss(es). 2.time: print time spent during steps, and time spent from start. 3.ckpt: save checkpoint.
- Parameters
network (Cell) – The training network. The network supports multi-outputs.
optimizer (Cell) – Optimizer for updating the network parameters.
loss_interval (int) – Step interval to print loss. if 0, it wouldn't print loss. Default: 1.
time_interval (int) – Step interval to print time. if 0, it wouldn't print time. Default: 0.
ckpt_interval (int) – Epoch interval to save checkpoint, calculated according to batch_num. If 0, it wouldn't save checkpoint. Default: 0.
loss_names (Union(str, tuple[str], list[str])) – Loss names in order of network outputs. It can accept n or n+1 strings, where n is the count of network outputs. If n, each string corresponds to the loss in the same position; if n + 1, the first loss name represents the sum of all outputs. Default:("loss",).
batch_num (int) – How many batches per epoch. Default: 1.
grad_first (bool) – If True, only the first output of the network would participate in the gradient descent. Otherwise, the sum of all outputs of the network would be taken into account. Default: False.
amp_level (str) – Mixed precision level, which supports ["O0", "O1", "O2", "O3"]. Default: "O0".
ckpt_dir (str) – Checkpoints saving path. Default: "./checkpoints".
clip_grad (bool) – Whether clip grad or not. Default: False.
clip_norm (Union(float, int)) – The clipping ratio, it should be greater than 0. Only enabled when clip_grad is True. Default: 1e-3.
model_name (str) – Model name which influences the checkpoint filename.
- Inputs:
*args (tuple[Tensor]) - Tuple of input tensors of the network.
- Outputs:
Union[Tensor, tuple[Tensor]], Tensor(s) of the loss(es).
- Raises
TypeError – If the input parameters are not of required types.
- Supported Platforms:
GPU
CPU
Ascend
Examples
>>> import mindspore as ms >>> from mindspore import nn, ops >>> from sciai.common import TrainCellWithCallBack >>> class LossNet(nn.Cell): >>> def __init__(self): >>> super().__init__() >>> self.dense1 = nn.Dense(2, 1) >>> self.dense2 = nn.Dense(2, 1) >>> def construct(self, x): >>> loss1 = self.dense1(x).sum() >>> loss2 = self.dense2(x).sum() >>> return loss1, loss2 >>> loss_net = LossNet() >>> optimizer = nn.Adam(loss_net.trainable_params(), 1e-2) >>> train_net = TrainCellWithCallBack(loss_net, optimizer, time_interval=3, loss_interval=1, ckpt_interval=5, >>> ckpt_dir='.', loss_names=("total loss", "loss1", "loss2")) >>> x = ops.ones((3, 2), ms.float32) >>> for epoch in range(8): >>> loss1, loss2 = train_net(x) step: 0, loss1: 0.07256523, loss2: 0.010363013, interval: 3.132981061935425s, total: 3.132981061935425s, checkpoint saved at: ./model_iter_0_2000-12-31-23-59-59.ckpt step: 1, loss1: 0.06356523, loss2: 0.0013630127 step: 2, loss1: 0.054565262, loss2: 0.007636956 step: 3, loss1: 0.04556533, loss2: 0.00999487, interval: 0.01753377914428711s, total: 3.150514841079712s step: 4, loss1: 0.036565356, loss2: 0.0090501215 step: 5, loss1: 0.027565379, loss2: 0.0061383317, checkpoint saved at: ./model_iter_5_2000-12-31-23-59-59.ckpt step: 6, loss1: 0.018565409, loss2: 0.0019272038, interval: 0.02319502830505371s, total: 3.1737098693847656s step: 7, loss1: 0.00956542, loss2: 0.0032018598