r"""VIF Metric
@article{sheikh2006vif,
title={Image information and visual quality},
author={Sheikh, Hamid R and Bovik, Alan C},
journal={IEEE Transactions on image processing},
volume={15},
number={2},
pages={430--444},
year={2006},
publisher={IEEE}
}
Created by: https://github.com/dingkeyan93/IQA-optimization/blob/master/IQA_pytorch/VIF.py
Modified by: Jiadi Mo (https://github.com/JiadiMo)
Refer to:
Matlab code from http://live.ece.utexas.edu/research/Quality/vifvec_release.zip;
"""
import torch
from torch.nn import functional as F
import numpy as np
from pyiqa.utils.color_util import to_y_channel
from pyiqa.utils.registry import ARCH_REGISTRY
[docs]
def sp5_filters():
r"""Define spatial filters."""
filters = {}
filters['harmonics'] = np.array([1, 3, 5])
filters['mtx'] = np.array(
[
[0.3333, 0.2887, 0.1667, 0.0000, -0.1667, -0.2887],
[0.0000, 0.1667, 0.2887, 0.3333, 0.2887, 0.1667],
[0.3333, -0.0000, -0.3333, -0.0000, 0.3333, -0.0000],
[0.0000, 0.3333, 0.0000, -0.3333, 0.0000, 0.3333],
[0.3333, -0.2887, 0.1667, -0.0000, -0.1667, 0.2887],
[-0.0000, 0.1667, -0.2887, 0.3333, -0.2887, 0.1667],
]
)
filters['hi0filt'] = np.array(
[
[
-0.00033429,
-0.00113093,
-0.00171484,
-0.00133542,
-0.00080639,
-0.00133542,
-0.00171484,
-0.00113093,
-0.00033429,
],
[
-0.00113093,
-0.00350017,
-0.00243812,
0.00631653,
0.01261227,
0.00631653,
-0.00243812,
-0.00350017,
-0.00113093,
],
[
-0.00171484,
-0.00243812,
-0.00290081,
-0.00673482,
-0.00981051,
-0.00673482,
-0.00290081,
-0.00243812,
-0.00171484,
],
[
-0.00133542,
0.00631653,
-0.00673482,
-0.07027679,
-0.11435863,
-0.07027679,
-0.00673482,
0.00631653,
-0.00133542,
],
[
-0.00080639,
0.01261227,
-0.00981051,
-0.11435863,
0.81380200,
-0.11435863,
-0.00981051,
0.01261227,
-0.00080639,
],
[
-0.00133542,
0.00631653,
-0.00673482,
-0.07027679,
-0.11435863,
-0.07027679,
-0.00673482,
0.00631653,
-0.00133542,
],
[
-0.00171484,
-0.00243812,
-0.00290081,
-0.00673482,
-0.00981051,
-0.00673482,
-0.00290081,
-0.00243812,
-0.00171484,
],
[
-0.00113093,
-0.00350017,
-0.00243812,
0.00631653,
0.01261227,
0.00631653,
-0.00243812,
-0.00350017,
-0.00113093,
],
[
-0.00033429,
-0.00113093,
-0.00171484,
-0.00133542,
-0.00080639,
-0.00133542,
-0.00171484,
-0.00113093,
-0.00033429,
],
]
)
filters['lo0filt'] = np.array(
[
[0.00341614, -0.01551246, -0.03848215, -0.01551246, 0.00341614],
[-0.01551246, 0.05586982, 0.15925570, 0.05586982, -0.01551246],
[-0.03848215, 0.15925570, 0.40304148, 0.15925570, -0.03848215],
[-0.01551246, 0.05586982, 0.15925570, 0.05586982, -0.01551246],
[0.00341614, -0.01551246, -0.03848215, -0.01551246, 0.00341614],
]
)
filters['lofilt'] = 2 * np.array(
[
[
0.00085404,
-0.00244917,
-0.00387812,
-0.00944432,
-0.00962054,
-0.00944432,
-0.00387812,
-0.00244917,
0.00085404,
],
[
-0.00244917,
-0.00523281,
-0.00661117,
0.00410600,
0.01002988,
0.00410600,
-0.00661117,
-0.00523281,
-0.00244917,
],
[
-0.00387812,
-0.00661117,
0.01396746,
0.03277038,
0.03981393,
0.03277038,
0.01396746,
-0.00661117,
-0.00387812,
],
[
-0.00944432,
0.00410600,
0.03277038,
0.06426333,
0.08169618,
0.06426333,
0.03277038,
0.00410600,
-0.00944432,
],
[
-0.00962054,
0.01002988,
0.03981393,
0.08169618,
0.10096540,
0.08169618,
0.03981393,
0.01002988,
-0.00962054,
],
[
-0.00944432,
0.00410600,
0.03277038,
0.06426333,
0.08169618,
0.06426333,
0.03277038,
0.00410600,
-0.00944432,
],
[
-0.00387812,
-0.00661117,
0.01396746,
0.03277038,
0.03981393,
0.03277038,
0.01396746,
-0.00661117,
-0.00387812,
],
[
-0.00244917,
-0.00523281,
-0.00661117,
0.00410600,
0.01002988,
0.00410600,
-0.00661117,
-0.00523281,
-0.00244917,
],
[
0.00085404,
-0.00244917,
-0.00387812,
-0.00944432,
-0.00962054,
-0.00944432,
-0.00387812,
-0.00244917,
0.00085404,
],
]
)
filters['bfilts'] = np.array(
[
[
0.00277643,
0.00496194,
0.01026699,
0.01455399,
0.01026699,
0.00496194,
0.00277643,
-0.00986904,
-0.00893064,
0.01189859,
0.02755155,
0.01189859,
-0.00893064,
-0.00986904,
-0.01021852,
-0.03075356,
-0.08226445,
-0.11732297,
-0.08226445,
-0.03075356,
-0.01021852,
0.00000000,
0.00000000,
0.00000000,
0.00000000,
0.00000000,
0.00000000,
0.00000000,
0.01021852,
0.03075356,
0.08226445,
0.11732297,
0.08226445,
0.03075356,
0.01021852,
0.00986904,
0.00893064,
-0.01189859,
-0.02755155,
-0.01189859,
0.00893064,
0.00986904,
-0.00277643,
-0.00496194,
-0.01026699,
-0.01455399,
-0.01026699,
-0.00496194,
-0.00277643,
],
[
-0.00343249,
-0.00640815,
-0.00073141,
0.01124321,
0.00182078,
0.00285723,
0.01166982,
-0.00358461,
-0.01977507,
-0.04084211,
-0.00228219,
0.03930573,
0.01161195,
0.00128000,
0.01047717,
0.01486305,
-0.04819057,
-0.12227230,
-0.05394139,
0.00853965,
-0.00459034,
0.00790407,
0.04435647,
0.09454202,
-0.00000000,
-0.09454202,
-0.04435647,
-0.00790407,
0.00459034,
-0.00853965,
0.05394139,
0.12227230,
0.04819057,
-0.01486305,
-0.01047717,
-0.00128000,
-0.01161195,
-0.03930573,
0.00228219,
0.04084211,
0.01977507,
0.00358461,
-0.01166982,
-0.00285723,
-0.00182078,
-0.01124321,
0.00073141,
0.00640815,
0.00343249,
],
[
0.00343249,
0.00358461,
-0.01047717,
-0.00790407,
-0.00459034,
0.00128000,
0.01166982,
0.00640815,
0.01977507,
-0.01486305,
-0.04435647,
0.00853965,
0.01161195,
0.00285723,
0.00073141,
0.04084211,
0.04819057,
-0.09454202,
-0.05394139,
0.03930573,
0.00182078,
-0.01124321,
0.00228219,
0.12227230,
-0.00000000,
-0.12227230,
-0.00228219,
0.01124321,
-0.00182078,
-0.03930573,
0.05394139,
0.09454202,
-0.04819057,
-0.04084211,
-0.00073141,
-0.00285723,
-0.01161195,
-0.00853965,
0.04435647,
0.01486305,
-0.01977507,
-0.00640815,
-0.01166982,
-0.00128000,
0.00459034,
0.00790407,
0.01047717,
-0.00358461,
-0.00343249,
],
[
-0.00277643,
0.00986904,
0.01021852,
-0.00000000,
-0.01021852,
-0.00986904,
0.00277643,
-0.00496194,
0.00893064,
0.03075356,
-0.00000000,
-0.03075356,
-0.00893064,
0.00496194,
-0.01026699,
-0.01189859,
0.08226445,
-0.00000000,
-0.08226445,
0.01189859,
0.01026699,
-0.01455399,
-0.02755155,
0.11732297,
-0.00000000,
-0.11732297,
0.02755155,
0.01455399,
-0.01026699,
-0.01189859,
0.08226445,
-0.00000000,
-0.08226445,
0.01189859,
0.01026699,
-0.00496194,
0.00893064,
0.03075356,
-0.00000000,
-0.03075356,
-0.00893064,
0.00496194,
-0.00277643,
0.00986904,
0.01021852,
-0.00000000,
-0.01021852,
-0.00986904,
0.00277643,
],
[
-0.01166982,
-0.00128000,
0.00459034,
0.00790407,
0.01047717,
-0.00358461,
-0.00343249,
-0.00285723,
-0.01161195,
-0.00853965,
0.04435647,
0.01486305,
-0.01977507,
-0.00640815,
-0.00182078,
-0.03930573,
0.05394139,
0.09454202,
-0.04819057,
-0.04084211,
-0.00073141,
-0.01124321,
0.00228219,
0.12227230,
-0.00000000,
-0.12227230,
-0.00228219,
0.01124321,
0.00073141,
0.04084211,
0.04819057,
-0.09454202,
-0.05394139,
0.03930573,
0.00182078,
0.00640815,
0.01977507,
-0.01486305,
-0.04435647,
0.00853965,
0.01161195,
0.00285723,
0.00343249,
0.00358461,
-0.01047717,
-0.00790407,
-0.00459034,
0.00128000,
0.01166982,
],
[
-0.01166982,
-0.00285723,
-0.00182078,
-0.01124321,
0.00073141,
0.00640815,
0.00343249,
-0.00128000,
-0.01161195,
-0.03930573,
0.00228219,
0.04084211,
0.01977507,
0.00358461,
0.00459034,
-0.00853965,
0.05394139,
0.12227230,
0.04819057,
-0.01486305,
-0.01047717,
0.00790407,
0.04435647,
0.09454202,
-0.00000000,
-0.09454202,
-0.04435647,
-0.00790407,
0.01047717,
0.01486305,
-0.04819057,
-0.12227230,
-0.05394139,
0.00853965,
-0.00459034,
-0.00358461,
-0.01977507,
-0.04084211,
-0.00228219,
0.03930573,
0.01161195,
0.00128000,
-0.00343249,
-0.00640815,
-0.00073141,
0.01124321,
0.00182078,
0.00285723,
0.01166982,
],
]
).T
return filters
[docs]
def corrDn(image, filt, step=1, channels=1):
r"""Compute correlation of image with FILT, followed by downsampling.
Args:
image: A tensor. Shape :math:`(N, C, H, W)`.
filt: A filter.
step: Downsampling factors.
channels: Number of channels.
"""
filt_ = (
torch.from_numpy(filt)
.float()
.unsqueeze(0)
.unsqueeze(0)
.repeat(channels, 1, 1, 1)
.to(image.device)
)
p = (filt_.shape[2] - 1) // 2
image = F.pad(image, (p, p, p, p), 'reflect')
img = F.conv2d(image, filt_, stride=step, padding=0, groups=channels)
return img
[docs]
def SteerablePyramidSpace(image, height=4, order=5, channels=1):
r"""Construct a steerable pyramid on image.
Args:
image: A tensor. Shape :math:`(N, C, H, W)`.
height (int): Number of pyramid levels to build.
order (int): Number of orientations.
channels (int): Number of channels.
"""
num_orientations = order + 1
filters = sp5_filters()
hi0 = corrDn(image, filters['hi0filt'], step=1, channels=channels)
pyr_coeffs = []
pyr_coeffs.append(hi0)
lo = corrDn(image, filters['lo0filt'], step=1, channels=channels)
for _ in range(height):
bfiltsz = int(np.floor(np.sqrt(filters['bfilts'].shape[0])))
for b in range(num_orientations):
filt = filters['bfilts'][:, b].reshape(bfiltsz, bfiltsz).T
band = corrDn(lo, filt, step=1, channels=channels)
pyr_coeffs.append(band)
lo = corrDn(lo, filters['lofilt'], step=2, channels=channels)
pyr_coeffs.append(lo)
return pyr_coeffs
@ARCH_REGISTRY.register()
[docs]
class VIF(torch.nn.Module):
r"""Image Information and Visual Quality metric
Args:
channels (int): Number of channels.
level (int): Number of levels to build.
ori (int): Number of orientations.
Reference:
Sheikh, Hamid R., and Alan C. Bovik. "Image information and visual quality."
IEEE Transactions on image processing 15, no. 2 (2006): 430-444.
"""
def __init__(self, channels=1, level=4, ori=6):
super(VIF, self).__init__()
self.ori = ori - 1
self.level = level
self.channels = channels
self.M = 3
self.subbands = [4, 7, 10, 13, 16, 19, 22, 25]
self.sigma_nsq = 0.4
self.tol = 1e-12
[docs]
def corrDn_win(self, image, filt, step=1, channels=1, start=[0, 0], end=[0, 0]):
r"""Compute correlation of image with FILT using window, followed by downsampling.
Args:
image: A tensor. Shape :math:`(N, C, H, W)`.
filt: A filter.
step (int): Downsampling factors.
channels (int): Number of channels.
start (list): The window over which the convolution occurs.
end (list): The window over which the convolution occurs.
"""
filt_ = (
torch.from_numpy(filt)
.float()
.unsqueeze(0)
.unsqueeze(0)
.repeat(channels, 1, 1, 1)
.to(image.device)
)
p = (filt_.shape[2] - 1) // 2
image = F.pad(image, (p, p, p, p), 'reflect')
img = F.conv2d(image, filt_, stride=1, padding=0, groups=channels)
img = img[:, :, start[0] : end[0] : step, start[1] : end[1] : step]
return img
[docs]
def vifsub_est_M(self, org, dist):
r"""Calculate the parameters of the distortion channel.
Args:
org: A reference tensor. Shape :math:`(N, C, H, W)`.
dist: A distortion tensor. Shape :math:`(N, C, H, W)`.
"""
g_all = []
vv_all = []
for i in range(len(self.subbands)):
sub = self.subbands[i] - 1
y = org[sub]
yn = dist[sub]
lev = np.ceil((sub - 1) / 6)
winsize = int(2**lev + 1)
win = np.ones((winsize, winsize))
newsizeX = int(np.floor(y.shape[2] / self.M) * self.M)
newsizeY = int(np.floor(y.shape[3] / self.M) * self.M)
y = y[:, :, :newsizeX, :newsizeY]
yn = yn[:, :, :newsizeX, :newsizeY]
winstart = [int(1 * np.floor(self.M / 2)), int(1 * np.floor(self.M / 2))]
winend = [
int(y.shape[2] - np.ceil(self.M / 2)) + 1,
int(y.shape[3] - np.ceil(self.M / 2)) + 1,
]
mean_x = self.corrDn_win(
y,
win / (winsize**2),
step=self.M,
channels=self.channels,
start=winstart,
end=winend,
)
mean_y = self.corrDn_win(
yn,
win / (winsize**2),
step=self.M,
channels=self.channels,
start=winstart,
end=winend,
)
cov_xy = (
self.corrDn_win(
y * yn,
win,
step=self.M,
channels=self.channels,
start=winstart,
end=winend,
)
- (winsize**2) * mean_x * mean_y
)
ss_x = (
self.corrDn_win(
y**2,
win,
step=self.M,
channels=self.channels,
start=winstart,
end=winend,
)
- (winsize**2) * mean_x**2
)
ss_y = (
self.corrDn_win(
yn**2,
win,
step=self.M,
channels=self.channels,
start=winstart,
end=winend,
)
- (winsize**2) * mean_y**2
)
ss_x = F.relu(ss_x)
ss_y = F.relu(ss_y)
g = cov_xy / (ss_x + self.tol)
vv = (ss_y - g * cov_xy) / (winsize**2)
g = g.masked_fill(ss_x < self.tol, 0)
vv[ss_x < self.tol] = ss_y[ss_x < self.tol]
ss_x = ss_x.masked_fill(ss_x < self.tol, 0)
g = g.masked_fill(ss_y < self.tol, 0)
vv = vv.masked_fill(ss_y < self.tol, 0)
vv[g < 0] = ss_y[g < 0]
g = F.relu(g)
vv = vv.masked_fill(vv < self.tol, self.tol)
g_all.append(g)
vv_all.append(vv)
return g_all, vv_all
[docs]
def refparams_vecgsm(self, org):
r"""Calculate the parameters of the reference image.
Args:
org: A reference tensor. Shape :math:`(N, C, H, W)`.
"""
ssarr, l_arr, cu_arr = [], [], []
for i in range(len(self.subbands)):
sub = self.subbands[i] - 1
y = org[sub]
M = self.M
newsizeX = int(np.floor(y.shape[2] / M) * M)
newsizeY = int(np.floor(y.shape[3] / M) * M)
y = y[:, :, :newsizeX, :newsizeY]
B, C, H, W = y.shape
temp = []
for j in range(M):
for k in range(M):
temp.append(
y[:, :, k : H - (M - k) + 1, j : W - (M - j) + 1].reshape(
B, C, -1
)
)
temp = torch.stack(temp, dim=3)
mcu = torch.mean(temp, dim=2).unsqueeze(2).repeat(1, 1, temp.shape[2], 1)
cu = (
torch.matmul((temp - mcu).permute(0, 1, 3, 2), temp - mcu)
/ temp.shape[2]
)
temp = []
for j in range(M):
for k in range(M):
temp.append(y[:, :, k : H + 1 : M, j : W + 1 : M].reshape(B, C, -1))
temp = torch.stack(temp, dim=2)
ss = torch.matmul(torch.pinverse(cu), temp)
ss = torch.sum(ss * temp, dim=2) / (M * M)
ss = ss.reshape(B, C, H // M, W // M)
v, _ = torch.linalg.eigh(cu, UPLO='U')
l_arr.append(v)
ssarr.append(ss)
cu_arr.append(cu)
return ssarr, l_arr, cu_arr
[docs]
def vif(self, x, y):
r"""VIF metric. Order of input is important.
Args:
x: A distortion tensor. Shape :math:`(N, C, H, W)`.
y: A reference tensor. Shape :math:`(N, C, H, W)`.
"""
# Convert RGB image to YCBCR and use the Y-channel.
x = to_y_channel(x, 255)
y = to_y_channel(y, 255)
sp_x = SteerablePyramidSpace(
x, height=self.level, order=self.ori, channels=self.channels
)[::-1]
sp_y = SteerablePyramidSpace(
y, height=self.level, order=self.ori, channels=self.channels
)[::-1]
g_all, vv_all = self.vifsub_est_M(sp_y, sp_x)
ss_arr, l_arr, cu_arr = self.refparams_vecgsm(sp_y)
num, den = [], []
for i in range(len(self.subbands)):
sub = self.subbands[i]
g = g_all[i]
vv = vv_all[i]
ss = ss_arr[i]
lamda = l_arr[i]
neigvals = lamda.shape[2]
lev = np.ceil((sub - 1) / 6)
winsize = 2**lev + 1
offset = (winsize - 1) / 2
offset = int(np.ceil(offset / self.M))
_, _, H, W = g.shape
g = g[:, :, offset : H - offset, offset : W - offset]
vv = vv[:, :, offset : H - offset, offset : W - offset]
ss = ss[:, :, offset : H - offset, offset : W - offset]
temp1 = 0
temp2 = 0
for j in range(neigvals):
cc = lamda[:, :, j].unsqueeze(2).unsqueeze(3)
temp1 = temp1 + torch.sum(
torch.log2(1 + g * g * ss * cc / (vv + self.sigma_nsq)), dim=[2, 3]
)
temp2 = temp2 + torch.sum(
torch.log2(1 + ss * cc / (self.sigma_nsq)), dim=[2, 3]
)
num.append(temp1.mean(1))
den.append(temp2.mean(1))
return torch.stack(num, dim=1).sum(1) / (torch.stack(den, dim=1).sum(1) + 1e-12)
[docs]
def forward(self, X, Y):
r"""Args:
x: A distortion tensor. Shape :math:`(N, C, H, W)`.
y: A reference tensor. Shape :math:`(N, C, H, W)`.
Order of input is important.
"""
assert X.shape == Y.shape, (
'Input and reference images should have the same shape, but got'
)
f'{X.shape} and {Y.shape}'
score = self.vif(X, Y)
return score