sciai.common.TrainCellWithCallBack

查看源文件
class sciai.common.TrainCellWithCallBack(network, optimizer, loss_interval=1, time_interval=0, ckpt_interval=0, loss_names=('loss',), batch_num=1, grad_first=False, amp_level='O0', ckpt_dir='./checkpoints', clip_grad=False, clip_norm=0.001, model_name='model')[源代码]

带有回调的 TrainOneStepCell,可以处理多重损失。 回调功能如下:

1.loss:打印损失。 2.time:打印步骤所花费的时间,以及从开始所花费的时间。 3.ckpt:保存checkpoint。

参数:
  • network (Cell) - 训练网络。该网络支持多个输出。

  • optimizer (Cell) - 用于更新网络参数的优化器。

  • loss_interval (int) - 打印loss的步长间隔。 如果为 0,则不会打印loss。 默认值:1。

  • time_interval (int) - 打印时间的步长间隔。 如果为 0,则不会打印时间。 默认值:0。

  • ckpt_interval (int) - 保存checkpoint的epoch间隔,根据 batch_num 计算,如果为0,则不会保存checkpoint。 默认值:0。 如果是n个,则每个字符串对应同一位置的loss;如果是n+1个,第一个损失名称代表所有输出的总和,其他一一对应。默认值:("loss",)。

  • loss_names (Union(str, tuple[str], list[str])) - 各损失的名字,按照网络输出的顺序排列。 它可以接受n个或n+1个字符串, 其中n为网络输出的个数。如果n个,每个字符串对应同一位置的loss;如果n+1个,第一个字符串为所有输出的总和的损失名。 默认值:(“loss”,)。

  • batch_num (int) - 每个时期有多少批次。 默认值:1。

  • grad_first (bool) - 若为True,则只有网络的第一个输出参与梯度下降。 否则所有输出之和参与梯度下降。默认值:False。

  • amp_level (str) - 混合精度等级,目前支持["O0", "O1", "O2", "O3"]. 默认值:"O0".

  • ckpt_dir (str) - checkpoint保存路径。 默认值:"./checkpoints"。

  • clip_grad (bool) - 是否裁剪梯度。默认值:False.

  • clip_norm (Union(float, int)) - 梯度裁剪率,需为正数. 仅当 clip_grad 为True时生效. 默认值:1e-3.

  • model_name (str) - 模型名,影响ckpt名字。 默认:"model"。

输入:
  • *args (tuple[Tensor]) - 网络输入张量的元组.

输出:

Union(Tensor, tuple[Tensor]) - 网络输出的单项或多项loss.

异常:
  • TypeError - 如果输入参数不是要求的类型。

支持平台:

GPU CPU Ascend

样例:

>>> import mindspore as ms
>>> from mindspore import nn, ops
>>> from sciai.common import TrainCellWithCallBack
>>> class LossNet(nn.Cell):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.dense1 = nn.Dense(2, 1)
>>>         self.dense2 = nn.Dense(2, 1)
>>>     def construct(self, x):
>>>         loss1 = self.dense1(x).sum()
>>>         loss2 = self.dense2(x).sum()
>>>         return loss1, loss2
>>> loss_net = LossNet()
>>> optimizer = nn.Adam(loss_net.trainable_params(), 1e-2)
>>> train_net = TrainCellWithCallBack(loss_net, optimizer, time_interval=3, loss_interval=1, ckpt_interval=5,
>>>                                   ckpt_dir='.', loss_names=("total loss", "loss1", "loss2"))
>>> x = ops.ones((3, 2), ms.float32)
>>> for epoch in range(8):
>>>     loss1, loss2 = train_net(x)
step: 0, loss1: 0.07256523, loss2: 0.010363013, interval: 3.132981061935425s, total: 3.132981061935425s,
checkpoint saved at: ./model_iter_0_2000-12-31-23-59-59.ckpt
step: 1, loss1: 0.06356523, loss2: 0.0013630127
step: 2, loss1: 0.054565262, loss2: 0.007636956
step: 3, loss1: 0.04556533, loss2: 0.00999487, interval: 0.01753377914428711s, total: 3.150514841079712s
step: 4, loss1: 0.036565356, loss2: 0.0090501215
step: 5, loss1: 0.027565379, loss2: 0.0061383317, checkpoint saved at: ./model_iter_5_2000-12-31-23-59-59.ckpt
step: 6, loss1: 0.018565409, loss2: 0.0019272038, interval: 0.02319502830505371s, total: 3.1737098693847656s
step: 7, loss1: 0.00956542, loss2: 0.0032018598
static calc_ckpt_name(iter_str, model_name, postfix='')[源代码]

计算检查点文件名。

参数:
  • iter_str (Union[str]) - 迭代次数或epoch数。

  • model_name (str) - 模型名称。

  • postfix (str) - 文件名后缀。 默认:""。

返回:

str,checkpoint的文件名。

static calc_optim_ckpt_name(model_name, postfix='')[源代码]

计算最新的检查点文件名。

参数:
  • model_name (str) - 模型名称。

  • postfix (str) - 文件名后缀。 默认:""。

返回:

str,checkpoint的文件名。