Source code for pyiqa.archs.deepdc_arch

r"""DeepDC: Deep Distance Correlation as a Perceptual Image Quality Evaluator

Reference:
@article{zhu2024adaptive,
  title={DeepDC: Deep Distance Correlation as a Perceptual Image Quality Evaluator},
  author={Zhu, Hanwei and Chen, Baoliang and Zhu, Lingyu and Wang, Shiqi and Weisi, Lin},
  journal={arXiv preprint arXiv:2211.04927},
  year={2024},
}

Reference url: https://github.com/h4nwei/DeepDC
"""

import torch
import torch.nn as nn
from collections import OrderedDict
import torchvision
from torchvision import models, transforms
from pyiqa.utils.registry import ARCH_REGISTRY

[docs] names = { 'vgg19': [ 'image', 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5', ], }
[docs] class MultiVGGFeaturesExtractor(nn.Module): """Extract multiple VGG19 feature maps for DeepDC. Args: target_features (tuple[str, ...]): VGG layer names to return. use_input_norm (bool): Whether to apply ImageNet normalization. requires_grad (bool): Whether backbone parameters require gradients. """ def __init__( self, target_features=('conv1_2', 'conv2_2', 'conv3_4', 'conv4_4', 'conv5_4'), use_input_norm=True, requires_grad=False, ): # ALL FALSE is the best for COS_Similarity; Correlation: use_norm = True super(MultiVGGFeaturesExtractor, self).__init__() self.use_input_norm = use_input_norm self.target_features = target_features model = torchvision.models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1) names_key = 'vgg19' if self.use_input_norm: mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) self.register_buffer('mean', mean) self.register_buffer('std', std) self.target_indexes = [ names[names_key].index(k) - 1 for k in self.target_features ] self.features = nn.Sequential( *list(model.features.children())[: (max(self.target_indexes) + 1)] ) if not requires_grad: for k, v in self.features.named_parameters(): v.requires_grad = False self.features.eval()
[docs] def forward(self, x): """Extract requested VGG feature maps. Args: x (torch.Tensor): Input tensor in range ``[0, 1]`` with shape ``(N, 3, H, W)``. Returns: collections.OrderedDict[str, torch.Tensor]: Mapping from feature names to feature tensors. """ # assume input range is [0, 1] if self.use_input_norm: x = (x - self.mean) / self.std y = OrderedDict() if 'image' in self.target_features: y.update({'image': x}) for key, layer in self.features._modules.items(): x = layer(x) # x = self._normalize_tensor(x) if int(key) in self.target_indexes: y.update({self.target_features[self.target_indexes.index(int(key))]: x}) return y
def _normalize_tensor(sefl, in_feat, eps=1e-10): norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True)) return in_feat / (norm_factor + eps)
@ARCH_REGISTRY.register()
[docs] class DeepDC(nn.Module): """DeepDC full-reference IQA metric. The score is derived from distance correlation between feature-space double-centered distance matrices. Args: features_to_compute (tuple[str, ...]): VGG feature names used in the DeepDC aggregation. """ def __init__( self, features_to_compute=('conv1_2', 'conv2_2', 'conv3_4', 'conv4_4', 'conv5_4'), ): super(DeepDC, self).__init__() self.MSE = torch.nn.MSELoss() self.features_extractor = MultiVGGFeaturesExtractor( target_features=features_to_compute ).eval()
[docs] def forward(self, x, y): r"""Compute DeepDC quality score. Args: x (torch.Tensor): Distorted image tensor with shape ``(N, 3, H, W)``. y (torch.Tensor): Reference image tensor with shape ``(N, 3, H, W)``. Returns: torch.Tensor: Quality score tensor with shape ``(N, 1)``. """ targets, inputs = x, y inputs_fea = self.features_extractor(inputs) with torch.no_grad(): targets_fea = self.features_extractor(targets) dc_scores = [] for _, key in enumerate(inputs_fea.keys()): inputs_dcdm = self._DCDM(inputs_fea[key]) targets_dcdm = self._DCDM(targets_fea[key]) dc_scores.append(self.Distance_Correlation(inputs_dcdm, targets_dcdm)) dc_scores = torch.stack(dc_scores, dim=1) score = 1 - dc_scores.mean(dim=1, keepdim=True) return score
# double-centered distance matrix (dcdm) def _DCDM(self, x): """Compute double-centered distance matrix for each sample. Args: x (torch.Tensor): Feature tensor of shape ``(N, C, H, W)`` or ``(N, M, C)``. Returns: torch.Tensor: Matrix tensor with shape ``(N, C, C)``. """ if len(x.shape) == 4: batchSize, dim, h, w = x.data.shape M = h * w elif len(x.shape) == 3: batchSize, M, dim = x.data.shape x = x.reshape(batchSize, dim, M) t = torch.log((1.0 / (torch.tensor(dim) * torch.tensor(dim)))) I = ( torch.eye(dim, dim, device=x.device) .view(1, dim, dim) .repeat(batchSize, 1, 1) .type(x.dtype) ) I_M = torch.ones(batchSize, dim, dim, device=x.device).type(x.dtype) x_pow2 = x.bmm(x.transpose(1, 2)) dcov = I_M.bmm(x_pow2 * I) + (x_pow2 * I).bmm(I_M) - 2 * x_pow2 dcov = torch.clamp(dcov, min=0.0) dcov = torch.exp(t) * dcov dcov = torch.sqrt(dcov + 1e-5) dcdm = ( dcov - 1.0 / dim * dcov.bmm(I_M) - 1.0 / dim * I_M.bmm(dcov) + 1.0 / (dim * dim) * I_M.bmm(dcov).bmm(I_M) ) return dcdm
[docs] def Distance_Correlation(self, matrix_A, matrix_B): """Compute distance correlation between matrix batches. Args: matrix_A (torch.Tensor): First matrix batch of shape ``(N, C, C)``. matrix_B (torch.Tensor): Second matrix batch of shape ``(N, C, C)``. Returns: torch.Tensor: Correlation values with shape ``(N,)``. """ Gamma_XY = torch.sum(matrix_A * matrix_B, dim=[1, 2]) Gamma_XX = torch.sum(matrix_A * matrix_A, dim=[1, 2]) Gamma_YY = torch.sum(matrix_B * matrix_B, dim=[1, 2]) c = 1e-6 correlation_r = (Gamma_XY + c) / (torch.sqrt(Gamma_XX * Gamma_YY) + c) return correlation_r
[docs] def prepare_image(image, resize=True): """Convert a PIL image to a 4D tensor for DeepDC demos. Args: image (PIL.Image.Image): Input image. resize (bool): If ``True``, resize shortest side to ``256`` when both sides are larger than ``256``. Returns: torch.Tensor: Tensor with shape ``(1, 3, H, W)``. """ if resize and min(image.size) > 256: image = transforms.functional.resize(image, 256) image = transforms.ToTensor()(image) return image.unsqueeze(0)
if __name__ == '__main__': from PIL import Image import argparse
[docs] parser = argparse.ArgumentParser()
parser.add_argument('--ref', type=str, default='../images/r0.png') parser.add_argument('--dist', type=str, default='../images/r1.png') args = parser.parse_args() ref = prepare_image(Image.open(args.ref).convert('RGB')) dist = prepare_image(Image.open(args.dist).convert('RGB')) assert ref.shape == dist.shape device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = DeepDC().to(device) ref = ref.to(device) dist = dist.to(device) score = model(ref, dist) print(score.item()) # score: 0.3347