Source code for mindspore.train.callback._mindio_ttp

# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Checkpoint related classes and functions."""

import os
import sys
import copy
from mindspore.train.serialization import save_checkpoint, _convert_cell_param_and_names_to_dict, _get_merged_param_data
from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
from mindspore.parallel._utils import _get_device_num
from mindspore import _checkparam as Validator
from mindspore.train.callback._callback import Callback
from mindspore.common.tensor import Tensor
from mindspore import context
import mindspore as ms
from mindspore.communication import get_rank
from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters

from mindspore.train._utils import get_parameter_redundancy
from mindspore import log as logger
from mindspore.parallel._utils import _is_in_auto_parallel_mode
from mindspore.common.api import _get_parameter_layout


def _get_dp_from_layout(parameter_layout_dict):
    """ Get dp and tp from layout dict. """
    pp_num = _get_auto_parallel_context("pipeline_stages")
    dev_num = _get_device_num()
    global_rank = get_rank()
    pipe_size = dev_num // pp_num
    initial_rank = (global_rank // pipe_size) * pipe_size
    parameter_redundancy_dict = get_parameter_redundancy(
        parameter_layout_dict, initial_rank)
    value_len = sys.maxsize
    min_value = ()
    for key, value in parameter_redundancy_dict.items():
        if "accu_grads" in key or "inputs" in key:
            continue
        for item in value:
            if len(item) < value_len and global_rank in item:
                value_len = len(item)
                min_value = item
    return min_value


def _get_ckpt_dir(append_dict, ckpt_save_path, is_tmp_file):
    """ Common func to generate ckpt dir name."""
    tmp = "_tmp" if is_tmp_file else ""
    mid_dir = f"ttp_saved_checkpoints-{str(append_dict['cur_epoch_num'])}_{str(append_dict['cur_step_num'])}{tmp}"
    return os.path.join(ckpt_save_path, mid_dir)


def _flush_from_cache(cb_params):
    """ Flush cache data to host if tensor is cache enable."""
    params = cb_params.train_network.get_parameters()
    for param in params:
        if param.cache_enable:
            Tensor(param).flush_from_cache()


def _save_checkpoint_on_failure(save_rank, step, rank_list, save_args):
    """ Callback used for TTP save ckpt function when errors occur."""
    logger.info("Enter _save_checkpoint_on_failure function")
    ckpt_save_path, save_params, append_dict = save_args
    ckpt_file = f"iteration-{str(append_dict['cur_epoch_num'])}_{str(append_dict['cur_step_num'])}.ckpt"
    cur_ckpt_dir = _get_ckpt_dir(
        append_dict, ckpt_save_path, True) + "/rank_" + str(save_rank)
    os.makedirs(cur_ckpt_dir)
    cur_file = os.path.join(cur_ckpt_dir, ckpt_file)
    save_checkpoint(save_params, cur_file,
                    integrated_save=False, append_dict=append_dict)
    logger.info("Finish _save_checkpoint_on_failure function")


def _convert_net_to_param_list(save_obj):
    """Convert nn.Cell to param_list."""
    sync_pipeline_shared_parameters(save_obj)
    param_list = []
    parameter_layout_dict = save_obj.parameter_layout_dict
    if _is_in_auto_parallel_mode() and not parameter_layout_dict:
        parameter_layout_dict = _get_parameter_layout()
    if not _is_in_auto_parallel_mode():
        save_obj.init_parameters_data()
    param_dict = _convert_cell_param_and_names_to_dict(save_obj, None)
    for (key, value) in param_dict.items():
        each_param = {"name": key}
        param_data = Tensor(value.asnumpy())
        # in automatic model parallel scenario, some parameters were split to all the devices,
        # which should be combined before saving
        if key in parameter_layout_dict:
            param_data = _get_merged_param_data(
                save_obj, parameter_layout_dict, key, param_data, False)
        each_param["data"] = param_data
        param_list.append(each_param)
    return param_list


def _rename_save_result(rename_args):
    """ Callback used for TTP rename function after ckpt save callback was finished and successful."""
    logger.info("Enter _rename_save_result function")
    ckpt_save_path, _, append_dict = rename_args

    tmp_dir = _get_ckpt_dir(append_dict, ckpt_save_path, True)
    fin_dir = _get_ckpt_dir(append_dict, ckpt_save_path, False)

    os.rename(tmp_dir, fin_dir)
    logger.info("Finish _rename_save_result function")


[docs]class MindIOTTPAdapter(Callback): """ This callback is used to enable the feature `MindIO TTP <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc1/mindio/mindiottp/mindiottp001.html>`_. This callback will execute TTP operations during training process, such as TTP init, report and exception handle. Note: Required for Ascend GE LazyInline mode only. And pipline size must be greater than 1. Args: controller_ip (str): TTP controller's ip address, used for init TTP controller. controller_port (int): TTP controller's ip port, used for init TTP controller and processor. ckpt_save_path (str): Checkpoint save directory when failure occurs, checkpoint file will save to directory named ttp_saved_checkpoints-{cur_epoch_num}_{cur_step_num} under this directory. Raises: Exception: TTP init failed. ModuleNotFoundError: Mindio TTP whl package is not installed. Examples: >>> import numpy as np >>> import os >>> import math >>> import mindspore as ms >>> import mindspore.dataset as ds >>> from mindspore import nn, ops, Parameter, train >>> from mindspore.communication import init >>> from mindspore.common.initializer import initializer, HeUniform >>> from mindspore.train import Model, MindIOTTPAdapter >>> from mindspore import dataset as ds >>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2') >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2) >>> init() >>> ms.set_seed(1) >>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file": >>> "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())}) >>> class MatMulCell(nn.Cell): ... def __init__(self, param=None, shape=None): ... super().__init__() ... if shape is None: ... shape = [28 * 28, 512] ... weight_init = HeUniform(math.sqrt(5)) ... self.param = Parameter(initializer(weight_init, shape), name="param") ... if param is not None: ... self.param = param ... self.print = ops.Print() ... self.matmul = ops.MatMul() ... ... def construct(self, x): ... out = self.matmul(x, self.param) ... self.print("out is:", out) ... return out >>> >>> class Network(nn.Cell): ... def __init__(self): ... super().__init__() ... self.flatten = nn.Flatten() ... self.layer1 = MatMulCell() ... self.relu1 = nn.ReLU() ... self.layer2 = nn.Dense(512, 512) ... self.relu2 = nn.ReLU() ... self.layer3 = nn.Dense(512, 10) ... ... def construct(self, x): ... x = self.flatten(x) ... x = self.layer1(x) ... x = self.relu1(x) ... x = self.layer2(x) ... x = self.relu2(x) ... logits = self.layer3(x) ... return logits >>> >>> net = Network() >>> net.layer1.pipeline_stage = 0 >>> net.relu1.pipeline_stage = 0 >>> net.layer2.pipeline_stage = 0 >>> net.relu2.pipeline_stage = 1 >>> net.layer3.pipeline_stage = 1 >>> >>> def create_dataset(batch_size): ... dataset_path = os.getenv("DATA_PATH") ... dataset = ds.MnistDataset(dataset_path) ... image_transforms = [ ... ds.vision.Rescale(1.0 / 255.0, 0), ... ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), ... ds.vision.HWC2CHW() ... ] ... label_transform = ds.transforms.TypeCast(ms.int32) ... dataset = dataset.map(image_transforms, 'image') ... dataset = dataset.map(label_transform, 'label') ... dataset = dataset.batch(batch_size) ... return dataset >>> >>> data_set = create_dataset(32) >>> >>> optimizer = nn.SGD(net.trainable_params(), 1e-2) >>> loss_fn = nn.CrossEntropyLoss() >>> >>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4) >>> net_with_loss.set_train() >>> model = Model(net_with_loss, optimizer=optimizer) >>> ttp_cb = MindIOTTPAdapter("192.168.0.1", 2000, "./ttp_checkpoint/") >>> loss_cb = train.LossMonitor(1) >>> model.train(1, dataset, callbacks=[ttp_cb, loss_cb]) """ def __init__(self, controller_ip, controller_port, ckpt_save_path): super(MindIOTTPAdapter, self).__init__() # let it raises errors if not install mindio_ttp package from mindio_ttp import framework_ttp as ttp self.ttp = ttp Validator.check_non_negative_int(controller_port) self.has_init = False self.enable = False mode = context.get_context("mode") if context.get_context("device_target") != "Ascend" or mode != context.GRAPH_MODE: logger.warning( "MindIO adataper only support on Ascend device with GRAPH Mode.") return if os.getenv("MS_ENABLE_MINDIO_GRACEFUL_EXIT") != "true": logger.warning("MindIO adataper need custom switch on.") return ttp_lib_path = os.getenv("MS_MINDIO_TTP_LIB_PATH") if ttp_lib_path is None or os.path.isfile(ttp_lib_path) is False: logger.warning( "MindIO adataper switch on, but ttp library path is not correct.") return self.enable = True self._controller_ip = controller_ip self._controller_port = controller_port self._ckpt_save_path = ckpt_save_path
[docs] def wrapper_ttp_persist(self, func): """ This method is used to wrapper TTP exception handler for the input func. Args: func (function): train method that need to be wrapper. Returns: Function, if the TTP is enabled, return the encapsulated function, otherwise the original function is returned. """ if self.enable: return self.ttp.ttp_to_persist(func) return func
def _init_ttp(self, run_context): """ Init Mindio TTP, used internal. """ logger.info("Begin to init ttp.") dev_num = _get_device_num() cb_params = run_context.original_args() param_layout_dict = cb_params.train_network.parameter_layout_dict dp = _get_dp_from_layout(param_layout_dict) logger.info("Init TTP with dp: {}.".format(dp)) self.ttp.ttp_register_save_ckpt_handler(_save_checkpoint_on_failure) self.ttp.ttp_register_rename_handler(_rename_save_result) world_size = dev_num cur_rank = get_rank() is_odd = len(dp) % 2 replica = 2 if is_odd else len(dp) // 2 enable_local_copy = False if cur_rank == 0: logger.info("Begin to start ttp controller.") self.ttp.ttp_init_controller( cur_rank, world_size, replica, enable_local_copy) self.ttp.ttp_start_controller( self._controller_ip, self._controller_port) logger.info("Finish start ttp controller.") logger.info("Begin to start ttp processor.") self.ttp.ttp_init_processor(cur_rank, dp, len( dp), world_size, replica, enable_local_copy) self.ttp.ttp_start_processor( self._controller_ip, self._controller_port) logger.info("Finished start ttp processor.") logger.info("Finish init ttp.")
[docs] def on_train_step_end(self, run_context): """ Init TTP Controller only once after first step finished. And report status to MindIO TTP after every step finished. Args: run_context (RunContext): Context of the train running. Refer to :class:`mindspore.train.RunContext` for detail. """ if self.enable is False: return pp_num = _get_auto_parallel_context("pipeline_stages") if pp_num < 2: self.enable = False return cb_params = run_context.original_args() if cb_params.dataset_sink_mode is True and cb_params.sink_size > 1: self.enable = False return if self.has_init is False: self.has_init = True self._init_ttp(run_context) _flush_from_cache(cb_params) cur_rank = get_rank() append_dict = {} append_dict["cur_epoch_num"] = cb_params.cur_epoch_num append_dict["cur_step_num"] = int( (cb_params.cur_step_num - 1) % cb_params.batch_num + 1) append_dict["cur_rank"] = cur_rank append_dict["batch_num"] = cb_params.batch_num append_dict["global_step"] = cb_params.cur_step_num save_params = _convert_net_to_param_list(cb_params.train_network) save_params_copy = copy.deepcopy(save_params) logger.info("Set ckpt args to TTP.") self.ttp.ttp_set_ckpt_args( (self._ckpt_save_path, save_params_copy, append_dict)) logger.info("Set optimizer finish step status to TTP.") self.ttp.ttp_end_updating_os(cb_params.cur_step_num)
[docs] @staticmethod def load_checkpoint_with_backup(ckpt_file_path, strategy_file_path, net): """ Load checkpoint into network, and use strategy file to find backup checkpoint file when origin checkpoint file not found. Note: This API must be called after the communication is initialized because the cluster information needs to be obtained internally. Args: ckpt_file_path (str): the checkpoint file to be loaded. strategy_file_path (str): strategy file path for current rank. net (Cell): network that needs to load checkpoint. Returns: Dict, checkpoint weights after loaded. Raises: ValueError: Failed to load the checkpoint file. Examples: >>> import numpy as np >>> from mindspore import nn >>> from mindspore.train import Model, MindIOTTPAdapter >>> from mindspore import dataset as ds >>> ms.set_context(mode=ms.GRAPH_MODE) >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True) >>> init() >>> ms.set_seed(1) >>> class Network(nn.Cell): ... def __init__(self): ... super().__init__() ... self.flatten = nn.Flatten() ... self.fc = nn.Dense(28*28, 10, weight_init="normal", bias_init="zeros") ... self.relu = nn.ReLU() ... ... def construct(self, x): ... x = self.flatten(x) ... logits = self.relu(self.fc(x)) ... return logits >>> >>> net = Network() >>> >>> def create_dataset(batch_size): ... dataset_path = os.getenv("DATA_PATH") ... rank_id = get_rank() ... rank_size = get_group_size() ... dataset = ds.MnistDataset(dataset_path, num_shards=rank_size, shard_id=rank_id) ... image_transforms = [ ... ds.vision.Rescale(1.0 / 255.0, 0), ... ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), ... ds.vision.HWC2CHW() ... ] ... label_transform = ds.transforms.TypeCast(ms.int32) ... dataset = dataset.map(image_transforms, 'image') ... dataset = dataset.map(label_transform, 'label') ... dataset = dataset.batch(batch_size) ... return dataset >>> data_set = create_dataset(32) >>> ckpt_file= "./rank_5/iteration-1_40.ckpt" >>> strategy_file = "./src_pipeline_strategys/src_strategy_5.ckpt" >>> param_dict = MindIOTTPAdapter.load_checkpoint_with_backup(ckpt_file, stragegy_file, net) >>> data_set.set_init_step(param_dict["global_step"]) """ logger.info("Start load checkpoint with strategy file.") try: param_dict = ms.load_checkpoint(ckpt_file_path) except ValueError as e: logger.warning( "Loading origin checkpoint file failed, the reason is:{}.".format(str(e))) dp = _get_dp_from_layout(strategy_file_path) rank = get_rank() logger.info( "Can't load origin checkpoint file, found dp:{}.".format(dp)) for i in dp: if i == rank: continue new_ckpt = ckpt_file_path.replace( f"/rank_{rank}/", f"/rank_{str(i)}/") if not os.path.isfile(new_ckpt): continue try: param_dict = ms.load_checkpoint(new_ckpt) except ValueError as e1: logger.warning( "Loading strategy checkpoint file failed, the reason is:{}.".format(str(e1))) param_dict = None if param_dict: logger.info("Found param dict, load it into network.") ms.load_param_into_net(net, param_dict) else: raise ValueError( "Load checkpoint file failed, please check your config is set correctly.") logger.info("Finish load checkpoint with strategy file.") return param_dict