Source code for pyiqa.archs.arniqa_arch

r"""ARNIQA: Learning Distortion Manifold for Image Quality Assessment

@inproceedings{agnolucci2024arniqa,
  title={ARNIQA: Learning Distortion Manifold for Image Quality Assessment},
  author={Agnolucci, Lorenzo and Galteri, Leonardo and Bertini, Marco and Del Bimbo, Alberto},
  booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
  pages={189--198},
  year={2024}
}

Reference:
    - Arxiv link: https://www.arxiv.org/abs/2310.14918
    - Official Github: https://github.com/miccunifi/ARNIQA
"""

import torch
from torch import nn
import torch.nn.functional as F
import torchvision.models
from typing import Tuple
import warnings
from collections import OrderedDict

from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import get_url_from_name
from pyiqa.api_helpers import get_dataset_info

# Avoid warning related to loading a jit model from torch.hub
warnings.filterwarnings('ignore', category=UserWarning, module='torch.serialization')

[docs] DATASET_INFO = get_dataset_info()
DATASET_INFO['clive'] = DATASET_INFO['livec'] DATASET_INFO['tid'] = DATASET_INFO['tid2013'] DATASET_INFO['koniq'] = DATASET_INFO['koniq10k'] DATASET_INFO['kadid'] = DATASET_INFO['kadid10k']
[docs] default_model_urls = { 'ARNIQA': get_url_from_name(name='ARNIQA.pth'), 'live': get_url_from_name(name='regressor_live.pth'), 'csiq': get_url_from_name(name='regressor_csiq.pth'), 'tid': get_url_from_name(name='regressor_tid2013.pth'), 'kadid': get_url_from_name(name='regressor_kadid10k.pth'), 'koniq': get_url_from_name(name='regressor_koniq10k.pth'), 'clive': get_url_from_name(name='regressor_clive.pth'), 'flive': get_url_from_name(name='regressor_flive.pth'), 'spaq': get_url_from_name(name='regressor_spaq.pth'), }
@ARCH_REGISTRY.register()
[docs] class ARNIQA(nn.Module): """ ARNIQA model implementation. This class implements the ARNIQA model for image quality assessment, which combines a ResNet50 encoder with a regressor network for predicting image quality scores. Args: regressor_dataset (str, optional): The dataset to use for the regressor. Default is "koniq". Attributes: regressor_dataset (str): The dataset to use for the regressor. encoder (nn.Module): The ResNet50 encoder. feat_dim (int): The feature dimension of the encoder. regressor (nn.Module): The regressor network. default_mean (torch.Tensor): The mean values for normalization. default_std (torch.Tensor): The standard deviation values for normalization. """ def __init__(self, regressor_dataset: str = 'koniq'): super().__init__()
[docs] self.regressor_dataset = regressor_dataset
[docs] self.encoder = torchvision.models.resnet50( weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1 ) # V1 weights work better than V2
[docs] self.feat_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) encoder_state_dict = torch.hub.load_state_dict_from_url( default_model_urls['ARNIQA'], progress=True, map_location='cpu' ) cleaned_encoder_state_dict = OrderedDict() for key, value in encoder_state_dict.items(): # Remove the prefix if key.startswith('model.'): new_key = key[6:] cleaned_encoder_state_dict[new_key] = value self.encoder.load_state_dict(cleaned_encoder_state_dict) self.encoder.eval()
[docs] self.regressor: nn.Module = torch.hub.load_state_dict_from_url( default_model_urls[self.regressor_dataset], progress=True, map_location='cpu', ) # Load regressor from torch.hub as JIT model
self.regressor.eval()
[docs] self.default_mean = torch.Tensor(IMAGENET_DEFAULT_MEAN).view(1, 3, 1, 1)
[docs] self.default_std = torch.Tensor(IMAGENET_DEFAULT_STD).view(1, 3, 1, 1)
[docs] def forward(self, x: torch.Tensor) -> float: """ Forward pass of the ARNIQA model. Args: x (torch.Tensor): The input tensor. Returns: float: The predicted quality score. """ x, x_ds = self._preprocess(x) f = F.normalize(self.encoder(x), dim=1) f_ds = F.normalize(self.encoder(x_ds), dim=1) f_combined = torch.hstack((f, f_ds)).view(-1, self.feat_dim * 2) score = self.regressor(f_combined) score = self._scale_score(score) return score
def _preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Downsample the input image with a factor of 2 and normalize the original and downsampled images. Args: x (torch.Tensor): The input tensor. Returns: Tuple[torch.Tensor, torch.Tensor]: The normalized original and downsampled tensors. """ x_ds = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) x = (x - self.default_mean.to(x)) / self.default_std.to(x) x_ds = (x_ds - self.default_mean.to(x_ds)) / self.default_std.to(x_ds) return x, x_ds def _scale_score(self, score: torch.Tensor) -> torch.Tensor: """ Scale the score in the range [0, 1], where higher is better. Args: score (torch.Tensor): The predicted score. Returns: torch.Tensor: The scaled score. """ new_range = (0.0, 1.0) # Compute scaling factors original_range = ( DATASET_INFO[self.regressor_dataset]['mos_range'][0], DATASET_INFO[self.regressor_dataset]['mos_range'][1], ) original_width = original_range[1] - original_range[0] new_width = new_range[1] - new_range[0] scaling_factor = new_width / original_width # Scale score scaled_score = new_range[0] + (score - original_range[0]) * scaling_factor # Invert the scale if needed if DATASET_INFO[self.regressor_dataset]['lower_better']: scaled_score = new_range[1] - scaled_score return scaled_score