# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Checkpoint related classes and functions."""
import os
from mindspore.utils import _tft_handler
from mindspore.train.serialization import save_checkpoint
from mindspore.train.callback._callback import Callback
from mindspore import context, ops
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.communication import get_rank, get_group_size
from mindspore import log as logger
from mindspore.train.serialization import _get_cur_rank_dp
from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post, _tft_sem_enable
from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm
from mindspore._c_expression import clean_tdt_channel
from mindspore._c_expression import send_recv, reset_params
from mindspore._c_expression import CollectiveManager
from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
from mindspore._c_expression import TensorPy as Tensor_
from mindspore.ops.operations.manually_defined._inner import TensorReport
import mindspore
import mindspore.common.dtype as mstype
from mindspore.parallel._recovery_context import _set_recovery_context
def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
""" Common func to generate ckpt dir name."""
tmp = "_tmp" if is_tmp_file else ""
mid_dir = f"tft_saved_checkpoints-step_{str(step)}{tmp}"
return os.path.join(ckpt_save_path, mid_dir)
def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
""" Callback used for TFT save ckpt function when errors occur."""
logger.info("Enter _save_checkpoint_on_failure function")
if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
raise RuntimeError("Can't save parameters, because they are left in inconsistent state!")
cb_params = args
# we record the current step and epoch num in on_train_step_end, so we can just reset it here
cb_params.cur_step_num = cb_ctx.cur_step_num
cb_params.cur_epoch_num = cb_ctx.cur_epoch_num
if cb_params.optimizer is not None:
cb_params.optimizer.global_step = cb_ctx.global_step
if hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
cb_params.network.optimizer.global_step = cb_ctx.global_step
append_dict = {}
append_dict["__exception_save__"] = True
# if user has provided a custom save callback, use it
if cb_ctx.save_cb:
cb_ctx.save_cb(cb_params, append_dict)
logger.info("Finish _save_checkpoint_on_failure function")
return
# if user has not provided a custom save callback, use default save logic
ckpt_save_path = cb_ctx.ckpt_save_path
cur_rank = get_rank()
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
cur_epoch_num = cb_params.cur_epoch_num
append_dict["epoch_num"] = cur_epoch_num
append_dict["step_num"] = cb_params.cur_step_num
append_dict["cur_rank"] = cur_rank
append_dict["batch_num"] = cb_params.batch_num
append_dict["global_step"] = cb_ctx.global_step
outputs = cb_params.net_outputs
if isinstance(outputs, (tuple, list)) and len(outputs) >= 3:
append_dict["loss_scale"] = outputs[2]
ckpt_file = f"ttp_rank_{str(cur_rank)}-{str(cur_epoch_num)}_{str(step_num_in_epoch)}.ckpt"
cur_ckpt_dir = _get_ckpt_dir(step, ckpt_save_path, True) + "/rank_" + str(cur_rank)
os.makedirs(cur_ckpt_dir, exist_ok=True)
cur_file = os.path.join(cur_ckpt_dir, ckpt_file)
save_checkpoint(cb_params.train_network, cur_file,
integrated_save=False, append_dict=append_dict)
logger.info("Finish _save_checkpoint_on_failure function")
def _rename_save_result(step, cb_ctx):
""" Callback used for TFT rename function after ckpt save callback was finished and successful."""
logger.info("Enter _rename_save_result function")
if cb_ctx.save_cb:
logger.info("User's save callback is provided, skip rename")
return
tmp_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, True)
fin_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, False)
os.rename(tmp_dir, fin_dir)
logger.info("Finish _rename_save_result function")
def _tft_exit_cb(ctx):
"""Callback used for TFT exit function."""
logger.error("Enter mindio ttp exit process, which means other ranks occur exception, check other ranks' logs!")
_tft_sem_post()
os._exit(1) # pylint: disable=W0212
def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
""" Callback used for TFT repair function."""
logger.warning("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value)):
logger.warning("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
_repair_device(cb_ctx.device_id)
if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
cb_ctx.tft.RepairType.RT_SEND.value,
cb_ctx.tft.RepairType.RT_RECV_REPAIR.value)):
logger.warning("Enter _tft_repair_callback SEND_RECV repair type:{}, src_rank:{}, dst_rank: {}".format(
repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
cb_params = args
if repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value:
for i in range(len(repair_info["src"])):
src_rank = repair_info["src"][i]
dst_rank = repair_info["dst"][i]
if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
raise ValueError("Call send_recv failed.")
else:
src_rank = repair_info["src"][0]
dst_rank = repair_info["dst"][0]
if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
raise ValueError("Call send_recv failed.")
logger.warning("Finish _tft_repair_callback")
def _tft_clean_callback(is_uce_error, args, ctx):
""" Callback used for TFT clean function."""
logger.warning("Enter _tft_clean_callback")
ret = 0
if is_uce_error:
_get_uce_mem_info(ctx.device_id)
err_strategy = _get_uce_process_strategy()
logger.warning("_tft_clean_callback err_strategy: {}".format(err_strategy))
if err_strategy == "RS_UCE_HIGHLEVEL":
ret = 0
elif err_strategy == "RS_UCE_LOWLEVEL":
ret = 2
else:
ret = 1
clean_tdt_channel()
logger.warning("Enter _tft_clean_callback resume_hccl_comm")
CollectiveManager.get_instance().resume_hccl_comm()
logger.warning("Finish _tft_clean_callback, ret: {}".format(ret))
return ret
def _tft_stop_callback(args, cb_ctx):
""" Callback used for TFT stop function."""
logger.warning("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
_stop_device(cb_ctx.device_id)
if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
cb_ctx.is_uce_rank = False
if cb_ctx.tft.tft_get_repair_type() == "recover":
logger.warning(f"Reset limit step")
cb_ctx.tft.tft_reset_limit_step()
logger.info("Finish _tft_stop_callback")
def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
"""Callback used for TFT Rebuild Group function."""
logger.warning(f"Enter _tft_rebuild_sub_groups, device id: ".format(ctx.device_id))
_finalize_comm()
_rebuild_world_group()
_rebuild_sub_group()
_set_recovery_context(is_arf=True)
logger.warning("Enter _tft_rebuild_sub_groups ok ")
[文档]class TrainFaultTolerance(Callback):
"""
This callback is used to enable the TFT feature
`MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_
and will execute TFT operations during training process, such as TFT init, report and exception handle.
Note:
Required for Ascend graph mode only. And sink size must be less than or equal to 1.
Args:
ckpt_save_path (str): Checkpoint save directory when failure occurs. When saved,
a new directory named 'ttp_saved_checkpoints-step_{cur_step_num}'
is created in that directory. Default: ``None``.
kwargs (dict): Other dictionary type parameters.
Raises:
Exception: TFT init failed.
ModuleNotFoundError: Mindio TFT whl package is not installed.
Examples:
.. note::
Before running the following examples, you need to configure the communication environment variables.
It's recommended to use the msrun startup method.
Please see the `msrun start up
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
for more details.
This example should be run with 4 devices.
>>> import numpy as np
>>> import os
>>> import math
>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> from mindspore import nn, ops, Parameter, train
>>> from mindspore.communication import init, get_rank
>>> from mindspore.common.initializer import initializer, HeUniform
>>> from mindspore.train import Model, TrainFaultTolerance
>>> from mindspore import dataset as ds
>>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
>>> init()
>>> ms.set_seed(1)
>>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file":
... "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
>>> class MatMulCell(nn.Cell):
... def __init__(self, param=None, shape=None):
... super().__init__()
... if shape is None:
... shape = [28 * 28, 512]
... weight_init = HeUniform(math.sqrt(5))
... self.param = Parameter(initializer(weight_init, shape), name="param")
... if param is not None:
... self.param = param
... self.print = ops.Print()
... self.matmul = ops.MatMul()
...
... def construct(self, x):
... out = self.matmul(x, self.param)
... self.print("out is:", out)
... return out
>>>
>>> class Network(nn.Cell):
... def __init__(self):
... super().__init__()
... self.flatten = nn.Flatten()
... self.layer1 = MatMulCell()
... self.relu1 = nn.ReLU()
... self.layer2 = nn.Dense(512, 512)
... self.relu2 = nn.ReLU()
... self.layer3 = nn.Dense(512, 10)
...
... def construct(self, x):
... x = self.flatten(x)
... x = self.layer1(x)
... x = self.relu1(x)
... x = self.layer2(x)
... x = self.relu2(x)
... logits = self.layer3(x)
... return logits
>>>
>>> net = Network()
>>> net.layer1.pipeline_stage = 0
>>> net.relu1.pipeline_stage = 0
>>> net.layer2.pipeline_stage = 0
>>> net.relu2.pipeline_stage = 1
>>> net.layer3.pipeline_stage = 1
>>>
>>> def create_dataset(batch_size):
... dataset_path = os.getenv("DATA_PATH")
... dataset = ds.MnistDataset(dataset_path)
... image_transforms = [
... ds.vision.Rescale(1.0 / 255.0, 0),
... ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
... ds.vision.HWC2CHW()
... ]
... label_transform = ds.transforms.TypeCast(ms.int32)
... dataset = dataset.map(image_transforms, 'image')
... dataset = dataset.map(label_transform, 'label')
... dataset = dataset.batch(batch_size)
... return dataset
>>>
>>> dataset = create_dataset(32)
>>>
>>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
>>> optimizer_wrapper = nn.OptTFTWrapper(optimizer)
>>> loss_fn = nn.CrossEntropyLoss()
>>>
>>> net_with_loss = nn.Pipeline(nn.WithLossCell(net, loss_fn), 4)
>>> net_with_loss.set_train()
>>> model = Model(net_with_loss, optimizer=optimizer_wrapper)
>>> tft_cb = TrainFaultTolerance()
>>> loss_cb = train.LossMonitor(1)
>>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
"""
def __init__(self, ckpt_save_path=None, **kwargs):
super(TrainFaultTolerance, self).__init__()
self.save_cb = kwargs.get("ckpt_save_fn", None)
self.ckpt_save_path = ckpt_save_path
if self.save_cb is None and self.ckpt_save_path is None:
raise ValueError("TrainFaultTolerance construct need to set ckpt_save_fn or ckpt_save_path!")
self.tft = _tft_handler.get_tft()
self._check_init()
self.global_step = None
self.learning_rate = None
self.has_init_replica = False
self.is_uce_rank = False
self.cb_params = None
self.initial_step = kwargs.get("initial_step", 0)
self.device_id = context.get_context("device_id")
self.assign = mindspore.ops.Assign()
self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
self.s1 = mindspore.hal.Stream()
self.cur_step_num = 0
self.cur_epoch_num = 0
_tft_sem_enable()
self._tft_register()
def _check_init(self):
"""Check if the mindio-ttp had inited"""
if self.tft is None:
tft_env = os.getenv("MS_ENABLE_TFT", "")
if "ARF:1" in tft_env:
raise ValueError("Must init by _tft_handler.init(config=params) if use ARF.")
logger.warning(f"TFT handle not init, try to init")
_tft_handler.init(config=None)
self.tft = _tft_handler.get_tft()
logger.warning(f"TFT handle init ok.")
mode = context.get_context("mode")
device_target = context.get_context("device_target")
if device_target != "Ascend" or mode != context.GRAPH_MODE:
raise ValueError(f"MindIO adataper only support on Ascend device with GRAPH Mode!"
f"device:{device_target}, run mode: {mode}")
def _is_params_consistent(self):
for key, param in self.cb_params.train_network.parameters_and_names():
if "tft_g_one_flag" in key:
with mindspore.hal.StreamCtx(self.s1):
tft_g_one_flag = Tensor(Tensor_.move_to(param, "CPU", False))
self.s1.synchronize()
return int(tft_g_one_flag) == 1
return False
def _set_tft_optimizer_replica(self, run_context):
""" Set Mindio TFT optimizer replica info, used internal. """
cur_rank = get_rank()
cb_params = run_context.original_args()
train_network = cb_params.train_network
# in data_parallel mode, every ranks has same train parameters
if context.get_auto_parallel_context("parallel_mode") == "data_parallel":
group_size = get_group_size()
dp = tuple(range(group_size))
else:
param_layout_dict = train_network.parameter_layout_dict
dp = _get_cur_rank_dp(param_layout_dict) if param_layout_dict else _get_cur_rank_dp(train_network)
logger.warning(f"Set TFT replica with dp: {dp}.")
replica_info = [
{
"type": 1,
"rank_list": list(dp),
"replica_cnt": len(dp),
"replica_shift": 0
}
]
self.tft.tft_set_optimizer_replica(cur_rank, replica_info)
[文档] @classmethod
def get_optimizer_wrapper(cls, origin_opt_cls):
"""
Optimizer wrapper func when using tft.
Args:
origin_opt_cls (Class): origin optimizer class.
"""
class TFTOptSubCls(origin_opt_cls):
"""
Optimizer wrapper class when using tft.
"""
def __init__(self, *args, **kwargs):
super(TFTOptSubCls, self).__init__(*args, **kwargs)
self.report = TensorReport()
self.report_end = TensorReport()
self.report_end.add_prim_attr("side_effect_mem", True).add_prim_attr("optimizer_end", True)
self.depend = ops.Depend()
self.allreduce_sum = ops.AllReduce()
self.allreduce_sum.add_prim_attr("tft_report_before", True)
self.tft_g_one_flag = Parameter(Tensor([1], dtype=mstype.int32))
def construct(self, gradients, **kwargs):
tft_g_one_flag = self.depend(self.tft_g_one_flag, gradients)
self.tft_g_one_flag = self.allreduce_sum(tft_g_one_flag)
grads = self.depend(gradients, self.report("tft_report", self.tft_g_one_flag))
opt_ret = super(TFTOptSubCls, self).construct(grads, **kwargs)
self.report_end("tft_report", self.tft_g_one_flag)
return opt_ret
return TFTOptSubCls
def _tft_register(self):
"""Register callback functions."""
self.tft.tft_register_save_ckpt_handler(_save_checkpoint_on_failure, self)
self.tft.tft_register_rename_handler(_rename_save_result, self)
self.tft.tft_register_exit_handler(_tft_exit_cb, self)
self.tft.tft_register_stop_handler(_tft_stop_callback, self)
self.tft.tft_register_clean_handler(_tft_clean_callback, self)
self.tft.tft_register_repair_handler(_tft_repair_callback, self)
self.tft.tft_register_rebuild_group_handler(_tft_rebuild_sub_groups, self)
def _reset_acc_grads(self):
accu_grad_params = map(lambda e: e[1],
filter(lambda e: e[1].name.startswith('accu_grads'),
self.cb_params.train_network.parameters_and_names()))
accu_grad_list = list(accu_grad_params)
if reset_params(accu_grad_list) != 0:
raise ValueError("Call reset_params failed.")
[文档] def on_train_step_end(self, run_context):
"""
Report status to MindIO TFT after every step finished.
Args:
run_context (RunContext): Context of the train running. Refer to
:class:`mindspore.train.RunContext` for detail.
"""
if self.has_init_replica is False:
self.has_init_replica = True
self._set_tft_optimizer_replica(run_context)
cb_params = run_context.original_args()
logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
self.cur_step_num = cb_params.cur_step_num
self.cur_epoch_num = cb_params.cur_epoch_num
if cb_params.optimizer is not None:
self.global_step = cb_params.optimizer.global_step.clone()
self.assign(cb_params.optimizer.tft_g_one_flag, self.g_one)
elif hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
self.global_step = cb_params.network.optimizer.global_step.clone()
self.assign(cb_params.network.optimizer.tft_g_one_flag, self.g_one)
else:
raise ValueError("TFT feature need optimizer or network's optimizer!")
self.tft.tft_end_updating_os(cb_params.cur_step_num + self.initial_step)
logger.info("END Set optimizer finish step status to TFT.")
[文档] def on_train_begin(self, run_context):
"""
Register train params to MindIO TFT on train beginning.
Args:
run_context (RunContext): Context of the train running. Refer to
:class:`mindspore.train.RunContext` for detail.
"""
cb_params = run_context.original_args()
sink_size = cb_params.get("sink_size", 0)
if sink_size > 1:
raise ValueError("TFT feature doesn't support sink_size > 1.")
logger.info("Set set args to TFT.")
self.tft.tft_set_step_args(cb_params)
self.cb_params = cb_params
[文档] def end(self, run_context):
"""
Unregister MindIO TFT on train end.
Args:
run_context (RunContext): Context of the train running. Refer to
:class:`mindspore.train.RunContext` for detail.
"""
_tft_handler.unregister_tft()