开发入门

下载Notebook下载样例代码查看源文件

因开发者可能会在OrangePi AIpro(下称:香橙派开发板)进行自定义模型和案例开发,本章节通过基于MindSpore的手写数字识别案例,说明香橙派开发板中的开发注意事项。

环境准备

开发者拿到香橙派开发板后,首先需要进行硬件资源确认、镜像烧录以及CANN和MindSpore版本的升级,才可运行该案例,具体如下:

  • 硬件:香橙派AIpro 16G 8-12T开发板

  • 镜像:香橙派官网Ubuntu镜像

  • CANN:8.0.RC3.alpha002

  • MindSpore:2.4.10

镜像烧录

运行该案例需要烧录香橙派官网Ubuntu镜像,参考镜像烧录章节。

CANN升级

参考CANN升级章节。

MindSpore升级

参考MindSpore升级章节。

[1]:
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset
/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  return self._float_to_str(self.smallest_subnormal)
/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  return self._float_to_str(self.smallest_subnormal)

设置运行环境

由于资源限制,需开启性能优化模式,具体设置如下参数:

max_device_memory=“2GB” : 设置设备可用的最大内存为2GB。

mode=mindspore.GRAPH_MODE : 表示在GRAPH_MODE模式中运行。

device_target=“Ascend” : 表示待运行的目标设备为Ascend。

jit_config={“jit_level”:“O2”} : 编译优化级别开启极致性能优化,使用下沉的执行方式。

ascend_config={“precision_mode”:“allow_mix_precision”} : 自动混合精度,自动将部分算子的精度降低到float16或bfloat16。

[2]:
import mindspore
mindspore.set_context(max_device_memory="2GB", mode=mindspore.GRAPH_MODE, device_target="Ascend", jit_config={"jit_level":"O2"}, ascend_config={"precision_mode":"allow_mix_precision"})

数据集准备与加载

MindSpore提供基于Pipeline的数据引擎,通过数据集(Dataset)实现高效的数据预处理。在本案例中,我们使用Mnist数据集,自动下载完成后,使用mindspore.dataset提供的数据变换进行预处理。

[3]:
#install download

!pip install download
Looking in indexes: https://repo.huaweicloud.com/repository/pypi/simple/
Requirement already satisfied: download in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (0.3.5)
Requirement already satisfied: tqdm in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from download) (4.66.5)
Requirement already satisfied: six in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from download) (1.16.0)
Requirement already satisfied: requests in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from download) (2.32.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->download) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->download) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->download) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->download) (2024.8.30)
[4]:
# Download data from open datasets
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)

file_sizes: 100%|███████████████████████████| 10.8M/10.8M [00:00<00:00, 101MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./

MNIST数据集目录结构如下:

MNIST_Data
└── train
    ├── train-images-idx3-ubyte (60000个训练图片)
    ├── train-labels-idx1-ubyte (60000个训练标签)
└── test
    ├── t10k-images-idx3-ubyte (10000个测试图片)
    ├── t10k-labels-idx1-ubyte (10000个测试标签)

数据下载完成后,获得数据集对象。

[5]:
train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

打印数据集中包含的数据列名,用于dataset的预处理。

[6]:
print(train_dataset.get_col_names())
['image', 'label']

MindSpore的dataset使用数据处理流水线(Data Processing Pipeline),需指定map、batch、shuffle等操作。这里我们使用map对图像数据及标签进行变换处理,将输入的图像缩放为1/255,根据均值0.1307和标准差值0.3081进行归一化处理,然后将处理好的数据集打包为大小为64的batch。

[7]:
def datapipe(dataset, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset
[8]:
# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

可使用create_tuple_iteratorcreate_dict_iterator对数据集进行迭代访问,查看数据和标签的shape和datatype。

[9]:
for image, label in test_dataset.create_tuple_iterator():
    print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")
    print(f"Shape of label: {label.shape} {label.dtype}")
    break
Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32
Shape of label: (64,) Int32
[10]:
for data in test_dataset.create_dict_iterator():
    print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")
    print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")
    break
Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32
Shape of label: (64,) Int32

模型构建

[11]:
# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
print(model)
Network<
  (flatten): Flatten<>
  (dense_relu_sequential): SequentialCell<
    (0): Dense<input_channels=784, output_channels=512, has_bias=True>
    (1): ReLU<>
    (2): Dense<input_channels=512, output_channels=512, has_bias=True>
    (3): ReLU<>
    (4): Dense<input_channels=512, output_channels=10, has_bias=True>
    >
  >

模型训练

在模型训练中,一个完整的训练过程(step)需要实现以下三步:

  1. 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。

  2. 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。

  3. 参数优化:将梯度更新到参数上。

MindSpore使用函数式自动微分机制,因此针对上述步骤需要实现:

  1. 定义正向计算函数。

  2. 使用value_and_grad通过函数变换获得梯度计算函数。

  3. 定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化。

[12]:
# Instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)

# 1. Define forward function
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# 3. Define function of one-step training
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

def train(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

除训练外,我们定义测试函数,用来评估模型的性能。

[13]:
def test(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程需多次迭代数据集,一次完整的迭代称为一轮(epoch)。在每一轮,遍历训练集进行训练,结束后使用测试集进行预测。打印每一轮的loss值和预测准确率(Accuracy),可以看到loss在不断下降,Accuracy在不断提高。

[14]:
epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(model, train_dataset)
    test(model, test_dataset, loss_fn)
print("Done!")
Epoch 1
-------------------------------
loss: 2.307922  [  0/938]
loss: 1.746887  [100/938]
loss: 0.848251  [200/938]
loss: 0.607513  [300/938]
loss: 0.369690  [400/938]
loss: 0.382843  [500/938]
loss: 0.293686  [600/938]
loss: 0.391556  [700/938]
loss: 0.227386  [800/938]
loss: 0.189972  [900/938]
Test:
 Accuracy: 91.1%, Avg loss: 0.314895

Epoch 2
-------------------------------
loss: 0.346725  [  0/938]
loss: 0.268921  [100/938]
loss: 0.247742  [200/938]
loss: 0.196686  [300/938]
loss: 0.264954  [400/938]
loss: 0.320938  [500/938]
loss: 0.368820  [600/938]
loss: 0.274811  [700/938]
loss: 0.373581  [800/938]
loss: 0.441010  [900/938]
Test:
 Accuracy: 92.8%, Avg loss: 0.247373

Epoch 3
-------------------------------
loss: 0.168976  [  0/938]
loss: 0.313812  [100/938]
loss: 0.195068  [200/938]
loss: 0.329803  [300/938]
loss: 0.464447  [400/938]
loss: 0.170197  [500/938]
loss: 0.280670  [600/938]
loss: 0.324707  [700/938]
loss: 0.134583  [800/938]
loss: 0.191467  [900/938]
Test:
 Accuracy: 93.9%, Avg loss: 0.207696

Done!

保存模型

模型训练完成后,需要将其参数进行保存。

[15]:
# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")
Saved Model to model.ckpt

权重加载

加载保存的权重分为两步:

  1. 重新实例化模型对象,构造模型。

  2. 加载模型参数,并将其加载至模型上。

[16]:
# Instantiate a random initialized model
model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)
[]

param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

模型推理

加载后的模型可以直接用于预测推理。

[17]:
import matplotlib.pyplot as plt

model.set_train(False)
for data, label in test_dataset:
    pred = model(data)
    predicted = pred.argmax(1)
    print(f'Predicted: "{predicted[:6]}", Actual: "{label[:6]}"')

    # 显示数字及数字的预测值
    plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        # 若预测正确,显示为蓝色;若预测错误,显示为红色
        color = 'blue' if predicted[i] == label[i] else 'red'
        plt.title('Predicted:{}'.format(predicted[i]), color=color)
        plt.imshow(data.asnumpy()[i][0], interpolation="None", cmap="gray")
        plt.axis('off')
    plt.show()
    break
Predicted: "[2 1 0 4 1 7]", Actual: "[2 1 0 4 1 7]"
../_images/orange_pi_dev_start_35_1.png

本案例已同步上线GitHub仓,更多案例可参考该仓库。

本案例运行所需环境:

  • 硬件:香橙派AIpro 16G 8-12T开发板

  • 镜像:香橙派官网Ubuntu镜像

  • CANN:8.0.RC3.alpha002

  • MindSpore:2.4.10