Source code for mindspore.train.callback._checkpoint

# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Checkpoint related classes and functions."""
from __future__ import absolute_import

import os
import stat
import time

import mindspore.context as context
from mindspore import log as logger
from mindspore import nn
from mindspore import _checkparam as Validator
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import save_checkpoint, _save_graph, _wait_async_process_save_ckpt, \
    _wait_async_thread_save_ckpt, _check_async_save
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
from mindspore.parallel._utils import _get_device_num
from mindspore.communication.management import get_rank
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
from mindspore.train.callback._callback import Callback, set_cur_net
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common.generator import Generator
from mindspore.common.api import _cell_graph_executor
from mindspore._c_expression import collect_host_info, get_clock_syscnt

_cur_dir = os.getcwd()
SAVE_DIR = _cur_dir
_info_list = ["epoch_num", "step_num"]


def _get_dp_tp_from_redundancy(redundancy_tuple):
    """From redundancy get dp and tp"""
    dp = []
    tp = []
    for dp_value in redundancy_tuple:
        dp.append(list(dp_value))
    for i in range(len(redundancy_tuple[0])):
        tp.append([v[i] for v in redundancy_tuple])
    return dp, tp


def _get_dp_tp_from_layout(parameter_redundancy_dict):
    """From layout dict get dp and tp"""
    tp = []
    dp = []
    value_len = 0
    for _, value in parameter_redundancy_dict.items():
        if len(value) > value_len:
            value_len = len(value)
            dp, tp = _get_dp_tp_from_redundancy(value)
    return dp, tp


def _wait_async_save_ckpt(async_save=False):
    """Waiting for asynchronous saving of ckpt to complete."""
    if async_save:
        if async_save == "process":
            _wait_async_process_save_ckpt()
        else:
            _wait_async_thread_save_ckpt()


def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
    """Check if there is a file with the same name."""
    if callable(prefix) or callable(directory):
        return prefix
    files = os.listdir(directory)
    suffix_num = 0
    pre_len = len(prefix)
    for filename in files:
        name_ext = os.path.splitext(filename)
        if exception and filename[-16:] != "_breakpoint.ckpt":
            continue
        if not exception and (name_ext[-1] != ".ckpt" or filename[-16:] == "_breakpoint.ckpt"):
            continue
        # find same prefix file
        if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
            # add the max suffix + 1
            index = filename[pre_len:].find("-")
            if index == 0:
                suffix_num = max(suffix_num, 1)
            elif index != -1:
                num = filename[pre_len + 1:pre_len + index]
                if num.isdigit():
                    suffix_num = max(suffix_num, int(num) + 1)

    if suffix_num != 0:
        prefix = f'{prefix}_{suffix_num}'

    return prefix


def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, exception_save=False,
                                   map_param_inc=False, global_step_num=None):
    param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
                         or exception_save or map_param_inc or global_step_num is not None)
    if format == "safetensors" and param_not_default:
        raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")


[docs]class CheckpointConfig: """ The configuration of model checkpoint. Note: - During the training process, if dataset is transmitted through the data channel, it is suggested to set 'save_checkpoint_steps' to an integer multiple of loop_size. Otherwise, the time to save the checkpoint may be biased. It is recommended to set only one save strategy and one keep strategy at the same time. If both `save_checkpoint_steps` and `save_checkpoint_seconds` are set, `save_checkpoint_seconds` will be invalid. If both `keep_checkpoint_max` and `keep_checkpoint_per_n_minutes` are set, `keep_checkpoint_per_n_minutes` will be invalid. - The `enc_mode` and `crc_check` parameters are mutually exclusive and cannot be configured simultaneously. Args: save_checkpoint_steps (int): Steps to save checkpoint. Default: ``1`` . save_checkpoint_seconds (int): Seconds to save checkpoint. Can't be used with save_checkpoint_steps at the same time. Default: ``0`` . keep_checkpoint_max (int): Maximum number of checkpoint files can be saved. Default: ``5`` . keep_checkpoint_per_n_minutes (int): Save the checkpoint file every `keep_checkpoint_per_n_minutes` minutes. Can't be used with keep_checkpoint_max at the same time. Default: ``0`` . integrated_save (bool): Whether to merge and save the split Tensor in the automatic parallel scenario. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. Default: ``True`` . async_save (Union[bool, str]):Whether to use asynchronous saving of the checkpoint file, if True, the asynchronous thread is used by default. If the type is string, the method of asynchronous saving, it can be "process" or "thread". Default: ``False`` . saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation with the network in training, the initial value of saved_network will be saved. Default: ``None`` . append_info (list): The information save to checkpoint file. Support "epoch_num", "step_num" and dict. The key of dict must be str, the value of dict must be one of int, float, bool, Parameter or Tensor. Default: ``None`` . enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption is not required. Default: ``None`` . enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption mode, currently supports 'AES-GCM', 'AES-CBC' and 'SM4-CBC'. Default: ``'AES-GCM'`` . exception_save (bool): Whether to save the current checkpoint when an exception occurs. Default: ``False`` . crc_check (bool): Whether to perform crc32 calculation when saving checkpoint and save the calculation result to the end of ckpt. Default: ``False`` . remove_redundancy (bool): Whether to enable saving the checkpoint with redundancy removal. Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means redundant-free saving is not enabled. format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt". kwargs (dict): Configuration options dictionary. Raises: ValueError: If input parameter is not the correct type. Examples: >>> from mindspore import nn >>> from mindspore.train import Model, CheckpointConfig, ModelCheckpoint >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim) >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> config = CheckpointConfig(save_checkpoint_seconds=100, keep_checkpoint_per_n_minutes=5, saved_network=net) >>> config.save_checkpoint_steps 1 >>> config.save_checkpoint_seconds >>> config.keep_checkpoint_max 5 >>> config.keep_checkpoint_per_n_minutes >>> config.integrated_save True >>> config.async_save False >>> config.saved_network >>> config.enc_key >>> config.enc_mode 'AES-GCM' >>> config.append_dict >>> config.get_checkpoint_policy >>> ckpoint_cb = ModelCheckpoint(prefix='LeNet5', directory='./checkpoint', config=config) >>> model.train(10, dataset, callbacks=ckpoint_cb) """ def __init__(self, save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True, async_save=False, saved_network=None, append_info=None, enc_key=None, enc_mode='AES-GCM', exception_save=False, crc_check=False, remove_redundancy=False, format="ckpt", **kwargs): if save_checkpoint_steps is not None: save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps) if save_checkpoint_seconds is not None: save_checkpoint_seconds = Validator.check_non_negative_int(save_checkpoint_seconds) if keep_checkpoint_max is not None: keep_checkpoint_max = Validator.check_non_negative_int(keep_checkpoint_max) if keep_checkpoint_per_n_minutes is not None: keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes) if saved_network is not None and not isinstance(saved_network, nn.Cell): raise TypeError(f"For 'CheckpointConfig', the type of 'saved_network' must be None or Cell, " f"but got {str(type(saved_network))}.") if not save_checkpoint_steps and not save_checkpoint_seconds and \ not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: raise ValueError("For 'CheckpointConfig', the input arguments 'save_checkpoint_steps', " "'save_checkpoint_seconds', " "'keep_checkpoint_max' and 'keep_checkpoint_per_n_minutes' can't be all None or 0.") Validator.check_bool(exception_save) self.exception_save = exception_save self._save_checkpoint_steps = save_checkpoint_steps self._save_checkpoint_seconds = save_checkpoint_seconds if self._save_checkpoint_steps and self._save_checkpoint_steps > 0: self._save_checkpoint_seconds = None self._keep_checkpoint_max = keep_checkpoint_max self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes if self._keep_checkpoint_max and self._keep_checkpoint_max > 0: self._keep_checkpoint_per_n_minutes = None else: if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: self._keep_checkpoint_max = 1 self._integrated_save = Validator.check_bool(integrated_save) self._async_save = _check_async_save(async_save) self._saved_network = saved_network self._append_dict = self._handle_append_info(append_info) self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) self._crc_check = Validator.check_isinstance('crc_check', crc_check, bool) self._format = Validator.check_isinstance('format', format, str) self._map_param_inc = kwargs.get('incremental', False) self.enable_redundance = kwargs.get('enable_redundance', False) self.remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool) _check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, exception_save, self._map_param_inc) @property def save_checkpoint_steps(self): """ Get the value of steps to save checkpoint. Returns: int, steps to save checkpoint. """ return self._save_checkpoint_steps @property def save_checkpoint_seconds(self): """Get the value of _save_checkpoint_seconds. Returns: int, seconds to save the checkpoint file. """ return self._save_checkpoint_seconds @property def keep_checkpoint_max(self): """ Get the value of maximum number of checkpoint files can be saved. Returns: int, Maximum number of checkpoint files can be saved. """ return self._keep_checkpoint_max @property def keep_checkpoint_per_n_minutes(self): """ Get the value of save the checkpoint file every n minutes. Returns: Int, save the checkpoint file every n minutes. """ return self._keep_checkpoint_per_n_minutes @property def integrated_save(self): """ Get the value of whether to merge and save the split Tensor in the automatic parallel scenario. Returns: bool, whether to merge and save the split Tensor in the automatic parallel scenario. """ return self._integrated_save @property def async_save(self): """ Get the value of whether or how asynchronous execution saves the checkpoint to a file. Returns: (bool, str), whether or how asynchronous execution saves the checkpoint to a file. """ return self._async_save @property def saved_network(self): """ Get the value of network to be saved in checkpoint file. Returns: Cell, network to be saved in checkpoint file. """ return self._saved_network @property def enc_key(self): """ Get the value of byte type key used for encryption. Returns: (None, bytes), byte type key used for encryption. """ return self._enc_key @property def enc_mode(self): """ Get the value of the encryption mode. Returns: str, encryption mode. """ return self._enc_mode @property def crc_check(self): """ Get the value of the whether to enable crc check. Returns: bool, whether to enable crc check. """ return self._crc_check @property def format(self): return self._format @property def append_dict(self): """ Get the value of information dict saved to checkpoint file. Returns: dict, the information saved to checkpoint file. """ return self._append_dict @property def map_param_inc(self): """ Get the value of whether to save map Parameter incrementally. Returns: bool, whether to save map Parameter incrementally. """ return self._map_param_inc
[docs] def get_checkpoint_policy(self): """ Get the policy of checkpoint. Returns: dict, the information of checkpoint policy. """ checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps, 'save_checkpoint_seconds': self.save_checkpoint_seconds, 'keep_checkpoint_max': self.keep_checkpoint_max, 'keep_checkpoint_per_n_minutes': self.keep_checkpoint_per_n_minutes, 'saved_network': self.saved_network} return checkpoint_policy
@staticmethod def _handle_append_info(append_info): """Handle ckpt append info.""" if append_info is None or append_info == []: return None if not isinstance(append_info, list): raise TypeError(f"For 'CheckpointConfig', the type of 'append_info' must be list," f"but got {str(type(append_info))}.") handle_append_info = {} if "epoch_num" in append_info: handle_append_info["epoch_num"] = 0 if "step_num" in append_info: handle_append_info["step_num"] = 0 if "random_op" in append_info: handle_append_info["random_op"] = 0 dict_num = 0 for element in append_info: if not isinstance(element, str) and not isinstance(element, dict): raise TypeError(f"For 'CheckpointConfig', the type of 'append_info' element must be str or dict," f"but got {str(type(element))}.") if isinstance(element, str) and element not in _info_list: raise ValueError(f"For 'CheckpointConfig', the value of element in the argument 'append_info' " f"must be in {_info_list}, " f"but got {element}.") if isinstance(element, dict): dict_num += 1 if dict_num > 1: raise TypeError(f"For 'CheckpointConfig', the element of 'append_info' must has only one dict, " "but got {dict_num}") for key, value in element.items(): if isinstance(key, str) and isinstance(value, (int, float, bool, str, Parameter, Tensor, Generator)): handle_append_info[key] = value else: raise TypeError(f"For 'CheckpointConfig', the key type of the dict 'append_info' " f"must be string, the value type must be int or float or bool, " f"but got key type {type(key)}, value type {type(value)}") return handle_append_info
[docs]class ModelCheckpoint(Callback): """ The checkpoint callback class. It is called to combine with train process and save the model and network parameters after training. Note: In the distributed training scenario, please specify different directories for each training process to save the checkpoint file. Otherwise, the training may fail. If this callback is used in the `model` function, the checkpoint file will saved parameters of the optimizer by default. Args: prefix (Union[str, callable object]): The prefix name or callable object to generate name of checkpoint files. Default: ``'CKP'`` . directory (Union[str, callable object]): The folder path where the checkpoint is stored, or the callable object used to generate the path. By default, the file is saved in the current directory. Default: ``None`` . config (CheckpointConfig): Checkpoint strategy configuration. Default: ``None`` . Raises: ValueError: If `prefix` is not str or contains the '/' character and is not a callable object. ValueError: If `directory` is not str and is not a callable object. TypeError: If the config is not CheckpointConfig type. Examples: >>> import numpy as np >>> import mindspore.dataset as ds >>> from mindspore import nn >>> from mindspore.train import Model, ModelCheckpoint >>> >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))} >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32) >>> net = nn.Dense(10, 5) >>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9) >>> ckpt_callback = ModelCheckpoint(prefix="myckpt") >>> model = Model(network=net, optimizer=opt, loss_fn=crit) >>> model.train(2, train_dataset, callbacks=[ckpt_callback]) """ def __init__(self, prefix='CKP', directory=None, config=None): super(ModelCheckpoint, self).__init__() self._latest_ckpt_file_name = "" self._init_time = time.time() self._last_time = time.time() self._last_time_for_keep = time.time() self._last_triggered_step = 0 """a callable for users to set self-defined prefix.""" self._prefix_func = None """a callable for users to set self-defined directory.""" self._directory_func = None if not callable(prefix) and (not isinstance(prefix, str) or prefix.find('/') >= 0): raise ValueError("For 'ModelCheckpoint', the argument 'prefix' " "for checkpoint file name is invalid, it must be " "callable or string that does not contain '/', but got {}.".format(prefix)) self._prefix = prefix self._exception_prefix = prefix if directory is not None: if callable(directory): self._directory_func = directory else: self._directory = _make_directory(directory) else: self._directory = _cur_dir if callable(prefix): self._prefix_func = prefix if _get_recovery_context("enable_recovery"): _set_recovery_context(ckpt_path=self._directory) if config is None: self._config = CheckpointConfig() else: if not isinstance(config, CheckpointConfig): raise TypeError("For 'ModelCheckpoint', the type of argument 'config' should be " "'CheckpointConfig', " "but got {}.".format(type(config))) self._config = config self._aiturbo_init_flag = os.getenv("AITURBO") == "1" # get existing checkpoint files if self._aiturbo_init_flag: from aiturbo.checkpoint.aiturbo_mindspore_ckpt import CheckpointShmManager self._manager = CheckpointShmManager() else: self._manager = CheckpointManager(self._config.format) if not callable(directory) and not callable(prefix): self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) self._append_dict = self._config.append_dict or {} self._append_epoch_num = self._append_dict.get("epoch_num") if "epoch_num" in self._append_dict else 0 self._append_step_num = self._append_dict.get("step_num") if "step_num" in self._append_dict else 0 self._graph_saved = False self._need_flush_from_cache = True self._map_param_inc = self._config.map_param_inc
[docs] def step_end(self, run_context): """ Save the checkpoint at the end of step. Args: run_context (RunContext): Context of the train running. """ cb_params = run_context.original_args() if self._aiturbo_init_flag: from aiturbo.checkpoint import aiturbo_mindspore as aiturbo ckpt_storage_path = self._directory rank_id = get_rank() stage_num = _get_auto_parallel_context("pipeline_stages") stage_rank_num = _get_device_num() // stage_num param_layout = cb_params.train_network.parameter_layout_dict if not param_layout: layout = {"stage_num": stage_num, "stage_rank_num": stage_rank_num, "stage_layout": None} aiturbo.init(ckpt_storage_path, rank_id, layout, None, False, None) else: device_num = _get_device_num() chunk_size = device_num // stage_num initial_rank = (rank_id // chunk_size) * chunk_size param_redundancy_dict = get_parameter_redundancy(param_layout, initial_rank) dp, _ = _get_dp_tp_from_layout(param_redundancy_dict) layout = {"stage_num": stage_num, "stage_rank_num": stage_rank_num, "stage_layout": param_redundancy_dict} single_params = remove_param_redundancy(param_redundancy_dict) single_params = {device_id: list(params) for device_id, params in single_params.items()} aiturbo.init(ckpt_storage_path, rank_id, layout, single_params, not self._config.enable_redundance, dp) self._aiturbo_init_flag = False if self._prefix_func: self._prefix = self._prefix_func(cb_params) if not isinstance(self._prefix, str) or self._prefix.find('/') >= 0: raise ValueError("For 'ModelCheckpoint', the argument 'prefix' " "for checkpoint file name is callable, it must return a " "string that does not contain '/', but got {}.".format(self._prefix)) if self._directory_func: self._directory = self._directory_func(cb_params) _make_directory(self._directory) collect_host_info("Callback", "ModelCheckpoint", "step_end", start_time=get_clock_syscnt(), level=1) # In disaster recovery scenario, the training process may be rolled back to the last step where # the ckpt was successfully saved, so the _last_triggered_step should be updated. if _get_recovery_context("enable_recovery") and cb_params.last_save_ckpt_step is not None: self._last_triggered_step = cb_params.last_save_ckpt_step cb_params.last_save_ckpt_step = None # save graph (only once) if not self._graph_saved: graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE: os.remove(graph_file_name) _save_graph(cb_params.train_network, graph_file_name) self._graph_saved = True self._save_ckpt(cb_params)
[docs] def end(self, run_context): """ Save the last checkpoint after training finished. Args: run_context (RunContext): Context of the train running. """ cb_params = run_context.original_args() collect_host_info("Callback", "ModelCheckpoint", "end", start_time=get_clock_syscnt(), level=1) _to_save_last_ckpt = True self._save_ckpt(cb_params, _to_save_last_ckpt) _wait_async_save_ckpt(self._config.async_save) destroy_allgather_cell()
def _check_save_ckpt(self, cb_params, force_to_save): """Check whether save checkpoint files or not.""" if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0: if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \ or force_to_save is True: return True elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0: self._cur_time = time.time() if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save: self._last_time = self._cur_time return True return False def _append_dict_content(self, epoch_num, step_num): """Append append_dict content.""" if "epoch_num" in self._append_dict: self._append_dict["epoch_num"] = self._append_epoch_num + epoch_num if "step_num" in self._append_dict: self._append_dict["step_num"] = self._append_step_num + step_num def _save_ckpt(self, cb_params, force_to_save=False): """Save checkpoint files.""" if cb_params.cur_step_num == self._last_triggered_step: return # if param is cache enable, flush data from cache to host before save_ckpt if self._need_flush_from_cache: self._flush_from_cache(cb_params) save_ckpt = self._check_save_ckpt(cb_params, force_to_save) step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) if save_ckpt: _wait_async_save_ckpt(self._config.async_save) if self._prefix_func: cur_ckpoint_file = self._prefix + f".{self._config.format}" else: cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ + str(step_num_in_epoch) + f".{self._config.format}" # update checkpoint file list. self._manager.update_ckpoint_filelist(self._directory, self._prefix) # keep checkpoint files number equal max number. if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: self._manager.remove_oldest_ckpoint_file() elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: self._cur_time_for_keep = time.time() if (self._cur_time_for_keep - self._last_time_for_keep) \ < self._config.keep_checkpoint_per_n_minutes * 60: self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes, self._cur_time_for_keep) # generate the new checkpoint file and rename it. global SAVE_DIR SAVE_DIR = self._directory cur_file = os.path.join(self._directory, cur_ckpoint_file) self._last_time_for_keep = time.time() self._last_triggered_step = cb_params.cur_step_num # TODO(MS_DISABLE_REF_MODE): Delete when remove MS_DISABLE_REF_MODE env. if context.get_context("enable_ge") and os.getenv('MS_DISABLE_REF_MODE') \ and context.get_context("mode") == context.GRAPH_MODE: set_cur_net(cb_params.train_network) cb_params.train_network.add_flags(ge_sync_data=True) _cell_graph_executor(cb_params.train_network, phase='save') self._append_dict_content(cb_params.cur_epoch_num, cb_params.cur_step_num) network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network if os.getenv("AITURBO") == "1": save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save, self._append_dict, self._config.enc_key, self._config.enc_mode, crc_check=self._config.crc_check, incremental=self._map_param_inc, global_step_num=cb_params.cur_step_num) elif self._config.remove_redundancy: parallel_mode = context.get_auto_parallel_context("parallel_mode") if parallel_mode == "stand_alone": raise TypeError(f"The deduplication feature for saving checkpoint can only be used " f"in parallel scenarios, but got {parallel_mode}.") param_layout = network.parameter_layout_dict rank_id = get_rank() if param_layout: device_num = _get_device_num() stage_num = _get_auto_parallel_context("pipeline_stages") chunk_size = device_num // stage_num initial_rank = (rank_id // chunk_size) * chunk_size param_redundancy_dict = get_parameter_redundancy(param_layout, initial_rank) single_params = remove_param_redundancy(param_redundancy_dict) save_param_names = single_params.get(rank_id) param_layout_set = set(param_layout.keys()) if save_param_names == param_layout.keys(): logger.warning( f"For remove_redundancy save checkpoint, the saved parameters are non-redundant.") def choice_func(x): return x not in param_layout_set or (save_param_names is not None and x in save_param_names) else: param_redundancy_dict = get_parameter_redundancy(network) single_params = remove_param_redundancy(param_redundancy_dict) save_param_names = single_params.get(rank_id) def choice_func(x): return save_param_names is not None and x in save_param_names save_checkpoint(network, cur_file, False, self._config.async_save, self._append_dict, self._config.enc_key, self._config.enc_mode, crc_check=self._config.crc_check, format=self._config.format, incremental=self._map_param_inc, choice_func=choice_func) else: save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save, self._append_dict, self._config.enc_key, self._config.enc_mode, crc_check=self._config.crc_check, format=self._config.format, incremental=self._map_param_inc) self._latest_ckpt_file_name = cur_file def _flush_from_cache(self, cb_params): """Flush cache data to host if tensor is cache enable.""" has_cache_params = False params = cb_params.train_network.get_parameters() for param in params: if param.cache_enable: has_cache_params = True Tensor(param).flush_from_cache() if not has_cache_params: self._need_flush_from_cache = False @property def latest_ckpt_file_name(self): """Return the latest checkpoint path and file name.""" return self._latest_ckpt_file_name @property def _get_save_checkpoint_steps(self): """Return save ckpt steps""" return self._config.save_checkpoint_steps @property def _get_last_trigger_step(self): """Return last triggered steps""" return self._last_triggered_step
class CheckpointManager: """Manage checkpoint files according to train_config of checkpoint.""" def __init__(self, format='ckpt'): self._ckpoint_filelist = [] self._format = format @property def ckpoint_filelist(self): """Get all the related checkpoint files managed here.""" return self._ckpoint_filelist @property def ckpoint_num(self): """Get the number of the related checkpoint files managed here.""" return len(self._ckpoint_filelist) def update_ckpoint_filelist(self, directory, prefix): """Update the checkpoint file list.""" self._ckpoint_filelist = [] format = self._format format_length = len(format) + 1 files = os.listdir(directory) for filename in files: if os.path.splitext(filename)[-1] == f".{format}" and filename.startswith(prefix + "-"): mid_name = filename[len(prefix):-format_length] flag = not (True in [char.isalpha() for char in mid_name]) if flag: self._ckpoint_filelist.append(os.path.join(directory, filename)) def remove_ckpoint_file(self, file_name): """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" try: os.chmod(file_name, stat.S_IWRITE) os.remove(file_name) self._ckpoint_filelist.remove(file_name) except OSError: logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) except ValueError: logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) def remove_oldest_ckpoint_file(self): """Remove the oldest checkpoint file from this checkpoint manager and also from the directory.""" ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime) self.remove_ckpoint_file(ckpoint_files[0]) def keep_one_ckpoint_per_minutes(self, minutes, cur_time): """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" del_list = [] oldest_file = '' oldest_time = cur_time for ck_file in self._ckpoint_filelist: modify_time = os.path.getmtime(ck_file) if cur_time - modify_time < 60 * minutes: del_list.append(ck_file) if modify_time < oldest_time: oldest_time = modify_time oldest_file = ck_file for mv_file in del_list: if mv_file == oldest_file: continue self.remove_ckpoint_file(mv_file)