mindspore.dataset.vision.py_transforms.LinearTransformation

class mindspore.dataset.vision.py_transforms.LinearTransformation(transformation_matrix, mean_vector)[source]

Apply linear transformation to the input NumPy image array, given a square transformation matrix and a mean vector.

The transformation first flattens the input array and subtracts the mean vector from it. It then computes the dot product with the transformation matrix, and reshapes it back to its original shape.

Parameters
  • transformation_matrix (numpy.ndarray) – a square transformation matrix of shape (D, D), where \(D = C \times H \times W\).

  • mean_vector (numpy.ndarray) – a NumPy ndarray of shape (D,) where \(D = C \times H \times W\).

Examples

>>> from mindspore.dataset.transforms.py_transforms import Compose
>>> import numpy as np
>>> height, width = 32, 32
>>> dim = 3 * height * width
>>> transformation_matrix = np.ones([dim, dim])
>>> mean_vector = np.zeros(dim)
>>> transforms_list = Compose([py_vision.Decode(),
...                            py_vision.Resize((height,width)),
...                            py_vision.ToTensor(),
...                            py_vision.LinearTransformation(transformation_matrix, mean_vector)])
>>> # apply the transform to dataset through map function
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
...                                                 input_columns="image")