mindspore.dataset.vision.LinearTransformation

View Source On Gitee
class mindspore.dataset.vision.LinearTransformation(transformation_matrix, mean_vector)[source]

Linearly transform the input numpy.ndarray image with a square transformation matrix and a mean vector.

It will first flatten the input image and subtract the mean vector from it, then compute the dot product with the transformation matrix, finally reshape it back to its original shape.

Parameters
  • transformation_matrix (numpy.ndarray) – A square transformation matrix in shape of (D, D), where \(D = C \times H \times W\) .

  • mean_vector (numpy.ndarray) – A mean vector in shape of (D,), where \(D = C \times H \times W\) .

Raises
Supported Platforms:

CPU

Examples

>>> import mindspore.dataset as ds
>>> import mindspore.dataset.vision as vision
>>> import numpy as np
>>> from mindspore.dataset.transforms import Compose
>>>
>>> height, width = 32, 32
>>> dim = 3 * height * width
>>> transformation_matrix = np.ones([dim, dim])
>>> mean_vector = np.zeros(dim)
>>> transforms_list = Compose([vision.Decode(to_pil=True),
...                            vision.Resize((height,width)),
...                            vision.ToTensor(),
...                            vision.LinearTransformation(transformation_matrix, mean_vector)])
>>> # apply the transform to dataset through map function
>>> image_folder_dataset = ds.ImageFolderDataset("/path/to/image_folder_dataset_directory")
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
...                                                 input_columns="image")
Tutorial Examples: