Source code for pyiqa.archs.paq2piq_arch

r"""Paq2piq metric, proposed by

Ying, Zhenqiang, Haoran Niu, Praful Gupta, Dhruv Mahajan, Deepti Ghadiyaram, and Alan Bovik.
"From patches to pictures (PaQ-2-PiQ): Mapping the perceptual space of picture quality."
In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 3575-3585. 2020.

Ref url: https://github.com/baidut/paq2piq/blob/master/paq2piq/model.py
Modified by: Chaofeng Chen (https://github.com/chaofengc)

"""

import torch
import torch.nn as nn
import torchvision as tv
from torchvision.ops import RoIPool

from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network
from pyiqa.archs.arch_util import get_url_from_name

[docs] default_model_urls = { 'url': get_url_from_name('P2P_RoIPoolModel-fit.10.bs.120-ca69882e.pth') }
[docs] class AdaptiveConcatPool2d(nn.Module): """Concatenate adaptive max and average pooling outputs.""" def __init__(self, sz=None): super().__init__() sz = sz or (1, 1) self.ap = nn.AdaptiveAvgPool2d(sz) self.mp = nn.AdaptiveMaxPool2d(sz)
[docs] def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
@ARCH_REGISTRY.register()
[docs] class PAQ2PIQ(nn.Module): """PaQ-2-PiQ no-reference image quality predictor. Args: backbone (str): Backbone name. Currently ``'resnet18'`` is supported. pretrained (bool): Whether to load pretrained PaQ-2-PiQ weights. pretrained_model_path (str | None): Optional local checkpoint path. """ def __init__( self, backbone='resnet18', pretrained=True, pretrained_model_path=None ): super(PAQ2PIQ, self).__init__() if backbone == 'resnet18': model = tv.models.resnet18(weights='IMAGENET1K_V1') cut = -2 spatial_scale = 1 / 32 self.blk_size = 20, 20 self.model_type = self.__class__.__name__ self.body = nn.Sequential(*list(model.children())[:cut]) self.head = nn.Sequential( AdaptiveConcatPool2d(), nn.Flatten(), nn.BatchNorm1d( 1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True ), nn.Dropout(p=0.25, inplace=False), nn.Linear(in_features=1024, out_features=512, bias=True), nn.ReLU(inplace=True), nn.BatchNorm1d( 512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True ), nn.Dropout(p=0.5, inplace=False), nn.Linear(in_features=512, out_features=1, bias=True), ) self.roi_pool = RoIPool((2, 2), spatial_scale) if pretrained_model_path is not None: load_pretrained_network(self, pretrained_model_path) elif pretrained: load_pretrained_network(self, default_model_urls['url'])
[docs] def forward(self, x): """Predict quality score from image tensor. Args: x (torch.Tensor): Input tensor with shape ``(N, 3, H, W)``. Returns: torch.Tensor: Predicted scores with shape ``(N, 1)``. """ im_data = x batch_size = im_data.shape[0] feats = self.body(im_data) global_rois = torch.tensor([0, 0, x.shape[-1], x.shape[-2]]).reshape(1, 4).to(x) feats = self.roi_pool(feats, [global_rois] * batch_size) preds = self.head(feats) return preds.view(batch_size, -1)