mindspore.train.MindIOTTPAdapter

View Source On Gitee
class mindspore.train.MindIOTTPAdapter(controller_ip, controller_port, ckpt_save_path)[source]

This callback is used to enable the feature MindIO TTP. This callback will execute TTP operations during training process, such as TTP init, report and exception handle.

Note

Required for Ascend GE LazyInline mode only. And pipline size must be greater than 1.

Parameters
  • controller_ip (str) – TTP controller's ip address, used for init TTP controller.

  • controller_port (int) – TTP controller's ip port, used for init TTP controller and processor.

  • ckpt_save_path (str) – Checkpoint save directory when failure occurs, checkpoint file will save to directory named ttp_saved_checkpoints-{cur_epoch_num}_{cur_step_num} under this directory.

Raises

Examples

>>> import numpy as np
>>> import os
>>> import math
>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> from mindspore import nn, ops, Parameter, train
>>> from mindspore.communication import init
>>> from mindspore.common.initializer import initializer, HeUniform
>>> from mindspore.train import Model, MindIOTTPAdapter
>>> from mindspore import dataset as ds
>>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
>>> init()
>>> ms.set_seed(1)
>>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file":
>>>                             "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
>>> class MatMulCell(nn.Cell):
...     def __init__(self, param=None, shape=None):
...         super().__init__()
...         if shape is None:
...             shape = [28 * 28, 512]
...         weight_init = HeUniform(math.sqrt(5))
...         self.param = Parameter(initializer(weight_init, shape), name="param")
...         if param is not None:
...             self.param = param
...         self.print = ops.Print()
...         self.matmul = ops.MatMul()
...
...     def construct(self, x):
...         out = self.matmul(x, self.param)
...         self.print("out is:", out)
...         return out
>>>
>>> class Network(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.flatten = nn.Flatten()
...         self.layer1 = MatMulCell()
...         self.relu1 = nn.ReLU()
...         self.layer2 = nn.Dense(512, 512)
...         self.relu2 = nn.ReLU()
...         self.layer3 = nn.Dense(512, 10)
...
...     def construct(self, x):
...         x = self.flatten(x)
...         x = self.layer1(x)
...         x = self.relu1(x)
...         x = self.layer2(x)
...         x = self.relu2(x)
...         logits = self.layer3(x)
...         return logits
>>>
>>> net = Network()
>>> net.layer1.pipeline_stage = 0
>>> net.relu1.pipeline_stage = 0
>>> net.layer2.pipeline_stage = 0
>>> net.relu2.pipeline_stage = 1
>>> net.layer3.pipeline_stage = 1
>>>
>>> def create_dataset(batch_size):
...     dataset_path = os.getenv("DATA_PATH")
...     dataset = ds.MnistDataset(dataset_path)
...     image_transforms = [
...         ds.vision.Rescale(1.0 / 255.0, 0),
...         ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
...         ds.vision.HWC2CHW()
...     ]
...     label_transform = ds.transforms.TypeCast(ms.int32)
...     dataset = dataset.map(image_transforms, 'image')
...     dataset = dataset.map(label_transform, 'label')
...     dataset = dataset.batch(batch_size)
...     return dataset
>>>
>>> data_set = create_dataset(32)
>>>
>>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
>>> loss_fn = nn.CrossEntropyLoss()
>>>
>>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
>>> net_with_loss.set_train()
>>> model = Model(net_with_loss, optimizer=optimizer)
>>> ttp_cb = MindIOTTPAdapter("192.168.0.1", 2000, "./ttp_checkpoint/")
>>> loss_cb = train.LossMonitor(1)
>>> model.train(1, dataset, callbacks=[ttp_cb, loss_cb])
static load_checkpoint_with_backup(ckpt_file_path, strategy_file_path, net)[source]

Load checkpoint into network, and use strategy file to find backup checkpoint file when origin checkpoint file not found.

Note

This API must be called after the communication is initialized because the cluster information needs to be obtained internally.

Parameters
  • ckpt_file_path (str) – the checkpoint file to be loaded.

  • strategy_file_path (str) – strategy file path for current rank.

  • net (Cell) – network that needs to load checkpoint.

Returns

Dict, checkpoint weights after loaded.

Raises

ValueError – Failed to load the checkpoint file.

Examples

>>> import numpy as np
>>> from mindspore import nn
>>> from mindspore.train import Model, MindIOTTPAdapter
>>> from mindspore import dataset as ds
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
>>> init()
>>> ms.set_seed(1)
>>> class Network(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.flatten = nn.Flatten()
...         self.fc = nn.Dense(28*28, 10, weight_init="normal", bias_init="zeros")
...         self.relu = nn.ReLU()
...
...     def construct(self, x):
...         x = self.flatten(x)
...         logits = self.relu(self.fc(x))
...         return logits
>>>
>>> net = Network()
>>>
>>> def create_dataset(batch_size):
...     dataset_path = os.getenv("DATA_PATH")
...     rank_id = get_rank()
...     rank_size = get_group_size()
...     dataset = ds.MnistDataset(dataset_path, num_shards=rank_size, shard_id=rank_id)
...     image_transforms = [
...         ds.vision.Rescale(1.0 / 255.0, 0),
...         ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
...         ds.vision.HWC2CHW()
...     ]
...     label_transform = ds.transforms.TypeCast(ms.int32)
...     dataset = dataset.map(image_transforms, 'image')
...     dataset = dataset.map(label_transform, 'label')
...     dataset = dataset.batch(batch_size)
...     return dataset
>>> data_set = create_dataset(32)
>>> ckpt_file= "./rank_5/iteration-1_40.ckpt"
>>> strategy_file = "./src_pipeline_strategys/src_strategy_5.ckpt"
>>> param_dict = MindIOTTPAdapter.load_checkpoint_with_backup(ckpt_file, stragegy_file, net)
>>> data_set.set_init_step(param_dict["global_step"])
on_train_step_end(run_context)[source]

Init TTP Controller only once after first step finished. And report status to MindIO TTP after every step finished.

Parameters

run_context (RunContext) – Context of the train running. Refer to mindspore.train.RunContext for detail.

wrapper_ttp_persist(func)[source]

This method is used to wrapper TTP exception handler for the input func.

Parameters

func (function) – train method that need to be wrapper.

Returns

Function, if the TTP is enabled, return the encapsulated function, otherwise the original function is returned.