mindspore.train.MindIOTTPAdapter

class mindspore.train.MindIOTTPAdapter(controller_ip, controller_port, ckpt_save_path)[源代码]

该回调用于开启 MindIO的TTP特性, 该CallBack会嵌入训练的流程,完成TTP 的初始化、上报、异常处理等操作。

说明

该特性仅支持Ascend GE LazyInline 模式,且满足 pipeline 流水并行大于1的要求。

参数:
  • controller_ip (str) - TTP controller 的IP地址, 该参数用于启动TTP的controller。

  • controller_port (int) - TTP controller 的IP端口, 该参数用于启动TTP的controller和processor。

  • ckpt_save_path (str) - 异常发生时ckpt保存的路径,该路径是一个目录,ckpt的异常保存时会在该录下创建新的名为‘ttp_saved_checkpoints-{cur_epoch_num}_{cur_step_num}’目录。

异常:
  • Exception - TTP 初始化失败,会对外抛Exception异常。

  • ModuleNotFoundError - Mindio TTP whl 包未安装。

样例:

>>> import numpy as np
>>> import os
>>> import math
>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> from mindspore import nn, ops, Parameter, train
>>> from mindspore.communication import init
>>> from mindspore.common.initializer import initializer, HeUniform
>>> from mindspore.train import Model, MindIOTTPAdapter
>>> from mindspore import dataset as ds
>>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
>>> init()
>>> ms.set_seed(1)
>>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file":
>>>                             "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
>>> class MatMulCell(nn.Cell):
...     def __init__(self, param=None, shape=None):
...         super().__init__()
...         if shape is None:
...             shape = [28 * 28, 512]
...         weight_init = HeUniform(math.sqrt(5))
...         self.param = Parameter(initializer(weight_init, shape), name="param")
...         if param is not None:
...             self.param = param
...         self.print = ops.Print()
...         self.matmul = ops.MatMul()
...
...     def construct(self, x):
...         out = self.matmul(x, self.param)
...         self.print("out is:", out)
...         return out
>>>
>>> class Network(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.flatten = nn.Flatten()
...         self.layer1 = MatMulCell()
...         self.relu1 = nn.ReLU()
...         self.layer2 = nn.Dense(512, 512)
...         self.relu2 = nn.ReLU()
...         self.layer3 = nn.Dense(512, 10)
...
...     def construct(self, x):
...         x = self.flatten(x)
...         x = self.layer1(x)
...         x = self.relu1(x)
...         x = self.layer2(x)
...         x = self.relu2(x)
...         logits = self.layer3(x)
...         return logits
>>>
>>> net = Network()
>>> net.layer1.pipeline_stage = 0
>>> net.relu1.pipeline_stage = 0
>>> net.layer2.pipeline_stage = 0
>>> net.relu2.pipeline_stage = 1
>>> net.layer3.pipeline_stage = 1
>>>
>>> def create_dataset(batch_size):
...     dataset_path = os.getenv("DATA_PATH")
...     dataset = ds.MnistDataset(dataset_path)
...     image_transforms = [
...         ds.vision.Rescale(1.0 / 255.0, 0),
...         ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
...         ds.vision.HWC2CHW()
...     ]
...     label_transform = ds.transforms.TypeCast(ms.int32)
...     dataset = dataset.map(image_transforms, 'image')
...     dataset = dataset.map(label_transform, 'label')
...     dataset = dataset.batch(batch_size)
...     return dataset
>>>
>>> data_set = create_dataset(32)
>>>
>>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
>>> loss_fn = nn.CrossEntropyLoss()
>>>
>>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
>>> net_with_loss.set_train()
>>> model = Model(net_with_loss, optimizer=optimizer)
>>> ttp_cb = MindIOTTPAdapter("192.168.0.1", 2000, "./ttp_checkpoint/")
>>> loss_cb = train.LossMonitor(1)
>>> model.train(1, dataset, callbacks=[ttp_cb, loss_cb])
static load_checkpoint_with_backup(ckpt_file_path, strategy_file_path, net)[源代码]

加载指定的checkpoint文件到网络中,如果配置的checkpoint文件没有,基于strategy文件获取备份的checkpoint进行加载。

说明

该接口必须在通信初始化后调用,因为内部需要获取集群的信息。

参数:
  • ckpt_file_path (str) - 需要加载的checkpoint文件。

  • strategy_file_path (str) - 当前卡的strategy 文件。

  • net (Cell) - 需要加载权重的网络。

返回:

Dict,加载后的checkpoint权重。

异常:
  • ValueError - 加载checkpoint文件失败。

样例:

>>> import numpy as np
>>> from mindspore import nn
>>> from mindspore.train import Model, MindIOTTPAdapter
>>> from mindspore import dataset as ds
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
>>> init()
>>> ms.set_seed(1)
>>> class Network(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.flatten = nn.Flatten()
...         self.fc = nn.Dense(28*28, 10, weight_init="normal", bias_init="zeros")
...         self.relu = nn.ReLU()
...
...     def construct(self, x):
...         x = self.flatten(x)
...         logits = self.relu(self.fc(x))
...         return logits
>>>
>>> net = Network()
>>>
>>> def create_dataset(batch_size):
...     dataset_path = os.getenv("DATA_PATH")
...     rank_id = get_rank()
...     rank_size = get_group_size()
...     dataset = ds.MnistDataset(dataset_path, num_shards=rank_size, shard_id=rank_id)
...     image_transforms = [
...         ds.vision.Rescale(1.0 / 255.0, 0),
...         ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
...         ds.vision.HWC2CHW()
...     ]
...     label_transform = ds.transforms.TypeCast(ms.int32)
...     dataset = dataset.map(image_transforms, 'image')
...     dataset = dataset.map(label_transform, 'label')
...     dataset = dataset.batch(batch_size)
...     return dataset
>>> data_set = create_dataset(32)
>>> ckpt_file= "./rank_5/iteration-1_40.ckpt"
>>> strategy_file = "./src_pipeline_strategys/src_strategy_5.ckpt"
>>> param_dict = MindIOTTPAdapter.load_checkpoint_with_backup(ckpt_file, stragegy_file, net)
>>> data_set.set_init_step(param_dict["global_step"])
on_train_step_end(run_context)[源代码]

在第一次step完成进行MindIO TTP的初始化, 每个step完成时进行MindIO TTP的上报。

参数:
wrapper_ttp_persist(func)[源代码]

对出传入的函数进行TTP异常处理的封装。

参数:
  • func (function) - 需要封装的训练函数

返回:

Function: 如果TTP使能,则返回封装后的函数,否则返回原函数。