r"""SSIM, MS-SSIM, CW-SSIM Metric
Created by:
- https://github.com/dingkeyan93/IQA-optimization/blob/master/IQA_pytorch/SSIM.py
- https://github.com/dingkeyan93/IQA-optimization/blob/master/IQA_pytorch/MS_SSIM.py
- https://github.com/dingkeyan93/IQA-optimization/blob/master/IQA_pytorch/CW_SSIM.py
Modified by: Jiadi Mo (https://github.com/JiadiMo)
Refer to:
- Official SSIM matlab code from https://www.cns.nyu.edu/~lcv/ssim/;
- PIQ from https://github.com/photosynthesis-team/piq;
- BasicSR from https://github.com/xinntao/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py;
- Official MS-SSIM matlab code from https://ece.uwaterloo.ca/~z70wang/research/iwssim/msssim.zip;
- Official CW-SSIM matlab code from
https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/submissions/43017/versions/1/download/zip;
"""
import numpy as np
import torch
import torch.nn.functional as F
from pyiqa.utils.color_util import to_y_channel
from pyiqa.matlab_utils import fspecial, SCFpyr_PyTorch, math_util, filter2
from pyiqa.utils.registry import ARCH_REGISTRY
from .func_util import preprocess_rgb
[docs]
def ssim(
X,
Y,
win=None,
get_ssim_map=False,
get_cs=False,
get_weight=False,
downsample=False,
data_range=1.0,
):
if win is None:
win = fspecial(11, 1.5, X.shape[1]).to(X)
C1 = (0.01 * data_range) ** 2
C2 = (0.03 * data_range) ** 2
# Averagepool image if the size is large enough
f = max(1, round(min(X.size()[-2:]) / 256))
# Downsample operation is used in official matlab code
if (f > 1) and downsample:
X = F.avg_pool2d(X, kernel_size=f)
Y = F.avg_pool2d(Y, kernel_size=f)
mu1 = filter2(X, win, 'valid')
mu2 = filter2(Y, win, 'valid')
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = filter2(X * X, win, 'valid') - mu1_sq
sigma2_sq = filter2(Y * Y, win, 'valid') - mu2_sq
sigma12 = filter2(X * Y, win, 'valid') - mu1_mu2
cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
cs_map = F.relu(
cs_map
) # force the ssim response to be nonnegative to avoid negative results.
ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
ssim_val = ssim_map.mean([1, 2, 3])
if get_weight:
weights = torch.log((1 + sigma1_sq / C2) * (1 + sigma2_sq / C2))
return ssim_map, weights
if get_ssim_map:
return ssim_map
if get_cs:
return ssim_val, cs_map.mean([1, 2, 3])
return ssim_val
@ARCH_REGISTRY.register()
[docs]
class SSIM(torch.nn.Module):
r"""Args:
- channel: number of channel.
- downsample: boolean, whether to downsample same as official matlab code.
- test_y_channel: boolean, whether to use y channel on ycbcr same as official matlab code.
"""
def __init__(
self,
channels=3,
downsample=False,
test_y_channel=True,
color_space='yiq',
crop_border=0.0,
):
super(SSIM, self).__init__()
self.downsample = downsample
self.test_y_channel = test_y_channel
self.color_space = color_space
self.crop_border = crop_border
self.data_range = 255
[docs]
def forward(self, X, Y):
assert X.shape == Y.shape, (
f'Input {X.shape} and reference images should have the same shape'
)
if self.crop_border != 0:
crop_border = self.crop_border
X = X[..., crop_border:-crop_border, crop_border:-crop_border]
Y = Y[..., crop_border:-crop_border, crop_border:-crop_border]
X = preprocess_rgb(
X, self.test_y_channel, self.data_range, self.color_space
).to(torch.float64)
Y = preprocess_rgb(
Y, self.test_y_channel, self.data_range, self.color_space
).to(torch.float64)
score = ssim(X, Y, data_range=self.data_range, downsample=self.downsample)
return score
[docs]
def ms_ssim(
X,
Y,
win=None,
data_range=1.0,
downsample=False,
test_y_channel=True,
is_prod=True,
color_space='yiq',
):
r"""Compute Multiscale structural similarity for a batch of images.
Args:
x: An input tensor. Shape :math:`(N, C, H, W)`.
y: A target tensor. Shape :math:`(N, C, H, W)`.
win: Window setting.
downsample: Boolean, whether to downsample which mimics official SSIM matlab code.
test_y_channel: Boolean, whether to use y channel on ycbcr.
is_prod: Boolean, calculate product or sum between mcs and weight.
Returns:
Index of similarity between two images. Usually in [0, 1] interval.
"""
if not X.shape == Y.shape:
raise ValueError('Input images must have the same dimensions.')
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(X)
levels = weights.shape[0]
mcs = []
for _ in range(levels):
ssim_val, cs = ssim(
X,
Y,
win=win,
get_cs=True,
downsample=downsample,
data_range=data_range,
)
mcs.append(cs)
padding = (X.shape[2] % 2, X.shape[3] % 2)
X = F.avg_pool2d(X, kernel_size=2, padding=padding)
Y = F.avg_pool2d(Y, kernel_size=2, padding=padding)
mcs = torch.stack(mcs, dim=0)
if is_prod:
msssim_val = torch.prod((mcs[:-1] ** weights[:-1].unsqueeze(1)), dim=0) * (
ssim_val ** weights[-1]
)
else:
weights = weights / torch.sum(weights)
msssim_val = torch.sum((mcs[:-1] * weights[:-1].unsqueeze(1)), dim=0) + (
ssim_val * weights[-1]
)
return msssim_val
@ARCH_REGISTRY.register()
[docs]
class MS_SSIM(torch.nn.Module):
r"""Multiscale structure similarity
References:
Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale structural similarity for image
quality assessment." In The Thrity-Seventh Asilomar Conference on Signals, Systems & Computers,
2003, vol. 2, pp. 1398-1402. Ieee, 2003.
Args:
channel: Number of channel.
downsample: Boolean, whether to downsample which mimics official SSIM matlab code.
test_y_channel: Boolean, whether to use y channel on ycbcr which mimics official matlab code.
"""
def __init__(
self,
channels=3,
downsample=False,
test_y_channel=True,
is_prod=True,
color_space='yiq',
):
super(MS_SSIM, self).__init__()
self.downsample = downsample
self.test_y_channel = test_y_channel
self.color_space = color_space
self.is_prod = is_prod
self.data_range = 255
[docs]
def forward(self, X, Y):
"""Computation of MS-SSIM metric.
Args:
x: An input tensor. Shape :math:`(N, C, H, W)`.
y: A target tensor. Shape :math:`(N, C, H, W)`.
Returns:
Value of MS-SSIM metric in [0, 1] range.
"""
assert X.shape == Y.shape, (
'Input and reference images should have the same shape, but got'
)
f'{X.shape} and {Y.shape}'
X = preprocess_rgb(
X, self.test_y_channel, self.data_range, self.color_space
).to(torch.float64)
Y = preprocess_rgb(
Y, self.test_y_channel, self.data_range, self.color_space
).to(torch.float64)
score = ms_ssim(
X,
Y,
data_range=self.data_range,
downsample=self.downsample,
is_prod=self.is_prod,
)
return score
@ARCH_REGISTRY.register()
[docs]
class CW_SSIM(torch.nn.Module):
r"""Complex-Wavelet Structural SIMilarity (CW-SSIM) index.
References:
M. P. Sampat, Z. Wang, S. Gupta, A. C. Bovik, M. K. Markey.
"Complex Wavelet Structural Similarity: A New Image Similarity Index",
IEEE Transactions on Image Processing, 18(11), 2385-401, 2009.
Args:
- channel: Number of channel.
- test_y_channel: Boolean, whether to use y channel on ycbcr.
- level: The number of levels to used in the complex steerable pyramid decomposition
- ori: The number of orientations to be used in the complex steerable pyramid decomposition
- guardb: How much is discarded from the four image boundaries.
- K: the constant in the CWSSIM index formula (see the above reference) default value: K=0
"""
def __init__(
self,
channels=1,
level=4,
ori=8,
guardb=0,
K=0,
test_y_channel=True,
color_space='yiq',
):
super(CW_SSIM, self).__init__()
self.channels = channels
self.level = level
self.ori = ori
self.guardb = guardb
self.K = K
self.test_y_channel = test_y_channel
self.color_space = color_space
self.register_buffer('win7', torch.ones(channels, 1, 7, 7) / (7 * 7))
[docs]
def conj(self, x, y):
a = x[..., 0]
b = x[..., 1]
c = y[..., 0]
d = -y[..., 1]
return torch.stack((a * c - b * d, b * c + a * d), dim=1)
[docs]
def conv2d_complex(self, x, win, groups=1):
real = F.conv2d(x[:, 0, ...].unsqueeze(1), win, groups=groups)
imaginary = F.conv2d(x[:, 1, ...].unsqueeze(1), win, groups=groups)
return torch.stack((real, imaginary), dim=-1)
[docs]
def cw_ssim(self, x, y, test_y_channel):
r"""Compute CW-SSIM for a batch of images.
Args:
x: An input tensor. Shape :math:`(N, C, H, W)`.
y: A target tensor. Shape :math:`(N, C, H, W)`.
test_y_channel: Boolean, whether to use y channel on ycbcr.
Returns:
Index of similarity between two images. Usually in [0, 1] interval.
"""
# Whether calculate on y channel of ycbcr
if test_y_channel and x.shape[1] == 3:
x = to_y_channel(x, 255, self.color_space)
y = to_y_channel(y, 255, self.color_space)
pyr = SCFpyr_PyTorch(
height=self.level, nbands=self.ori, scale_factor=2, device=x.device
)
cw_x = pyr.build(x)
cw_y = pyr.build(y)
bandind = self.level
band_cssim = []
s = np.array(cw_x[bandind][0].size()[1:3])
w = fspecial(s - 7 + 1, s[0] / 4, 1).to(x.device)
gb = int(self.guardb / (2 ** (self.level - 1)))
self.win7 = self.win7.to(x.dtype)
for i in range(self.ori):
band1 = cw_x[bandind][i]
band2 = cw_y[bandind][i]
band1 = band1[:, gb : s[0] - gb, gb : s[1] - gb, :]
band2 = band2[:, gb : s[0] - gb, gb : s[1] - gb, :]
corr = self.conj(band1, band2)
corr_band = self.conv2d_complex(corr, self.win7, groups=self.channels)
varr = (
(math_util.abs(band1)) ** 2 + (math_util.abs(band2)) ** 2
).unsqueeze(1)
varr_band = F.conv2d(
varr, self.win7, stride=1, padding=0, groups=self.channels
)
cssim_map = (2 * math_util.abs(corr_band) + self.K) / (varr_band + self.K)
band_cssim.append(
(cssim_map * w.repeat(cssim_map.shape[0], 1, 1, 1)).sum([2, 3]).mean(1)
)
return torch.stack(band_cssim, dim=1).mean(1)
[docs]
def forward(self, X, Y):
r"""Computation of CW-SSIM metric.
Args:
X: An input tensor. Shape :math:`(N, C, H, W)`.
Y: A target tensor. Shape :math:`(N, C, H, W)`.
Returns:
Value of CW-SSIM metric in [0, 1] range.
"""
assert X.shape == Y.shape, (
f'Input {X.shape} and reference images should have the same shape'
)
score = self.cw_ssim(
X.to(torch.float64), Y.to(torch.float64), self.test_y_channel
)
return score