r"""Debiased Mapping for Full-Reference Image Quality Assessment
@article{chen2025debiased,
title={Debiased mapping for full-reference image quality assessment},
author={Chen, Baoliang and Zhu, Hanwei and Zhu, Lingyu and Wang, Shanshe and Pan, Jingshan and Wang, Shiqi},
journal={IEEE Transactions on Multimedia},
year={2025},
publisher={IEEE}
}
Reference:
- Arxiv link: https://ieeexplore.ieee.org/abstract/document/10886996
- Official Github: https://github.com/Baoliang93/DMM
"""
import torch
import torchvision
import torch.nn as nn
from collections import OrderedDict
import torch.nn.functional as F
from pyiqa.utils.registry import ARCH_REGISTRY
#----------------------- VGGNet-----------------------------------
[docs]
names = {'vgg16': ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
'conv3_3', 'relu3_3', 'pool3',
'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2',
'conv4_3', 'relu4_3', 'pool4',
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
'conv5_3', 'relu5_3', 'pool5'],
}
[docs]
class L2Pool2d(torch.nn.Module):
r"""Applies L2 pooling with Hann window of size 3x3
Args:
x: Tensor with shape (N, C, H, W)"""
EPS = 1e-12
def __init__(self, kernel_size: int = 3, stride: int = 2, padding=1) -> None:
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.kernel = None
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.kernel is None:
C = x.size(1)
self.kernel = self._hann_filter(self.kernel_size).repeat((C, 1, 1, 1)).to(x)
out = torch.nn.functional.conv2d(
x ** 2, self.kernel,
stride=self.stride,
padding=self.padding,
groups=x.shape[1]
)
return (out + self.EPS).sqrt()
def _hann_filter(self, kernel_size: int) -> torch.Tensor:
r"""Creates Hann kernel
Returns:
kernel: Tensor with shape (1, kernel_size, kernel_size)
"""
window = torch.hann_window(kernel_size + 2, periodic=False)[1:-1]
kernel = window[:, None] * window[None, :]
return kernel.view(1, kernel_size, kernel_size) / kernel.sum()
#----------------------- Main Class-----------------------------------
@ARCH_REGISTRY.register()
[docs]
class DMM(nn.Module):
def __init__(self, reduce_dim=256, kernel_size=5, features_to_compute=('relu3_3','relu4_3'), criterion=torch.nn.CosineSimilarity(), use_dropout=True, **kwargs):
super().__init__()
self.criterion = criterion
self.features_extractor = FeaturesExtractor(target_features=features_to_compute, replace_pooling=True)
self.patchsize = 16
self.stride = 4
self.unfold = nn.Unfold(kernel_size=self.patchsize, stride=self.stride )
[docs]
def forward(self, Dist, Ref, as_loss=False):
# preprocess image
Ref = self.prepare_image_adt(Ref)
Dist = self.prepare_image_adt(Dist)
# main forward
Ref_fea = self.features_extractor(Ref)
with torch.no_grad():
Dist_fea = self.features_extractor(Dist)
dist = 0.
c1 = 1e-6
c2 = 1e-6
for key in Ref_fea.keys():
tdistparam = Dist_fea[key]
tprisparam = Ref_fea[key]
a, b, c, d = tdistparam.size()
k = self.patchsize
distparam = self.unfold(tdistparam).view(a, b, k, k, -1).transpose(2, 4).contiguous()
prisparam = self.unfold(tprisparam).view(a, b, k, k, -1).transpose(2, 4).contiguous()
pt_num = distparam.shape[2]
distparam = distparam.view(a*b*pt_num, k, k)
prisparam = prisparam.view(a*b*pt_num, k, k)
u_r, s_r, v_r = torch.svd(prisparam)
u_d, s_d, v_d = torch.svd(distparam)
s_r = s_r.contiguous().view(a*b,pt_num,k)
s_d = s_d.contiguous().view(a*b,pt_num,k)
diff_s = ((s_r - s_d)**2)
u_rd = (torch.abs((u_r.contiguous().view(a,b,pt_num,k,k))*(u_d.contiguous().view(a,b,pt_num,k,k)))).sum(-1)
wt = (u_rd.std(-1))/(u_rd.mean(-1)+1e-9)
diff_s= (((diff_s).view(a,b,pt_num,k).sum(-1))*wt).mean(-1).mean(-1)
x_mean = tdistparam.mean([2,3], keepdim=True)
y_mean = tprisparam.mean([2,3], keepdim=True)
S1 = (2*x_mean*y_mean+c1)/(x_mean**2+y_mean**2+c2)
glb = S1.squeeze(-1).squeeze(-1).mean(1)
glb = torch.exp(-2.0*glb)
dist = dist+diff_s*glb
if as_loss:
return dist.mean()
else:
return dist
[docs]
def prepare_image_adt(self, tensor_image):
b, c, h, w = tensor_image.shape
msize = min(w, h)
if msize > 128:
tar_size = max(int(msize/(1.0*48))*32, 128)
tensor_image = F.interpolate(tensor_image, size=tar_size, mode='bilinear',align_corners=False )
return tensor_image