mindspore.dataset.vision.LinearTransformation
- class mindspore.dataset.vision.LinearTransformation(transformation_matrix, mean_vector)[source]
Linearly transform the input numpy.ndarray image with a square transformation matrix and a mean vector.
It will first flatten the input image and subtract the mean vector from it, then compute the dot product with the transformation matrix, finally reshape it back to its original shape.
- Parameters
transformation_matrix (numpy.ndarray) – A square transformation matrix in shape of (D, D), where
.mean_vector (numpy.ndarray) – A mean vector in shape of (D,), where
.
- Raises
TypeError – If transformation_matrix is not of type
numpy.ndarray
.TypeError – If mean_vector is not of type
numpy.ndarray
.
- Supported Platforms:
CPU
Examples
>>> import numpy as np >>> import mindspore.dataset as ds >>> import mindspore.dataset.vision as vision >>> from mindspore.dataset.transforms import Compose >>> >>> # Use the transform in dataset pipeline mode >>> height, width = 32, 32 >>> dim = 3 * height * width >>> transformation_matrix = np.ones([dim, dim]) >>> mean_vector = np.zeros(dim) >>> transforms_list = Compose([vision.Resize((height,width)), ... vision.ToTensor(), ... vision.LinearTransformation(transformation_matrix, mean_vector)]) >>> # apply the transform to dataset through map function >>> data = np.random.randint(0, 255, size=(1, 100, 100, 3)).astype(np.uint8) >>> numpy_slices_dataset = ds.NumpySlicesDataset(data, ["image"]) >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms_list, input_columns="image") >>> for item in numpy_slices_dataset.create_dict_iterator(num_epochs=1, output_numpy=True): ... print(item["image"].shape, item["image"].dtype) ... break (3, 32, 32) float64 >>> >>> # Use the transform in eager mode >>> data = np.random.randn(10, 10, 3) >>> transformation_matrix = np.random.randn(300, 300) >>> mean_vector = np.random.randn(300,) >>> output = vision.LinearTransformation(transformation_matrix, mean_vector)(data) >>> print(output.shape, output.dtype) (10, 10, 3) float64
- Tutorial Examples: