# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""image"""
import numbers
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .conv import Conv2d
from .container import CellList
from .pooling import AvgPool2d
from .activation import ReLU
from ..cell import Cell
__all__ = ['ImageGradients', 'SSIM', 'MSSSIM', 'PSNR', 'CentralCrop']
[docs]class ImageGradients(Cell):
r"""
Returns two tensors, the first is along the height dimension and the second is along the width dimension.
Assume an image shape is :math:`h*w`. The gradients along the height and the width are :math:`dy` and :math:`dx`,
respectively.
.. math::
dy[i] = \begin{cases} image[i+1, :]-image[i, :], &if\ 0<=i<h-1 \cr
0, &if\ i==h-1\end{cases}
dx[i] = \begin{cases} image[:, i+1]-image[:, i], &if\ 0<=i<w-1 \cr
0, &if\ i==w-1\end{cases}
Inputs:
- **images** (Tensor) - The input image data, with format 'NCHW'.
Outputs:
- **dy** (Tensor) - vertical image gradients, the same type and shape as input.
- **dx** (Tensor) - horizontal image gradients, the same type and shape as input.
Examples:
>>> net = nn.ImageGradients()
>>> image = Tensor(np.array([[[[1,2],[3,4]]]]), dtype=mstype.int32)
>>> net(image)
[[[[2,2]
[0,0]]]]
[[[[1,0]
[1,0]]]]
"""
def __init__(self):
super(ImageGradients, self).__init__()
def construct(self, images):
check = _check_input_4d(F.shape(images), "images", self.cls_name)
images = F.depend(images, check)
batch_size, depth, height, width = P.Shape()(images)
if height == 1:
dy = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
else:
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
dy = P.Concat(2)((dy, dy_last))
if width == 1:
dx = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
else:
dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
dx = P.Concat(3)((dx, dx_last))
return dy, dx
def _convert_img_dtype_to_float32(img, max_val):
"""convert img dtype to float32"""
# Ususally max_val is 1.0 or 255, we will do the scaling if max_val > 1.
# We will scale img pixel value if max_val > 1. and just cast otherwise.
ret = F.cast(img, mstype.float32)
max_val = F.scalar_cast(max_val, mstype.float32)
if max_val > 1.:
scale = 1. / max_val
ret = ret * scale
return ret
@constexpr
def _get_dtype_max(dtype):
"""get max of the dtype"""
np_type = mstype.dtype_to_nptype(dtype)
if issubclass(np_type, numbers.Integral):
dtype_max = np.float64(np.iinfo(np_type).max)
else:
dtype_max = 1.0
return dtype_max
@constexpr
def _check_input_4d(input_shape, param_name, func_name):
if len(input_shape) != 4:
raise ValueError(f"{func_name} {param_name} should be 4d, but got shape {input_shape}")
return True
@constexpr
def _check_input_filter_size(input_shape, param_name, filter_size, func_name):
_check_input_4d(input_shape, param_name, func_name)
validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, Rel.GE, func_name)
validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, Rel.GE, func_name)
@constexpr
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
def _conv2d(in_channels, out_channels, kernel_size, weight, stride=1, padding=0):
return Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
weight_init=weight, padding=padding, pad_mode="valid")
def _create_window(size, sigma):
x_data, y_data = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1]
x_data = np.expand_dims(x_data, axis=-1).astype(np.float32)
x_data = np.expand_dims(x_data, axis=-1) ** 2
y_data = np.expand_dims(y_data, axis=-1).astype(np.float32)
y_data = np.expand_dims(y_data, axis=-1) ** 2
sigma = 2 * sigma ** 2
g = np.exp(-(x_data + y_data) / sigma)
return np.transpose(g / np.sum(g), (2, 3, 0, 1))
def _split_img(x):
_, c, _, _ = F.shape(x)
img_split = P.Split(1, c)
output = img_split(x)
return output, c
def _compute_per_channel_loss(c1, c2, img1, img2, conv):
"""computes ssim index between img1 and img2 per single channel"""
dot_img = img1 * img2
mu1 = conv(img1)
mu2 = conv(img2)
mu1_sq = mu1 * mu1
mu2_sq = mu2 * mu2
mu1_mu2 = mu1 * mu2
sigma1_tmp = conv(img1 * img1)
sigma1_sq = sigma1_tmp - mu1_sq
sigma2_tmp = conv(img2 * img2)
sigma2_sq = sigma2_tmp - mu2_sq
sigma12_tmp = conv(dot_img)
sigma12 = sigma12_tmp - mu1_mu2
a = (2 * mu1_mu2 + c1)
b = (mu1_sq + mu2_sq + c1)
v1 = 2 * sigma12 + c2
v2 = sigma1_sq + sigma2_sq + c2
ssim = (a * v1) / (b * v2)
cs = v1 / v2
return ssim, cs
def _compute_multi_channel_loss(c1, c2, img1, img2, conv, concat, mean):
"""computes ssim index between img1 and img2 per color channel"""
split_img1, c = _split_img(img1)
split_img2, _ = _split_img(img2)
multi_ssim = ()
multi_cs = ()
for i in range(c):
ssim_per_channel, cs_per_channel = _compute_per_channel_loss(c1, c2, split_img1[i], split_img2[i], conv)
multi_ssim += (ssim_per_channel,)
multi_cs += (cs_per_channel,)
multi_ssim = concat(multi_ssim)
multi_cs = concat(multi_cs)
ssim = mean(multi_ssim, (2, 3))
cs = mean(multi_cs, (2, 3))
return ssim, cs
[docs]class SSIM(Cell):
r"""
Returns SSIM index between img1 and img2.
Its implementation is based on Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). `Image quality
assessment: from error visibility to structural similarity <https://ieeexplore.ieee.org/document/1284395>`_.
IEEE transactions on image processing.
.. math::
l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\
c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\
s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\
SSIM(x,y)&=l*c*s\\&=\frac{(2\mu_x\mu_y+C_1)(2\sigma_{xy}+C_2}{(\mu_x^2+\mu_y^2+C_1)(\sigma_x^2+\sigma_y^2+C_2)}.
Args:
max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
Default: 1.0.
filter_size (int): The size of the Gaussian filter. Default: 11.
filter_sigma (float): The standard deviation of Gaussian kernel. Default: 1.5.
k1 (float): The constant used to generate c1 in the luminance comparison function. Default: 0.01.
k2 (float): The constant used to generate c2 in the contrast comparison function. Default: 0.03.
Inputs:
- **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2.
- **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1.
Outputs:
Tensor, has the same dtype as img1. It is a 1-D tensor with shape N, where N is the batch num of img1.
Examples:
>>> net = nn.SSIM()
>>> img1 = Tensor(np.random.random((1,3,16,16)))
>>> img2 = Tensor(np.random.random((1,3,16,16)))
>>> ssim = net(img1, img2)
"""
def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
super(SSIM, self).__init__()
validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name)
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
window = _create_window(filter_size, filter_sigma)
self.conv = _conv2d(1, 1, filter_size, Tensor(window))
self.conv.weight.requires_grad = False
self.reduce_mean = P.ReduceMean()
self.concat = P.Concat(axis=1)
def construct(self, img1, img2):
_check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name)
_check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name)
P.SameTypeShape()(img1, img2)
dtype_max_val = _get_dtype_max(F.dtype(img1))
max_val = F.scalar_cast(self.max_val, F.dtype(img1))
max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
c1 = (self.k1 * max_val) ** 2
c2 = (self.k2 * max_val) ** 2
ssim_ave_channel, _ = _compute_multi_channel_loss(c1, c2, img1, img2, self.conv, self.concat, self.reduce_mean)
loss = self.reduce_mean(ssim_ave_channel, -1)
return loss
def _downsample(img1, img2, op):
a = op(img1)
b = op(img2)
return a, b
[docs]class MSSSIM(Cell):
r"""
Returns MS-SSIM index between img1 and img2.
Its implementation is based on Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. `Multiscale structural similarity
for image quality assessment <https://ieeexplore.ieee.org/document/1292216>`_.
Signals, Systems and Computers, 2004.
.. math::
l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\
c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\
s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\
MSSSIM(x,y)&=l^alpha_M*{\prod_{1\leq j\leq M} (c^beta_j*s^gamma_j)}.
Args:
max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
Default: 1.0.
power_factors (Union[tuple, list]): Iterable of weights for each scal e.
Default: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). Default values obtained by Wang et al.
filter_size (int): The size of the Gaussian filter. Default: 11.
filter_sigma (float): The standard deviation of Gaussian kernel. Default: 1.5.
k1 (float): The constant used to generate c1 in the luminance comparison function. Default: 0.01.
k2 (float): The constant used to generate c2 in the contrast comparison function. Default: 0.03.
Inputs:
- **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2.
- **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1.
Outputs:
Tensor, the value is in range [0, 1]. It is a 1-D tensor with shape N, where N is the batch num of img1.
Examples:
>>> net = nn.MSSSIM(power_factors=(0.033, 0.033, 0.033))
>>> img1 = Tensor(np.random.random((1,3,128,128)))
>>> img2 = Tensor(np.random.random((1,3,128,128)))
>>> msssim = net(img1, img2)
"""
def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11,
filter_sigma=1.5, k1=0.01, k2=0.03):
super(MSSSIM, self).__init__()
validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val
validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name)
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name)
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
window = _create_window(filter_size, filter_sigma)
self.level = len(power_factors)
self.conv = []
for i in range(self.level):
self.conv.append(_conv2d(1, 1, filter_size, Tensor(window)))
self.conv[i].weight.requires_grad = False
self.multi_convs_list = CellList(self.conv)
self.weight_tensor = Tensor(power_factors, mstype.float32)
self.avg_pool = AvgPool2d(kernel_size=2, stride=2, pad_mode='valid')
self.relu = ReLU()
self.reduce_mean = P.ReduceMean()
self.prod = P.ReduceProd()
self.pow = P.Pow()
self.pack = P.Pack(axis=-1)
self.concat = P.Concat(axis=1)
def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", self.cls_name)
_check_input_4d(F.shape(img2), "img2", self.cls_name)
_check_input_dtype(F.dtype(img1), 'img1', mstype.number_type, self.cls_name)
P.SameTypeShape()(img1, img2)
dtype_max_val = _get_dtype_max(F.dtype(img1))
max_val = F.scalar_cast(self.max_val, F.dtype(img1))
max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
c1 = (self.k1 * max_val) ** 2
c2 = (self.k2 * max_val) ** 2
sim = ()
mcs = ()
for i in range(self.level):
sim, cs = _compute_multi_channel_loss(c1, c2, img1, img2,
self.multi_convs_list[i], self.concat, self.reduce_mean)
mcs += (self.relu(cs),)
img1, img2 = _downsample(img1, img2, self.avg_pool)
mcs = mcs[0:-1:1]
mcs_and_ssim = self.pack(mcs + (self.relu(sim),))
mcs_and_ssim = self.pow(mcs_and_ssim, self.weight_tensor)
ms_ssim = self.prod(mcs_and_ssim, -1)
loss = self.reduce_mean(ms_ssim, -1)
return loss
[docs]class PSNR(Cell):
r"""
Returns Peak Signal-to-Noise Ratio of two image batches.
It produces a PSNR value for each image in batch.
Assume inputs are :math:`I` and :math:`K`, both with shape :math:`h*w`.
:math:`MAX` represents the dynamic range of pixel values.
.. math::
MSE&=\frac{1}{hw}\sum\limits_{i=0}^{h-1}\sum\limits_{j=0}^{w-1}[I(i,j)-K(i,j)]^2\\
PSNR&=10*log_{10}(\frac{MAX^2}{MSE})
Args:
max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
The value must be greater than 0. Default: 1.0.
Inputs:
- **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2.
- **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1.
Outputs:
Tensor, with dtype mindspore.float32. It is a 1-D tensor with shape N, where N is the batch num of img1.
Examples:
>>> net = nn.PSNR()
>>> img1 = Tensor(np.random.random((1,3,16,16)))
>>> img2 = Tensor(np.random.random((1,3,16,16)))
>>> psnr = net(img1, img2)
"""
def __init__(self, max_val=1.0):
super(PSNR, self).__init__()
validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val
def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", self.cls_name)
_check_input_4d(F.shape(img2), "img2", self.cls_name)
P.SameTypeShape()(img1, img2)
dtype_max_val = _get_dtype_max(F.dtype(img1))
max_val = F.scalar_cast(self.max_val, F.dtype(img1))
max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
mse = P.ReduceMean()(F.square(img1 - img2), (-3, -2, -1))
# 10*log_10(max_val^2/MSE)
psnr = 10 * P.Log()(F.square(max_val) / mse) / F.scalar_log(10.0)
return psnr
@constexpr
def _raise_dims_rank_error(input_shape, param_name, func_name):
"""raise error if input is not 3d or 4d"""
raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}")
@constexpr
def _get_bbox(rank, shape, size_h, size_w):
"""get bbox start and size for slice"""
if rank == 3:
c, h, w = shape
else:
n, c, h, w = shape
bbox_h_start = int((float(h) - size_h) / 2)
bbox_w_start = int((float(w) - size_w) / 2)
bbox_h_size = h - bbox_h_start * 2
bbox_w_size = w - bbox_w_start * 2
if rank == 3:
bbox_begin = (0, bbox_h_start, bbox_w_start)
bbox_size = (c, bbox_h_size, bbox_w_size)
else:
bbox_begin = (0, 0, bbox_h_start, bbox_w_start)
bbox_size = (n, c, bbox_h_size, bbox_w_size)
return bbox_begin, bbox_size
[docs]class CentralCrop(Cell):
"""
Crop the centeral region of the images with the central_fraction.
Args:
central_fraction (float): Fraction of size to crop. It must be float and in range (0.0, 1.0].
Inputs:
- **image** (Tensor) - A 3-D tensor of shape [C, H, W], or a 4-D tensor of shape [N, C, H, W].
Outputs:
Tensor, 3-D or 4-D float tensor, according to the input.
Examples:
>>> net = nn.CentralCrop(central_fraction=0.5)
>>> image = Tensor(np.random.random((4, 3, 4, 4)), mindspore.float32)
>>> output = net(image)
"""
def __init__(self, central_fraction):
super(CentralCrop, self).__init__()
validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name)
self.central_fraction = validator.check_number_range('central_fraction', central_fraction,
0.0, 1.0, Rel.INC_RIGHT, self.cls_name)
self.slice = P.Slice()
def construct(self, image):
image_shape = F.shape(image)
rank = len(image_shape)
h, w = image_shape[-2], image_shape[-1]
if not rank in (3, 4):
return _raise_dims_rank_error(image_shape, "image", self.cls_name)
if self.central_fraction == 1.0:
return image
size_h = self.central_fraction * h
size_w = self.central_fraction * w
bbox_begin, bbox_size = _get_bbox(rank, image_shape, size_h, size_w)
image = self.slice(image, bbox_begin, bbox_size)
return image