Loading the Model from Hub
For Inference Validation
mindspore_hub.load
API is used to load the pre-trained model in a single line of code. The main process of model loading is as follows:
Search the model of interest on MindSpore Hub Website.
For example, if you aim to perform image classification on CIFAR-10 dataset using GoogleNet, please search on MindSpore Hub Website with the keyword
GoogleNet
. Then all related models will be returned. Once you enter into the related model page, you can get the websiteurl
.Complete the task of loading model using
url
, as shown in the example below:import mindspore_hub as mshub import mindspore from mindspore import context, Tensor, nn, Model from mindspore import dtype as mstype import mindspore.dataset.vision.py_transforms as py_transforms context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=0) model = "mindspore/ascend/0.7/googlenet_v1_cifar10" # Initialize the number of classes based on the pre-trained model. network = mshub.load(model, num_classes=10) network.set_train(False) # ...
After loading the model, you can use MindSpore to do inference. You can refer to Multi-Platform Inference Overview.
For Transfer Training
When loading a model with mindspore_hub.load
API, we can add an extra argument to load the feature extraction part of the model only. So we can easily add new layers to perform transfer learning. This feature can be found in the related model page when an extra argument (e.g., include_top) has been integrated into the model construction by the model developer. The value of include_top
is True or False, indicating whether to keep the top layer in the fully-connected network.
We use MobileNetV2 as an example to illustrate how to load a model trained on the ImageNet dataset and then perform transfer learning (re-training) on a specific sub-task dataset. The main steps are listed below:
Search the model of interest on MindSpore Hub Website and get the related
url
.Load the model from MindSpore Hub using the
url
. Note that the parameterinclude_top
is provided by the model developer.import os import mindspore_hub as mshub import mindspore from mindspore import context, Tensor, nn from mindspore.nn import Momentum from mindspore import save_checkpoint, load_checkpoint,load_param_into_net from mindspore import ops import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as C2 import mindspore.dataset.vision.c_transforms as C from mindspore import dtype as mstype from mindspore import Model context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=0) model = "mindspore/ascend/1.0/mobilenetv2_v1.0_openimage" network = mshub.load(model, num_classes=500, include_top=False, activation="Sigmoid") network.set_train(False)
Add a new classification layer into current model architecture.
class ReduceMeanFlatten(nn.Cell): def __init__(self): super(ReduceMeanFlatten, self).__init__() self.mean = ops.ReduceMean(keep_dims=True) self.flatten = nn.Flatten() def construct(self, x): x = self.mean(x, (2, 3)) x = self.flatten(x) return x # Check MindSpore Hub website to conclude that the last output shape is 1280. last_channel = 1280 # The number of classes in target task is 10. num_classes = 10 reducemean_flatten = ReduceMeanFlatten() classification_layer = nn.Dense(last_channel, num_classes) classification_layer.set_train(True) train_network = nn.SequentialCell([network, reducemean_flatten, classification_layer])
Define
dataset_loader
.As shown below, the new dataset used for fine-tuning is the CIFAR-10. It is noted here we need to download the
binary version
dataset. After downloading and decompression, the following code can be used for data loading and processing. It is noted thedataset_path
is the path to the dataset and should be given by the user.def create_cifar10dataset(dataset_path, batch_size, do_train): if do_train: usage, shuffle = "train", True else: usage, shuffle = "test", False data_set = ds.Cifar10Dataset(dataset_dir=dataset_path, usage=usage, shuffle=True) # define map operations trans = [C.Resize((256, 256))] if do_train: trans += [ C.RandomHorizontalFlip(prob=0.5), ] trans += [ C.Rescale(1.0 / 255.0, 0.0), C.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), C.HWC2CHW() ] type_cast_op = C2.TypeCast(mstype.int32) data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8) # apply batch operations data_set = data_set.batch(batch_size, drop_remainder=True) return data_set # Create Dataset dataset_path = "/path_to_dataset/cifar-10-batches-bin" dataset = create_cifar10dataset(dataset_path, batch_size=32, do_train=True)
Define
loss
,optimizer
andlearning rate
.def generate_steps_lr(lr_init, steps_per_epoch, total_epochs): total_steps = total_epochs * steps_per_epoch decay_epoch_index = [0.3*total_steps, 0.6*total_steps, 0.8*total_steps] lr_each_step = [] for i in range(total_steps): if i < decay_epoch_index[0]: lr = lr_init elif i < decay_epoch_index[1]: lr = lr_init * 0.1 elif i < decay_epoch_index[2]: lr = lr_init * 0.01 else: lr = lr_init * 0.001 lr_each_step.append(lr) return lr_each_step # Set epoch size epoch_size = 60 # Wrap the backbone network with loss. loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") loss_net = nn.WithLossCell(train_network, loss_fn) steps_per_epoch = dataset.get_dataset_size() lr = generate_steps_lr(lr_init=0.01, steps_per_epoch=steps_per_epoch, total_epochs=epoch_size) # Create an optimizer. optim = Momentum(filter(lambda x: x.requires_grad, classification_layer.get_parameters()), Tensor(lr, mindspore.float32), 0.9, 4e-5) train_net = nn.TrainOneStepCell(loss_net, optim)
Start fine-tuning.
for epoch in range(epoch_size): for i, items in enumerate(dataset): data, label = items data = mindspore.Tensor(data) label = mindspore.Tensor(label) loss = train_net(data, label) print(f"epoch: {epoch}/{epoch_size}, loss: {loss}") # Save the ckpt file for each epoch. if not os.path.exists('ckpt'): os.mkdir('ckpt') ckpt_path = f"./ckpt/cifar10_finetune_epoch{epoch}.ckpt" save_checkpoint(train_network, ckpt_path)
Eval on test set.
model = "mindspore/ascend/1.0/mobilenetv2_v1.0_openimage" network = mshub.load(model, num_classes=500, pretrained=True, include_top=False, activation="Sigmoid") network.set_train(False) reducemean_flatten = ReduceMeanFlatten() classification_layer = nn.Dense(last_channel, num_classes) classification_layer.set_train(False) softmax = nn.Softmax() network = nn.SequentialCell([network, reducemean_flatten, classification_layer, softmax]) # Load a pre-trained ckpt file. ckpt_path = "./ckpt/cifar10_finetune_epoch59.ckpt" trained_ckpt = load_checkpoint(ckpt_path) load_param_into_net(classification_layer, trained_ckpt) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") # Define loss and create model. eval_dataset = create_cifar10dataset(dataset_path, batch_size=32, do_train=False) eval_metrics = {'Loss': nn.Loss(), 'Top1-Acc': nn.Top1CategoricalAccuracy(), 'Top5-Acc': nn.Top5CategoricalAccuracy()} model = Model(network, loss_fn=loss, optimizer=None, metrics=eval_metrics) metrics = model.eval(eval_dataset) print("metric: ", metrics)