mindspore.train.TrainFaultTolerance
- class mindspore.train.TrainFaultTolerance(ckpt_save_path=None, **kwargs)[源代码]
该回调函数用于开启 MindIO的TTP特性,会嵌入训练的流程,完成TTP的初始化、上报、异常处理等操作。
说明
该特性仅支持Ascend后端的静态图模式,并且只支持sink_size值小于等于1的场景。
- 参数:
ckpt_save_path (str) - 异常发生时ckpt保存的路径,该路径是一个目录。保存时,会在该目录下创建新的名为‘ttp_saved_checkpoints-step_{cur_step_num}’目录。默认值为:
None
。kwargs (dict) - 其他字典类型参数。
- 异常:
Exception - TTP初始化失败,会对外抛Exception异常。
ModuleNotFoundError - Mindio TTP whl包未安装。
样例:
说明
在运行TrainFaultTolerance的用例之前,需要配置相应的环境变量。推荐使用msrun进行分布式的启动,参考 msrun启动方式。用例应该在4张卡上运行。
>>> import numpy as np >>> import os >>> import math >>> import mindspore as ms >>> import mindspore.dataset as ds >>> from mindspore import nn, ops, Parameter, train >>> from mindspore.communication import init, get_rank >>> from mindspore.common.initializer import initializer, HeUniform >>> from mindspore.train import Model, TrainFaultTolerance >>> from mindspore import dataset as ds >>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2') >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2) >>> init() >>> ms.set_seed(1) >>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file": ... "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())}) >>> class MatMulCell(nn.Cell): ... def __init__(self, param=None, shape=None): ... super().__init__() ... if shape is None: ... shape = [28 * 28, 512] ... weight_init = HeUniform(math.sqrt(5)) ... self.param = Parameter(initializer(weight_init, shape), name="param") ... if param is not None: ... self.param = param ... self.print = ops.Print() ... self.matmul = ops.MatMul() ... ... def construct(self, x): ... out = self.matmul(x, self.param) ... self.print("out is:", out) ... return out >>> >>> class Network(nn.Cell): ... def __init__(self): ... super().__init__() ... self.flatten = nn.Flatten() ... self.layer1 = MatMulCell() ... self.relu1 = nn.ReLU() ... self.layer2 = nn.Dense(512, 512) ... self.relu2 = nn.ReLU() ... self.layer3 = nn.Dense(512, 10) ... ... def construct(self, x): ... x = self.flatten(x) ... x = self.layer1(x) ... x = self.relu1(x) ... x = self.layer2(x) ... x = self.relu2(x) ... logits = self.layer3(x) ... return logits >>> >>> net = Network() >>> net.layer1.pipeline_stage = 0 >>> net.relu1.pipeline_stage = 0 >>> net.layer2.pipeline_stage = 0 >>> net.relu2.pipeline_stage = 1 >>> net.layer3.pipeline_stage = 1 >>> >>> def create_dataset(batch_size): ... dataset_path = os.getenv("DATA_PATH") ... dataset = ds.MnistDataset(dataset_path) ... image_transforms = [ ... ds.vision.Rescale(1.0 / 255.0, 0), ... ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), ... ds.vision.HWC2CHW() ... ] ... label_transform = ds.transforms.TypeCast(ms.int32) ... dataset = dataset.map(image_transforms, 'image') ... dataset = dataset.map(label_transform, 'label') ... dataset = dataset.batch(batch_size) ... return dataset >>> >>> dataset = create_dataset(32) >>> >>> optimizer = nn.SGD(net.trainable_params(), 1e-2) >>> optimizer_wrapper = nn.OptTFTWrapper(optimizer) >>> loss_fn = nn.CrossEntropyLoss() >>> >>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4) >>> net_with_loss.set_train() >>> model = Model(net_with_loss, optimizer=optimizer_wrapper) >>> tft_cb = TrainFaultTolerance() >>> loss_cb = train.LossMonitor(1) >>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
- end(run_context)[源代码]
训练结束,解注册MindIO TTP。
- 参数:
run_context (RunContext) - 包含模型的相关信息。详情请参考
mindspore.train.RunContext
。
- classmethod get_optimizer_wrapper(origin_opt_cls)[源代码]
使用TFT功能时的优化器类封装函数。
- 参数:
origin_opt_cls (Class) - 原优化器类。
- on_train_begin(run_context)[源代码]
训练开始时,向MindIO TTP注册训练时参数。
- 参数:
run_context (RunContext) - 包含模型的相关信息。详情请参考
mindspore.train.RunContext
。
- on_train_step_end(run_context)[源代码]
每个step完成时,进行MindIO TTP的上报。
- 参数:
run_context (RunContext) - 包含模型的相关信息。详情请参考
mindspore.train.RunContext
。