sciai.common.TrainCellWithCallBack
- class sciai.common.TrainCellWithCallBack(network, optimizer, loss_interval=1, time_interval=0, ckpt_interval=0, loss_names=('loss',), batch_num=1, grad_first=False, amp_level='O0', ckpt_dir='./checkpoints', clip_grad=False, clip_norm=0.001, model_name='model')[源代码]
带有回调的 TrainOneStepCell,可以处理多重损失。 回调功能如下:
1.loss:打印损失。 2.time:打印步骤所花费的时间,以及从开始所花费的时间。 3.ckpt:保存checkpoint。
- 参数:
network (Cell) - 训练网络。该网络支持多个输出。
optimizer (Cell) - 用于更新网络参数的优化器。
loss_interval (int) - 打印loss的步长间隔。 如果为 0,则不会打印loss。 默认值:1。
time_interval (int) - 打印时间的步长间隔。 如果为 0,则不会打印时间。 默认值:0。
ckpt_interval (int) - 保存checkpoint的epoch间隔,根据 batch_num 计算,如果为0,则不会保存checkpoint。 默认值:0。 如果是n个,则每个字符串对应同一位置的loss;如果是n+1个,第一个损失名称代表所有输出的总和,其他一一对应。默认值:("loss",)。
loss_names (Union(str, tuple[str], list[str])) - 各损失的名字,按照网络输出的顺序排列。 它可以接受n个或n+1个字符串, 其中n为网络输出的个数。如果n个,每个字符串对应同一位置的loss;如果n+1个,第一个字符串为所有输出的总和的损失名。 默认值:(“loss”,)。
batch_num (int) - 每个时期有多少批次。 默认值:1。
grad_first (bool) - 若为True,则只有网络的第一个输出参与梯度下降。 否则所有输出之和参与梯度下降。默认值:False。
amp_level (str) - 混合精度等级,目前支持["O0", "O1", "O2", "O3"]. 默认值:"O0".
ckpt_dir (str) - checkpoint保存路径。 默认值:"./checkpoints"。
clip_grad (bool) - 是否裁剪梯度。默认值:False.
clip_norm (Union(float, int)) - 梯度裁剪率,需为正数. 仅当 clip_grad 为True时生效. 默认值:1e-3.
model_name (str) - 模型名,影响ckpt名字。 默认:"model"。
- 输入:
*args (tuple[Tensor]) - 网络输入张量的元组.
- 输出:
Union(Tensor, tuple[Tensor]) - 网络输出的单项或多项loss.
- 异常:
TypeError - 如果输入参数不是要求的类型。
- 支持平台:
GPU
CPU
Ascend
样例:
>>> import mindspore as ms >>> from mindspore import nn, ops >>> from sciai.common import TrainCellWithCallBack >>> class LossNet(nn.Cell): >>> def __init__(self): >>> super().__init__() >>> self.dense1 = nn.Dense(2, 1) >>> self.dense2 = nn.Dense(2, 1) >>> def construct(self, x): >>> loss1 = self.dense1(x).sum() >>> loss2 = self.dense2(x).sum() >>> return loss1, loss2 >>> loss_net = LossNet() >>> optimizer = nn.Adam(loss_net.trainable_params(), 1e-2) >>> train_net = TrainCellWithCallBack(loss_net, optimizer, time_interval=3, loss_interval=1, ckpt_interval=5, >>> ckpt_dir='.', loss_names=("total loss", "loss1", "loss2")) >>> x = ops.ones((3, 2), ms.float32) >>> for epoch in range(8): >>> loss1, loss2 = train_net(x) step: 0, loss1: 0.07256523, loss2: 0.010363013, interval: 3.132981061935425s, total: 3.132981061935425s, checkpoint saved at: ./model_iter_0_2000-12-31-23-59-59.ckpt step: 1, loss1: 0.06356523, loss2: 0.0013630127 step: 2, loss1: 0.054565262, loss2: 0.007636956 step: 3, loss1: 0.04556533, loss2: 0.00999487, interval: 0.01753377914428711s, total: 3.150514841079712s step: 4, loss1: 0.036565356, loss2: 0.0090501215 step: 5, loss1: 0.027565379, loss2: 0.0061383317, checkpoint saved at: ./model_iter_5_2000-12-31-23-59-59.ckpt step: 6, loss1: 0.018565409, loss2: 0.0019272038, interval: 0.02319502830505371s, total: 3.1737098693847656s step: 7, loss1: 0.00956542, loss2: 0.0032018598