Source code for pyiqa.archs.dists_arch

r"""DISTS metric introduced in

@article{ding2020iqa,
  title={Image Quality Assessment: Unifying Structure and Texture Similarity},
  author={Ding, Keyan and Ma, Kede and Wang, Shiqi and Simoncelli, Eero P.},
  journal = {CoRR},
  volume = {abs/2004.07728},
  year={2020},
  url = {https://arxiv.org/abs/2004.07728}
}

Created by: https://github.com/dingkeyan93/DISTS/blob/master/DISTS_pytorch/DISTS_pt.py
Re-implemented by: Jiadi Mo (https://github.com/JiadiMo)

"""

import numpy as np
import torch
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F

from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network
from pyiqa.archs.arch_util import get_url_from_name

[docs] default_model_urls = {'url': get_url_from_name('DISTS_weights-f5e65c96.pth')}
[docs] class L2pooling(nn.Module): def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0): super(L2pooling, self).__init__() self.padding = (filter_size - 2) // 2 self.stride = stride self.channels = channels a = np.hanning(filter_size)[1:-1] g = torch.Tensor(a[:, None] * a[None, :]) g = g / torch.sum(g) self.register_buffer( 'filter', g[None, None, :, :].repeat((self.channels, 1, 1, 1)) )
[docs] def forward(self, input): input = input**2 out = F.conv2d( input, self.filter, stride=self.stride, padding=self.padding, groups=input.shape[1], ) return (out + 1e-12).sqrt()
@ARCH_REGISTRY.register()
[docs] class DISTS(torch.nn.Module): r"""DISTS model. Args: pretrained_model_path (String): Pretrained model path. """ def __init__(self, pretrained=True, pretrained_model_path=None, **kwargs): """Refer to official code https://github.com/dingkeyan93/DISTS""" super(DISTS, self).__init__() vgg_pretrained_features = models.vgg16(weights='IMAGENET1K_V1').features self.stage1 = torch.nn.Sequential() self.stage2 = torch.nn.Sequential() self.stage3 = torch.nn.Sequential() self.stage4 = torch.nn.Sequential() self.stage5 = torch.nn.Sequential() for x in range(0, 4): self.stage1.add_module(str(x), vgg_pretrained_features[x]) self.stage2.add_module(str(4), L2pooling(channels=64)) for x in range(5, 9): self.stage2.add_module(str(x), vgg_pretrained_features[x]) self.stage3.add_module(str(9), L2pooling(channels=128)) for x in range(10, 16): self.stage3.add_module(str(x), vgg_pretrained_features[x]) self.stage4.add_module(str(16), L2pooling(channels=256)) for x in range(17, 23): self.stage4.add_module(str(x), vgg_pretrained_features[x]) self.stage5.add_module(str(23), L2pooling(channels=512)) for x in range(24, 30): self.stage5.add_module(str(x), vgg_pretrained_features[x]) for param in self.parameters(): param.requires_grad = False self.register_buffer( 'mean', torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1) ) self.register_buffer( 'std', torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1) ) self.chns = [3, 64, 128, 256, 512, 512] self.register_parameter( 'alpha', nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)) ) self.register_parameter( 'beta', nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)) ) self.alpha.data.normal_(0.1, 0.01) self.beta.data.normal_(0.1, 0.01) if pretrained_model_path is not None: load_pretrained_network(self, pretrained_model_path, False) elif pretrained: load_pretrained_network(self, default_model_urls['url'], False)
[docs] def forward_once(self, x): h = (x - self.mean) / self.std h = self.stage1(h) h_relu1_2 = h h = self.stage2(h) h_relu2_2 = h h = self.stage3(h) h_relu3_3 = h h = self.stage4(h) h_relu4_3 = h h = self.stage5(h) h_relu5_3 = h return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
[docs] def forward(self, x, y): r"""Compute IQA using DISTS model. Args: - x: An input tensor with (N, C, H, W) shape. RGB channel order for colour images. - y: An reference tensor with (N, C, H, W) shape. RGB channel order for colour images. Returns: Value of DISTS model. """ feats0 = self.forward_once(x) feats1 = self.forward_once(y) dist1 = 0 dist2 = 0 c1 = 1e-6 c2 = 1e-6 w_sum = self.alpha.sum() + self.beta.sum() alpha = torch.split(self.alpha / w_sum, self.chns, dim=1) beta = torch.split(self.beta / w_sum, self.chns, dim=1) for k in range(len(self.chns)): x_mean = feats0[k].mean([2, 3], keepdim=True) y_mean = feats1[k].mean([2, 3], keepdim=True) S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1) dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True) x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True) y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True) xy_cov = (feats0[k] * feats1[k]).mean( [2, 3], keepdim=True ) - x_mean * y_mean S2 = (2 * xy_cov + c2) / (x_var + y_var + c2) dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True) score = 1 - (dist1 + dist2) return score.squeeze(-1).squeeze(-1)