使用不确定性估计工具箱

下载Notebook下载样例代码查看源文件

贝叶斯神经网络的优势之一就是可以获取不确定性,MindSpore Probability在上层提供了不确定性估计的工具箱,用户可以很方便地使用该工具箱计算不确定性。不确定性意味着深度学习模型对预测结果的不确定程度。目前,大多数深度学习算法只能给出预测结果,而不能判断预测结果的可靠性。不确定性主要有两种类型:偶然不确定性和认知不确定性。

  • 偶然不确定性(Aleatoric Uncertainty):描述数据中的内在噪声,即无法避免的误差,这个现象不能通过增加采样数据来削弱。

  • 认知不确定性(Epistemic Uncertainty):模型自身对输入数据的估计可能因为训练不佳、训练数据不够等原因而不准确,可以通过增加训练数据等方式来缓解。

不确定性估计工具箱,适用于主流的深度学习模型,如回归、分类等。在推理阶段,利用不确定性估计工具箱,开发人员只需通过训练模型和训练数据集,指定需要估计的任务和样本,即可得到任意不确定性和认知不确定性。基于不确定性信息,开发人员可以更好地理解模型和数据集。

本例将使用MNIST数据集和LeNet5网络模型示例,进行本次体验。

  1. 数据准备。

  2. 定义深度学习网络。

  3. 初始化不确定性评估工具箱。

  4. 评估认知不确定性。

本例适用于GPU和Ascend环境,你可以在这里下载完整的样例代码:https://gitee.com/mindspore/mindspore/tree/r1.7/tests/st/probability/toolbox

数据准备

下载数据集

以下示例代码将MNIST数据集下载并解压到指定位置。

[ ]:
import os
import requests

requests.packages.urllib3.disable_warnings()

def download_dataset(dataset_url, path):
    filename = dataset_url.split("/")[-1]
    save_path = os.path.join(path, filename)
    if os.path.exists(save_path):
        return
    if not os.path.exists(path):
        os.makedirs(path)
    res = requests.get(dataset_url, stream=True, verify=False)
    with open(save_path, "wb") as f:
        for chunk in res.iter_content(chunk_size=512):
            if chunk:
                f.write(chunk)
    print("The {} file is downloaded and saved in the path {} after processing".format(os.path.basename(dataset_url), path))

train_path = "datasets/MNIST_Data/train"
test_path = "datasets/MNIST_Data/test"

download_dataset("https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte", train_path)
download_dataset("https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte", train_path)
download_dataset("https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte", test_path)
download_dataset("https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte", test_path)

下载的数据集文件的目录结构如下:

./datasets/MNIST_Data
├── test
│   ├── t10k-images-idx3-ubyte
│   └── t10k-labels-idx1-ubyte
└── train
    ├── train-images-idx3-ubyte
    └── train-labels-idx1-ubyte

数据增强

定义数据集增强函数,并将原始数据增强为适用于LeNet网络的数据。

[2]:
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter

def create_dataset(data_path, batch_size=32, repeat_size=1,
                   num_parallel_workers=1):
    """
    create dataset for train or test
    """
    # define dataset
    mnist_ds = ds.MnistDataset(data_path)

    # define some parameters needed for data enhancement and rough justification
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # according to the parameters, generate the corresponding data enhancement method
    c_trans = [
        CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR),
        CV.Rescale(rescale_nml, shift_nml),
        CV.Rescale(rescale, shift),
        CV.HWC2CHW()
    ]
    type_cast_op = C.TypeCast(mstype.int32)

    # using map to apply operations to a dataset
    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=c_trans, input_columns="image", num_parallel_workers=num_parallel_workers)

    # process the generated dataset
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds

定义深度学习网络

本例采用LeNet5深度神经网络,在MindSpore中实现如下:

[3]:
import mindspore.nn as nn
from mindspore import load_checkpoint, load_param_into_net
from mindspore.common.initializer import Normal

class LeNet5(nn.Cell):
    """Lenet network structure."""
    # define the operator required
    def __init__(self, num_class=10, num_channel=1):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

    # use the preceding operators to construct networks
    def construct(self, x):
        x = self.max_pool2d(self.relu(self.conv1(x)))
        x = self.max_pool2d(self.relu(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

初始化不确定性工具箱

初始化不确定性工具箱的UncertaintyEvaluation功能,准备如下:

  1. 准备模型权重参数文件。

  2. 将模型权重参数文件载入神经网络中。

  3. 将训练数据集增强为适用于神经网络的数据。

  4. 将上述网络和数据集载入到UncertaintyEvaluation中。

MindSpore中使用不确定性工具箱UncertaintyEvaluation接口来测量模型偶然不确定性和认知不确定性,更多使用方法请参见API

准备模型权重参数文件

本例已经准备好了对应的模型权重参数文件checkpoint_lenet.ckpt,本参数文件为初学入门中训练完成5个epoch后保存的权重参数文件,执行如下命令进行下载:

[4]:
download_dataset("https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/checkpoint_lenet.ckpt", ".")

完成初始化

将需要进行不确定性测量的DNN网络与训练数据集载入,由于不确定性测量需要贝叶斯网络,所以当第一次调用初始化完成的不确定性测量工具时,会将DNN网络转成贝叶斯网络进行训练,完成后可传入对应的数据进行偶然不确定性或认知不确定性进行测量。

[5]:
from mindspore import context
from mindspore.nn.probability.toolbox import UncertaintyEvaluation
from mindspore import dtype as mstype

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# get trained model
network = LeNet5()
param_dict = load_checkpoint('checkpoint_lenet.ckpt')
load_param_into_net(network, param_dict)
# get train
ds_train = create_dataset('./datasets/MNIST_Data/train')
evaluation = UncertaintyEvaluation(model=network,
                                   train_dataset=ds_train,
                                   task_type='classification',
                                   num_classes=10,
                                   epochs=1,
                                   epi_uncer_model_path=None,
                                   ale_uncer_model_path=None,
                                   save_model=False)

评估认知不确定性

转换成贝叶斯训练测量

首先将验证数据集取出一个batch,进行认知不确定性测量,首次调用时会将原本深度神经网络转换为贝叶斯网络进行训练。

[6]:
ds_test = create_dataset("./datasets/MNIST_Data/test")
batch_data = next(ds_test.create_dict_iterator())
eval_images = batch_data["image"]
eval_labels = batch_data["label"]
epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_images)
epoch: 1 step: 1, loss is 0.14702837
epoch: 1 step: 2, loss is 0.00017862688
epoch: 1 step: 3, loss is 0.09421586
epoch: 1 step: 4, loss is 0.0003434865
epoch: 1 step: 5, loss is 7.1358285e-05
... ...
epoch: 1 step: 1871, loss is 0.20069705
epoch: 1 step: 1872, loss is 0.12135945
epoch: 1 step: 1873, loss is 0.04572148
epoch: 1 step: 1874, loss is 0.04962858
epoch: 1 step: 1875, loss is 0.0019944885
evaluation.eval_epistemic_uncertainty:认知不确定性测量接口,第一次调用时会使用训练数据对DNN模型进行转换成贝叶斯训练。
eval_images:即偶然不确定性测试使用的batch图片。

打印认知不确定性

取一个batch的数据将label打印出来。

[7]:
print(eval_labels)
print(epistemic_uncertainty.shape)
[2 9 4 3 9 9 2 4 9 6 0 5 6 8 7 6 1 9 7 6 5 4 0 3 7 7 6 7 7 4 6 2]
(32, 10)

认知不确定性内容为32张图片对应0-9的分类模型的不确定性值。

取前面两个图片打印出对应模型的的偶然不确定性值。

[8]:
print("the picture one, number is {}, epistemic uncertainty is:\n{}".format(eval_labels[0], epistemic_uncertainty[0]))
print("the picture two, number is {}, epistemic uncertainty is:\n{}".format(eval_labels[1], epistemic_uncertainty[1]))
the picture one, number is 2, epistemic uncertainty is:
[0.75372726 0.2053496  3.737096   0.7113453  0.93452704 0.40339947
 0.91918266 0.44237098 0.40863538 0.8195221 ]
the picture two, number is 9, epistemic uncertainty is:
[0.97602427 0.37808532 0.4955423  0.17907992 1.3365419  0.20227651
 2.2211757  0.27501273 0.30733848 3.7536747 ]