应用自动数据增强

Linux Ascend GPU CPU 数据准备 中级 高级

查看源文件    查看notebook    在线运行

概述

自动数据增强(AutoAugment)[1]是在一系列图像增强子策略的搜索空间中,通过搜索算法找到适合特定数据集的图像增强方案。MindSpore的c_transforms模块提供了丰富的C++算子来实现AutoAugment,用户也可以自定义函数或者算子来实现。更多MindSpore算子的详细说明参见API文档

MindSpore算子和AutoAugment中的算子的对应关系如下:

AutoAugment算子

MindSpore算子

描述

shearX

RandomAffine

横向剪切

shearY

RandomAffine

纵向剪切

translateX

RandomAffine

水平平移

translateY

RandomAffine

垂直平移

rotate

RandomRotation

旋转变换

color

RandomColor

颜色变换

posterize

RandomPosterize

减少颜色通道位数

solarize

RandomSolarize

指定的阈值范围内,反转所有的像素点

contrast

RandomColorAdjust

调整对比度

sharpness

RandomSharpness

调整锐度

brightness

RandomColorAdjust

调整亮度

autocontrast

AutoContrast

最大化图像对比度

equalize

Equalize

均衡图像直方图

invert

Invert

反转图像

ImageNet自动数据增强

本教程以在ImageNet数据集上实现AutoAugment作为示例。

针对ImageNet数据集的数据增强策略包含25条子策略,每条子策略中包含两种变换,针对一个batch中的每张图像随机挑选一个子策略的组合,以预定的概率来决定是否执行子策略中的每种变换。

用户可以使用MindSpore中c_transforms模块的RandomSelectSubpolicy接口来实现AutoAugment,在ImageNet分类训练中标准的数据增强方式分以下几个步骤:

  • RandomCropDecodeResize:随机裁剪后进行解码。

  • RandomHorizontalFlip:水平方向上随机翻转。

  • Normalize:归一化。

  • HWC2CHW:图片通道变化。

RandomCropDecodeResize后插入AutoAugment变换,如下所示:

  1. 引入MindSpore数据增强模块。

    import matplotlib.pyplot as plt
    
    import mindspore.dataset as ds
    import mindspore.dataset.transforms.c_transforms as c_transforms
    import mindspore.dataset.vision.c_transforms as c_vision
    from mindspore import dtype as mstype
    
  2. 定义MindSpore算子到AutoAugment算子的映射:

    # define Auto Augmentation operators
    PARAMETER_MAX = 10
    
    def float_parameter(level, maxval):
        return float(level) * maxval /  PARAMETER_MAX
    
    def int_parameter(level, maxval):
        return int(level * maxval / PARAMETER_MAX)
    
    def shear_x(level):
        v = float_parameter(level, 0.3)
        return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, shear=(-v,-v)), c_vision.RandomAffine(degrees=0, shear=(v, v))])
    
    def shear_y(level):
        v = float_parameter(level, 0.3)
        return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, shear=(0, 0, -v,-v)), c_vision.RandomAffine(degrees=0, shear=(0, 0, v, v))])
    
    def translate_x(level):
        v = float_parameter(level, 150 / 331)
        return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, translate=(-v,-v)), c_vision.RandomAffine(degrees=0, translate=(v, v))])
    
    def translate_y(level):
        v = float_parameter(level, 150 / 331)
        return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, translate=(0, 0, -v,-v)), c_vision.RandomAffine(degrees=0, translate=(0, 0, v, v))])
    
    def color_impl(level):
        v = float_parameter(level, 1.8) + 0.1
        return c_vision.RandomColor(degrees=(v, v))
    
    def rotate_impl(level):
        v = int_parameter(level, 30)
        return c_transforms.RandomChoice([c_vision.RandomRotation(degrees=(-v, -v)), c_vision.RandomRotation(degrees=(v, v))])
    
    def solarize_impl(level):
        level = int_parameter(level, 256)
        v = 256 - level
        return c_vision.RandomSolarize(threshold=(0, v))
    
    def posterize_impl(level):
        level = int_parameter(level, 4)
        v = 4 - level
        return c_vision.RandomPosterize(bits=(v, v))
    
    def contrast_impl(level):
        v = float_parameter(level, 1.8) + 0.1
        return c_vision.RandomColorAdjust(contrast=(v, v))
    
    def autocontrast_impl(level):
        return c_vision.AutoContrast()
    
    def sharpness_impl(level):
        v = float_parameter(level, 1.8) + 0.1
        return c_vision.RandomSharpness(degrees=(v, v))
    
    def brightness_impl(level):
        v = float_parameter(level, 1.8) + 0.1
        return c_vision.RandomColorAdjust(brightness=(v, v))
    
  3. 定义ImageNet数据集的AutoAugment策略:

    # define the Auto Augmentation policy
    imagenet_policy = [
        [(posterize_impl(8), 0.4), (rotate_impl(9), 0.6)],
        [(solarize_impl(5), 0.6), (autocontrast_impl(5), 0.6)],
        [(c_vision.Equalize(), 0.8), (c_vision.Equalize(), 0.6)],
        [(posterize_impl(7), 0.6), (posterize_impl(6), 0.6)],
        [(c_vision.Equalize(), 0.4), (solarize_impl(4), 0.2)],
    
        [(c_vision.Equalize(), 0.4), (rotate_impl(8), 0.8)],
        [(solarize_impl(3), 0.6), (c_vision.Equalize(), 0.6)],
        [(posterize_impl(5), 0.8), (c_vision.Equalize(), 1.0)],
        [(rotate_impl(3), 0.2), (solarize_impl(8), 0.6)],
        [(c_vision.Equalize(), 0.6), (posterize_impl(6), 0.4)],
    
        [(rotate_impl(8), 0.8), (color_impl(0), 0.4)],
        [(rotate_impl(9), 0.4), (c_vision.Equalize(), 0.6)],
        [(c_vision.Equalize(), 0.0), (c_vision.Equalize(), 0.8)],
        [(c_vision.Invert(), 0.6), (c_vision.Equalize(), 1.0)],
        [(color_impl(4), 0.6), (contrast_impl(8), 1.0)],
    
        [(rotate_impl(8), 0.8), (color_impl(2), 1.0)],
        [(color_impl(8), 0.8), (solarize_impl(7), 0.8)],
        [(sharpness_impl(7), 0.4), (c_vision.Invert(), 0.6)],
        [(shear_x(5), 0.6), (c_vision.Equalize(), 1.0)],
        [(color_impl(0), 0.4), (c_vision.Equalize(), 0.6)],
    
        [(c_vision.Equalize(), 0.4), (solarize_impl(4), 0.2)],
        [(solarize_impl(5), 0.6), (autocontrast_impl(5), 0.6)],
        [(c_vision.Invert(), 0.6), (c_vision.Equalize(), 1.0)],
        [(color_impl(4), 0.6), (contrast_impl(8), 1.0)],
        [(c_vision.Equalize(), 0.8), (c_vision.Equalize(), 0.6)],
    ]
    
  4. RandomCropDecodeResize操作后插入AutoAugment变换。

    def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, shuffle=True, num_samples=5, target="Ascend"):
        # create a train or eval imagenet2012 dataset for ResNet-50
        dataset = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8,
                                   shuffle=shuffle, num_samples=num_samples)
    
        image_size = 224
        mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
        std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    
        # define map operations
        if do_train:
            trans = [
                c_vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            ]
    
            post_trans = [
                c_vision.RandomHorizontalFlip(prob=0.5),
            ]
        else:
            trans = [
                c_vision.Decode(),
                c_vision.Resize(256),
                c_vision.CenterCrop(image_size),
                c_vision.Normalize(mean=mean, std=std),
                c_vision.HWC2CHW()
            ]
        dataset = dataset.map(operations=trans, input_columns="image")
        if do_train:
            dataset = dataset.map(operations=c_vision.RandomSelectSubpolicy(imagenet_policy), input_columns=["image"])
            dataset = dataset.map(operations=post_trans, input_columns="image")
        type_cast_op = c_transforms.TypeCast(mstype.int32)
        dataset = dataset.map(operations=type_cast_op, input_columns="label")
        # apply the batch operation
        dataset = dataset.batch(batch_size, drop_remainder=True)
        # apply the repeat operation
        dataset = dataset.repeat(repeat_num)
    
        return dataset
    
  5. 验证自动数据增强效果。

    # Define the path to image folder directory. This directory needs to contain sub-directories which contain the images
    DATA_DIR = "/path/to/image_folder_directory"
    dataset = create_dataset(dataset_path=DATA_DIR, do_train=True, batch_size=5, shuffle=False, num_samples=5)
    
    epochs = 5
    itr = dataset.create_dict_iterator()
    fig=plt.figure(figsize=(8, 8))
    columns = 5
    rows = 5
    
    step_num = 0
    for ep_num in range(epochs):
        for data in itr:
            step_num += 1
            for index in range(rows):
                fig.add_subplot(rows, columns, ep_num * rows + index + 1)
                plt.imshow(data['image'].asnumpy()[index])
    plt.show()
    

    为了更好地演示效果,此处只加载5张图片,且读取时不进行shuffle操作,自动数据增强时也不进行NormalizeHWC2CHW操作。

    augment

    运行结果可以看到,batch中每张图像的增强效果,水平方向表示1个batch的5张图像,垂直方向表示5个batch。

参考文献

[1] AutoAugment: Learning Augmentation Policies from Data.