mindspore.train.TFTRegister

class mindspore.train.TFTRegister(ctrl_rank_id, ctrl_ip, ctrl_port, ckpt_save_path)[源代码]

该回调用于开启 MindIO的TTP特性,该CallBack会嵌入训练的流程,完成TTP 的初始化、上报、异常处理等操作。

说明

该特性仅支持Ascend后端的静态图模式,并且只支持sink_size值小于等于1的场景。

参数:
  • ctrl_rank_id (int) - TTP controller 运行的rank_id, 该参数用于启动TTP的controller。

  • ctrl_ip (str) - TTP controller 的IP地址, 该参数用于启动TTP的controller。

  • ctrl_port (int) - TTP controller 的IP端口, 该参数用于启动TTP的controller和processor。

  • ckpt_save_path (str) - 异常发生时ckpt保存的路径,该路径是一个目录,ckpt的异常保存时会在该录下创建新的名为‘ttp_saved_checkpoints-step_{cur_step_num}’目录。

异常:
  • Exception - TTP 初始化失败,会对外抛Exception异常。

  • ModuleNotFoundError - Mindio TTP whl 包未安装。

样例:

>>> import numpy as np
>>> import os
>>> import math
>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> from mindspore import nn, ops, Parameter, train
>>> from mindspore.communication import init
>>> from mindspore.common.initializer import initializer, HeUniform
>>> from mindspore.train import Model, TFTRegister
>>> from mindspore import dataset as ds
>>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
>>> init()
>>> ms.set_seed(1)
>>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file":
>>>                             "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
>>> class MatMulCell(nn.Cell):
...     def __init__(self, param=None, shape=None):
...         super().__init__()
...         if shape is None:
...             shape = [28 * 28, 512]
...         weight_init = HeUniform(math.sqrt(5))
...         self.param = Parameter(initializer(weight_init, shape), name="param")
...         if param is not None:
...             self.param = param
...         self.print = ops.Print()
...         self.matmul = ops.MatMul()
...
...     def construct(self, x):
...         out = self.matmul(x, self.param)
...         self.print("out is:", out)
...         return out
>>>
>>> class Network(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.flatten = nn.Flatten()
...         self.layer1 = MatMulCell()
...         self.relu1 = nn.ReLU()
...         self.layer2 = nn.Dense(512, 512)
...         self.relu2 = nn.ReLU()
...         self.layer3 = nn.Dense(512, 10)
...
...     def construct(self, x):
...         x = self.flatten(x)
...         x = self.layer1(x)
...         x = self.relu1(x)
...         x = self.layer2(x)
...         x = self.relu2(x)
...         logits = self.layer3(x)
...         return logits
>>>
>>> net = Network()
>>> net.layer1.pipeline_stage = 0
>>> net.relu1.pipeline_stage = 0
>>> net.layer2.pipeline_stage = 0
>>> net.relu2.pipeline_stage = 1
>>> net.layer3.pipeline_stage = 1
>>>
>>> def create_dataset(batch_size):
...     dataset_path = os.getenv("DATA_PATH")
...     dataset = ds.MnistDataset(dataset_path)
...     image_transforms = [
...         ds.vision.Rescale(1.0 / 255.0, 0),
...         ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
...         ds.vision.HWC2CHW()
...     ]
...     label_transform = ds.transforms.TypeCast(ms.int32)
...     dataset = dataset.map(image_transforms, 'image')
...     dataset = dataset.map(label_transform, 'label')
...     dataset = dataset.batch(batch_size)
...     return dataset
>>>
>>> data_set = create_dataset(32)
>>>
>>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
>>> optimizer_wrapper = nn.OptTFTWrapper(optimizer)
>>> loss_fn = nn.CrossEntropyLoss()
>>>
>>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
>>> net_with_loss.set_train()
>>> model = Model(net_with_loss, optimizer=optimizer)
>>> tft_cb = TFTRegister("192.168.0.1", 2000, "./tft_checkpoint/")
>>> loss_cb = train.LossMonitor(1)
>>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
on_train_step_end(run_context)[源代码]

每个step完成时进行MindIO TTP的上报。

参数: