mindspore.dataset.vision.AutoAugment

class mindspore.dataset.vision.AutoAugment(policy=AutoAugmentPolicy.IMAGENET, interpolation=Inter.NEAREST, fill_value=0)[source]

Apply AutoAugment data augmentation method based on AutoAugment: Learning Augmentation Strategies from Data . This operation works only with 3-channel RGB images.

Parameters
  • policy (AutoAugmentPolicy, optional) –

    AutoAugment policies learned on different datasets. Default: AutoAugmentPolicy.IMAGENET. It can be AutoAugmentPolicy.IMAGENET, AutoAugmentPolicy.CIFAR10, AutoAugmentPolicy.SVHN. Randomly apply 2 operations from a candidate set. See auto augmentation details in AutoAugmentPolicy.

    • AutoAugmentPolicy.IMAGENET, means to apply AutoAugment learned on ImageNet dataset.

    • AutoAugmentPolicy.CIFAR10, means to apply AutoAugment learned on Cifar10 dataset.

    • AutoAugmentPolicy.SVHN, means to apply AutoAugment learned on SVHN dataset.

  • interpolation (Inter, optional) – Image interpolation method defined by Inter . Default: Inter.NEAREST.

  • fill_value (Union[int, tuple[int]], optional) – Pixel fill value for the area outside the transformed image. It can be an int or a 3-tuple. If it is a 3-tuple, it is used to fill R, G, B channels respectively. If it is an integer, it is used for all RGB channels. The fill_value values must be in range [0, 255]. Default: 0.

Raises
Supported Platforms:

CPU

Examples

>>> import numpy as np
>>> import mindspore.dataset as ds
>>> import mindspore.dataset.vision as vision
>>> from mindspore.dataset.vision import AutoAugmentPolicy, Inter
>>>
>>> # Use the transform in dataset pipeline mode
>>> transforms_list = [vision.AutoAugment(policy=AutoAugmentPolicy.IMAGENET,
...                                       interpolation=Inter.NEAREST,
...                                       fill_value=0)]
>>> data = np.random.randint(0, 255, size=(1, 100, 100, 3)).astype(np.uint8)
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data, ["image"])
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms_list, input_columns=["image"])
>>> for item in numpy_slices_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
...     print(item["image"].shape, item["image"].dtype)
...     break
(100, 100, 3) uint8
>>>
>>> # Use the transform in eager mode
>>> data = np.random.randint(0, 255, size=(100, 100, 3)).astype(np.uint8)
>>> output = vision.AutoAugment()(data)
>>> print(output.shape, output.dtype)
(100, 100, 3) uint8
Tutorial Examples: