.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_modelarts.svg
    :target: https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9yMi4yL3R1dG9yaWFscy96aF9jbi9hZHZhbmNlZC9tb2RlbC9taW5kc3BvcmVfbW9kZWwuaXB5bmI=&imageid=4c43b3ad-9df7-4b83-a096-c775dc4ba243
.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_notebook.svg
    :target: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/tutorials/zh_cn/advanced/model/mindspore_model.ipynb
.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_download_code.svg
    :target: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/tutorials/zh_cn/advanced/model/mindspore_model.py
.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_source.svg
    :target: https://gitee.com/mindspore/docs/blob/r2.2/tutorials/source_zh_cn/advanced/model/model.ipynb
    :alt: 查看源文件

高阶封装:Model
===============

.. toctree::
  :maxdepth: 1
  :hidden:

  model/callback
  model/metric

通常情况下,定义训练和评估网络并直接运行,已经可以满足基本需求。

一方面,\ ``Model``\ 可以在一定程度上简化代码。例如:无需手动遍历数据集;在不需要自定义\ ``nn.TrainOneStepCell``\ 的场景下,可以借助\ ``Model``\ 自动构建训练网络;可以使用\ ``Model``\ 的\ ``eval``\ 接口进行模型评估,直接输出评估结果,无需手动调用评价指标的\ ``clear``\ 、\ ``update``\ 、\ ``eval``\ 函数等。

另一方面,\ ``Model``\ 提供了很多高阶功能,如数据下沉、混合精度等,在不借助\ ``Model``\ 的情况下,使用这些功能需要花费较多的时间仿照\ ``Model``\ 进行自定义。

本文档首先对MindSpore的Model进行基本介绍,然后重点讲解如何使用\ ``Model``\ 进行模型训练、评估和推理。

.. figure:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/tutorials/source_zh_cn/advanced/model/images/model.png
   :alt: model


.. code:: python

    import mindspore
    from mindspore import nn
    from mindspore.dataset import vision, transforms
    from mindspore.dataset import MnistDataset
    from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, LossMonitor


Model基本介绍
-------------

`Model <https://www.mindspore.cn/docs/zh-CN/r2.2/api_python/train/mindspore.train.Model.html#mindspore.train.Model>`__\ 是MindSpore提供的高阶API,可以进行模型训练、评估和推理。其接口的常用参数如下:

-  ``network``\ :用于训练或推理的神经网络。
-  ``loss_fn``\ :所使用的损失函数。
-  ``optimizer``\ :所使用的优化器。
-  ``metrics``\ :用于模型评估的评价函数。
-  ``eval_network``\ :模型评估所使用的网络,未定义情况下,\ ``Model``\ 会使用\ ``network``\ 和\ ``loss_fn``\ 进行封装。

``Model``\ 提供了以下接口用于模型训练、评估和推理:

-  ``fit``\ :边训练边评估模型。
-  ``train``\ :用于在训练集上进行模型训练。
-  ``eval``\ :用于在验证集上进行模型评估。
-  ``predict``\ :用于对输入的一组数据进行推理,输出预测结果。

使用Model接口
~~~~~~~~~~~~~

对于简单场景的神经网络,可以在定义\ ``Model``\ 时指定前向网络\ ``network``\ 、损失函数\ ``loss_fn``\ 、优化器\ ``optimizer``\ 和评价函数\ ``metrics``\ 。

下载并处理数据集
----------------

.. code:: python

    # Download data from open datasets
    from download import download
    
    url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
          "notebook/datasets/MNIST_Data.zip"
    path = download(url, "./", kind="zip", replace=True)
    
    
    def datapipe(path, batch_size):
        image_transforms = [
            vision.Rescale(1.0 / 255.0, 0),
            vision.Normalize(mean=(0.1307,), std=(0.3081,)),
            vision.HWC2CHW()
        ]
        label_transform = transforms.TypeCast(mindspore.int32)
        
        dataset = MnistDataset(path)
        dataset = dataset.map(image_transforms, 'image')
        dataset = dataset.map(label_transform, 'label')
        dataset = dataset.batch(batch_size)
        return dataset
    
    train_dataset = datapipe('MNIST_Data/train', 64)
    test_dataset = datapipe('MNIST_Data/test', 64)


创建模型
--------

.. code:: python

    # Define model
    class Network(nn.Cell):
        def __init__(self):
            super().__init__()
            self.flatten = nn.Flatten()
            self.dense_relu_sequential = nn.SequentialCell(
                nn.Dense(28*28, 512),
                nn.ReLU(),
                nn.Dense(512, 512),
                nn.ReLU(),
                nn.Dense(512, 10)
            )
    
        def construct(self, x):
            x = self.flatten(x)
            logits = self.dense_relu_sequential(x)
            return logits
    
    model = Network()

定义损失函数和优化器
--------------------

要训练神经网络模型,需要定义损失函数和优化器函数。

-  损失函数这里使用交叉熵损失函数\ ``CrossEntropyLoss``\ 。
-  优化器这里使用\ ``SGD``\ 。

.. code:: python

    # Instantiate loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = nn.SGD(model.trainable_params(), 1e-2)

训练及保存模型
--------------

在开始训练之前,MindSpore需要提前声明网络模型在训练过程中是否需要保存中间过程和结果,因此使用\ ``ModelCheckpoint``\ 接口用于保存网络模型和参数,以便进行后续的Fine-tuning(微调)操作。

.. code:: python

    steps_per_epoch = train_dataset.get_dataset_size()
    config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch)
    
    ckpt_callback = ModelCheckpoint(prefix="mnist", directory="./checkpoint", config=config)
    loss_callback = LossMonitor(steps_per_epoch)

通过MindSpore提供的\ ``model.fit``\ 接口可以方便地进行网络的训练与评估,\ ``LossMonitor``\ 可以监控训练过程中\ ``loss``\ 值的变化。

.. code:: python

    trainer = Model(model, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'})
    
    trainer.fit(10, train_dataset, test_dataset, callbacks=[ckpt_callback, loss_callback])


.. raw:: html

    <div class="highlight"><pre>
    epoch: 1 step: 938, loss is 0.602992594242096
    Eval result: epoch 1, metrics: {'accuracy': 0.8435}
    epoch: 2 step: 938, loss is 0.2797124981880188
    Eval result: epoch 2, metrics: {'accuracy': 0.9003}
    epoch: 3 step: 938, loss is 0.32015785574913025
    Eval result: epoch 3, metrics: {'accuracy': 0.9179}
    epoch: 4 step: 938, loss is 0.17153620719909668
    Eval result: epoch 4, metrics: {'accuracy': 0.9308}
    epoch: 5 step: 938, loss is 0.18772485852241516
    Eval result: epoch 5, metrics: {'accuracy': 0.9382}
    epoch: 6 step: 938, loss is 0.45641791820526123
    Eval result: epoch 6, metrics: {'accuracy': 0.946}
    epoch: 7 step: 938, loss is 0.11519066989421844
    Eval result: epoch 7, metrics: {'accuracy': 0.9506}
    epoch: 8 step: 938, loss is 0.43486487865448
    Eval result: epoch 8, metrics: {'accuracy': 0.9555}
    epoch: 9 step: 938, loss is 0.1941455900669098
    Eval result: epoch 9, metrics: {'accuracy': 0.9588}
    epoch: 10 step: 938, loss is 0.13441434502601624
    Eval result: epoch 10, metrics: {'accuracy': 0.9632}
    </pre></div>


训练过程中会打印loss值,loss值会波动,但总体来说loss值会逐步减小,精度逐步提高。每个人运行的loss值有一定随机性,不一定完全相同。

通过模型运行测试数据集得到的结果,验证模型的泛化能力:

1. 使用\ ``model.eval``\ 接口读入测试数据集。
2. 使用保存后的模型参数进行推理。

.. code:: python

    acc = trainer.eval(test_dataset)
    acc

.. raw:: html

    <div class="highlight"><pre>
    {'accuracy': 0.9632}
    </pre></div>



可以在打印信息中看出模型精度数据,示例中精度数据达到95%以上,模型质量良好。随着网络迭代次数增加,模型精度会进一步提高。