Fault Recovery
Overview
Faults may be encountered during model training. The overhead of restarting the training with various resources is huge. For this purpose, MindSpore provides a fault recovery scheme, i.e., periodically saving the model parameters, which allows the model to recover quickly and continue training at the point of failure. MindSpore saves the model parameters in a step or epoch cycle. The model parameters are saved in CheckPoint (ckpt for short) files. If a fault occurs during model training, load the latest saved model parameters, restore the state here, and continue training.
This document describes the use case for fault recovery, saving the CheckPoint file only at the end of each epoch.
Data and Model Preparation
To provide a complete experience, the fault recovery process is simulated here by using the MNIST dataset and the LeNet5 network. You can skip this section if you are ready.
Data Preparation
Download the MNIST dataset and unzip the dataset to the project directory.
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)
Model Definition
import os
import mindspore
from mindspore.common.initializer import Normal
from mindspore.dataset import MnistDataset, vision
from mindspore import nn
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, Callback
import mindspore.dataset.transforms as transforms
mindspore.set_context(mode=mindspore.GRAPH_MODE)
# Create a training dataset
def create_dataset(data_path, batch_size=32):
train_dataset = MnistDataset(data_path, shuffle=False)
image_transfroms = [
vision.Rescale(1.0 / 255.0, 0),
vision.Resize(size=(32, 32)),
vision.HWC2CHW()
]
train_dataset = train_dataset.map(image_transfroms, input_columns='image')
train_dataset = train_dataset.map(transforms.TypeCast(mindspore.int32), input_columns='label')
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
return train_dataset
# Load the training dataset
data_path = "MNIST_Data/train"
train_dataset = create_dataset(data_path)
# Fault during the simulation training
class myCallback(Callback):
def __init__(self, break_epoch_num=6):
super(myCallback, self).__init__()
self.epoch_num = 0
self.break_epoch_num = break_epoch_num
def on_train_epoch_end(self, run_context):
self.epoch_num += 1
if self.epoch_num == self.break_epoch_num:
raise Exception("Some errors happen.")
class LeNet5(nn.Cell):
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode="valid")
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode="valid")
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
net = LeNet5() # Model initialization
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") # Loss function
optim = nn.Momentum(net.trainable_params(), 0.01, 0.9) # Optimizer
model = Model(net, loss_fn=loss, optimizer=optim) # Model encapsulation
Periodically Saving CheckPoint Files
Configuring CheckpointConfig
mindspore.train.CheckpointConfig
can be configured according to the number of iterations, and the parameters for configuring the iteration strategy are as follows:
save_checkpoint_steps
: indicates how many steps to save a CheckPoint file. The default value is 1.keep_checkpoint_max
: indicates the maximum number of CheckPoint files to be saved. The default value is 5.
If the iteration strategy script ends normally, the CheckPoint file of the last step is saved by default.
During model training, the callbacks
parameter in Model.train
is used to pass in the object ModelCheckpoint
of saving model (used in conjunction with mindspore.train.CheckpointConfig
), which generates CheckPoint file.
User-defined Saved Data
The parameter append_info
of CheckpointConfig
can save user-defined information in the CheckPoint file. append_info
supports passing in epoch_num
, step_num
and data of dictionary type. epoch_num
and step_num
can save the number of epochs and the number of steps during training in the CheckPoint file.
key
of the dictionary type data must be of type string, and value
must be of type int, float, bool, string, Parameter, or Tensor.
# User-defined saved data
append_info = ["epoch_num", "step_num", {"lr": 0.01, "momentum": 0.9}]
# In the data sinking mode, the CheckPoint file of the last step is saved by default
config_ck = CheckpointConfig(append_info=append_info)
# The CheckPoint file is saved with the prefix "lenet" and is saved in ". /lenet" path
ckpoint_cb = ModelCheckpoint(prefix='lenet', directory='./lenet', config=config_ck)
# Simulation program fault. The default is to fail at the end of the 6th epoch
my_callback = myCallback()
# In the data sinking mode, 10 epoch training is performed by using Model.train
model.train(10, train_dataset, callbacks=[ckpoint_cb, my_callback], dataset_sink_mode=True)
User-defined Script to Find the Latest CheckPoint File
The program fails at the end of the 6th epoch. After the failure, the . /lenet
directory holds the CheckPoint files for the latest generated 5 epochs.
└── lenet
├── lenet-graph.meta # Compiled compute graph
├── lenet-2_1875.ckpt # CheckPoint files with the suffix '.ckpt'
├── lenet-3_1875.ckpt # The naming of the file indicates the number of epochs and steps where the parameters are stored. Here is the model parameters of the 1875th step of the 3rd epoch
├── lenet-4_1875.ckpt
├── lenet-5_1875.ckpt
└── lenet-6_1875.ckpt
If the user runs the training script multiple times using the same prefix name, a CheckPoint file with the same name may be generated. MindSpore adds "_" and a number after the user-defined prefix to make it easier for users to distinguish between the files generated each time. If you want to delete the .ckpt file, please delete the .meta file at the same time. For example:
lenet_3-2_1875.ckpt
indicates the CheckPoint file for the 1875th step of the 2nd epoch generated by running the fourth script.
Users can use user-defined scripts to find the latest saved CheckPoint files.
ckpt_path = "./lenet"
filenames = os.listdir(ckpt_path)
# Filter all CheckPoint file names
ckptnames = [ckpt for ckpt in filenames if ckpt.endswith(".ckpt")]
# Sort CheckPoint file names from oldest to newest in order of creation
ckptnames.sort(key=lambda ckpt: os.path.getctime(ckpt_path + "/" + ckpt))
# Get the latest CheckPoint file path
ckpt_file = ckpt_path + "/" + ckptnames[-1]
Recovery Training
Loading CheckPoint File
Use the load_checkpoint
and load_param_into_net
methods to load the latest saved CheckPoint file.
The
load_checkpoint
method will load the network parameters from the CheckPoint file into the dictionary param_dict.The
load_param_into_net
method will load the parameters from the dictionary param_dict into the network or optimizer, and the parameters in the network after loading are the ones saved in the CheckPoint file.
# Load the model parameters into param_dict. Here the model parameters saved during training and the user-defined saved data are loaded
param_dict = mindspore.load_checkpoint(ckpt_file)
net = LeNet5()
# Load the parameters into the model
mindspore.load_param_into_net(net, param_dict)
Obtaining the User-defined Data
The user can obtain the number of epochs and user-defined saved data from the CheckPoint file for training. Note that the data obtained at this point is of type Parameter.
epoch_num = int(param_dict["epoch_num"].asnumpy())
step_num = int(param_dict["step_num"].asnumpy())
lr = float(param_dict["lr"].asnumpy())
momentum = float(param_dict["momentum"].asnumpy())
Setting the Epoch for Continued Training
Pass the number of obtained epochs to the initial_epoch
parameter of Model.train
. The network will continue training from that epoch. In this case, the epoch
parameter of Model.train
indicates the last epoch of training.
model.train(10, train_dataset, callbacks=ckpoint_cb, initial_epoch=epoch_num, dataset_sink_mode=True)
Training Ends
At the end of the training, . /lenet
directory generates 4 new CheckPoint files. Based on the CheckPoint file names, it can be seen that the model is retrained at the 7th epoch and ends at the 10th epoch after the failure occurs. The fault recovery is successful.
└── lenet
├── lenet-graph.meta
├── lenet-2_1875.ckpt
├── lenet-3_1875.ckpt
├── lenet-4_1875.ckpt
├── lenet-5_1875.ckpt
├── lenet-6_1875.ckpt
├── lenet-1-7_1875.ckpt
├── lenet-1-8_1875.ckpt
├── lenet-1-9_1875.ckpt
├── lenet-1-10_1875.ckpt
└── lenet-1-graph.meta