sciai.common.TrainCellWithCallBack

View Source On Gitee
class sciai.common.TrainCellWithCallBack(network, optimizer, loss_interval=1, time_interval=0, ckpt_interval=0, loss_names=('loss',), batch_num=1, grad_first=False, amp_level='O0', ckpt_dir='./checkpoints', clip_grad=False, clip_norm=0.001, model_name='model')[source]

TrainOneStepCell with callbacks, which can handle multi-losses. Callbacks can be as follows:

1.loss: print loss(es). 2.time: print time spent during steps, and time spent from start. 3.ckpt: save checkpoint.

Parameters
  • network (Cell) – The training network. The network supports multi-outputs.

  • optimizer (Cell) – Optimizer for updating the network parameters.

  • loss_interval (int) – Step interval to print loss. if 0, it wouldn't print loss. Default: 1.

  • time_interval (int) – Step interval to print time. if 0, it wouldn't print time. Default: 0.

  • ckpt_interval (int) – Epoch interval to save checkpoint, calculated according to batch_num. If 0, it wouldn't save checkpoint. Default: 0.

  • loss_names (Union(str, tuple[str], list[str])) – Loss names in order of network outputs. It can accept n or n+1 strings, where n is the count of network outputs. If n, each string corresponds to the loss in the same position; if n + 1, the first loss name represents the sum of all outputs. Default:("loss",).

  • batch_num (int) – How many batches per epoch. Default: 1.

  • grad_first (bool) – If True, only the first output of the network would participate in the gradient descent. Otherwise, the sum of all outputs of the network would be taken into account. Default: False.

  • amp_level (str) – Mixed precision level, which supports ["O0", "O1", "O2", "O3"]. Default: "O0".

  • ckpt_dir (str) – Checkpoints saving path. Default: "./checkpoints".

  • clip_grad (bool) – Whether clip grad or not. Default: False.

  • clip_norm (Union(float, int)) – The clipping ratio, it should be greater than 0. Only enabled when clip_grad is True. Default: 1e-3.

  • model_name (str) – Model name which influences the checkpoint filename.

Inputs:
  • *args (tuple[Tensor]) - Tuple of input tensors of the network.

Outputs:

Union[Tensor, tuple[Tensor]], Tensor(s) of the loss(es).

Raises

TypeError – If the input parameters are not of required types.

Supported Platforms:

GPU CPU Ascend

Examples

>>> import mindspore as ms
>>> from mindspore import nn, ops
>>> from sciai.common import TrainCellWithCallBack
>>> class LossNet(nn.Cell):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.dense1 = nn.Dense(2, 1)
>>>         self.dense2 = nn.Dense(2, 1)
>>>     def construct(self, x):
>>>         loss1 = self.dense1(x).sum()
>>>         loss2 = self.dense2(x).sum()
>>>         return loss1, loss2
>>> loss_net = LossNet()
>>> optimizer = nn.Adam(loss_net.trainable_params(), 1e-2)
>>> train_net = TrainCellWithCallBack(loss_net, optimizer, time_interval=3, loss_interval=1, ckpt_interval=5,
>>>                                   ckpt_dir='.', loss_names=("total loss", "loss1", "loss2"))
>>> x = ops.ones((3, 2), ms.float32)
>>> for epoch in range(8):
>>>     loss1, loss2 = train_net(x)
step: 0, loss1: 0.07256523, loss2: 0.010363013, interval: 3.132981061935425s, total: 3.132981061935425s,
    checkpoint saved at: ./model_iter_0_2000-12-31-23-59-59.ckpt
step: 1, loss1: 0.06356523, loss2: 0.0013630127
step: 2, loss1: 0.054565262, loss2: 0.007636956
step: 3, loss1: 0.04556533, loss2: 0.00999487, interval: 0.01753377914428711s, total: 3.150514841079712s
step: 4, loss1: 0.036565356, loss2: 0.0090501215
step: 5, loss1: 0.027565379, loss2: 0.0061383317, checkpoint saved at: ./model_iter_5_2000-12-31-23-59-59.ckpt
step: 6, loss1: 0.018565409, loss2: 0.0019272038, interval: 0.02319502830505371s, total: 3.1737098693847656s
step: 7, loss1: 0.00956542, loss2: 0.0032018598
static calc_ckpt_name(iter_str, model_name, postfix='')[source]

Calculate checkpoint file name.

Parameters
  • iter_str (Union[str]) – Iteration number or epoch number.

  • model_name (str) – Model name.

  • postfix (str) – Filename postfix, generally can be the auto mixed precision level. Default: "".

Returns

str, Filename of checkpoint.

static calc_optim_ckpt_name(model_name, postfix='')[source]

Calculate the latest checkpoint filename. For example, Optimal_pinns_O2.ckpt.

Parameters
  • model_name (str) – Model name.

  • postfix (str) – Filename postfix. Default: "".

Returns

str, Filename of checkpoint.