mindspore.train.TFTRegister
- class mindspore.train.TFTRegister(ctrl_rank_id, ctrl_ip, ctrl_port, ckpt_save_path)[源代码]
该回调用于开启 MindIO的TTP特性,该CallBack会嵌入训练的流程,完成TTP 的初始化、上报、异常处理等操作。
说明
该特性仅支持Ascend后端的静态图模式,并且只支持sink_size值小于等于1的场景。
- 参数:
ctrl_rank_id (int) - TTP controller 运行的rank_id, 该参数用于启动TTP的controller。
ctrl_ip (str) - TTP controller 的IP地址, 该参数用于启动TTP的controller。
ctrl_port (int) - TTP controller 的IP端口, 该参数用于启动TTP的controller和processor。
ckpt_save_path (str) - 异常发生时ckpt保存的路径,该路径是一个目录,ckpt的异常保存时会在该录下创建新的名为‘ttp_saved_checkpoints-step_{cur_step_num}’目录。
- 异常:
Exception - TTP 初始化失败,会对外抛Exception异常。
ModuleNotFoundError - Mindio TTP whl 包未安装。
样例:
>>> import numpy as np >>> import os >>> import math >>> import mindspore as ms >>> import mindspore.dataset as ds >>> from mindspore import nn, ops, Parameter, train >>> from mindspore.communication import init >>> from mindspore.common.initializer import initializer, HeUniform >>> from mindspore.train import Model, TFTRegister >>> from mindspore import dataset as ds >>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2') >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2) >>> init() >>> ms.set_seed(1) >>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file": >>> "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())}) >>> class MatMulCell(nn.Cell): ... def __init__(self, param=None, shape=None): ... super().__init__() ... if shape is None: ... shape = [28 * 28, 512] ... weight_init = HeUniform(math.sqrt(5)) ... self.param = Parameter(initializer(weight_init, shape), name="param") ... if param is not None: ... self.param = param ... self.print = ops.Print() ... self.matmul = ops.MatMul() ... ... def construct(self, x): ... out = self.matmul(x, self.param) ... self.print("out is:", out) ... return out >>> >>> class Network(nn.Cell): ... def __init__(self): ... super().__init__() ... self.flatten = nn.Flatten() ... self.layer1 = MatMulCell() ... self.relu1 = nn.ReLU() ... self.layer2 = nn.Dense(512, 512) ... self.relu2 = nn.ReLU() ... self.layer3 = nn.Dense(512, 10) ... ... def construct(self, x): ... x = self.flatten(x) ... x = self.layer1(x) ... x = self.relu1(x) ... x = self.layer2(x) ... x = self.relu2(x) ... logits = self.layer3(x) ... return logits >>> >>> net = Network() >>> net.layer1.pipeline_stage = 0 >>> net.relu1.pipeline_stage = 0 >>> net.layer2.pipeline_stage = 0 >>> net.relu2.pipeline_stage = 1 >>> net.layer3.pipeline_stage = 1 >>> >>> def create_dataset(batch_size): ... dataset_path = os.getenv("DATA_PATH") ... dataset = ds.MnistDataset(dataset_path) ... image_transforms = [ ... ds.vision.Rescale(1.0 / 255.0, 0), ... ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), ... ds.vision.HWC2CHW() ... ] ... label_transform = ds.transforms.TypeCast(ms.int32) ... dataset = dataset.map(image_transforms, 'image') ... dataset = dataset.map(label_transform, 'label') ... dataset = dataset.batch(batch_size) ... return dataset >>> >>> data_set = create_dataset(32) >>> >>> optimizer = nn.SGD(net.trainable_params(), 1e-2) >>> optimizer_wrapper = nn.OptTFTWrapper(optimizer) >>> loss_fn = nn.CrossEntropyLoss() >>> >>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4) >>> net_with_loss.set_train() >>> model = Model(net_with_loss, optimizer=optimizer) >>> tft_cb = TFTRegister("192.168.0.1", 2000, "./tft_checkpoint/") >>> loss_cb = train.LossMonitor(1) >>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
- on_train_step_end(run_context)[源代码]
每个step完成时进行MindIO TTP的上报。
- 参数:
run_context (RunContext) - 包含模型的相关信息。详情请参考
mindspore.train.RunContext
。