Source code for pyiqa.archs.pieapp_arch

r"""PieAPP metric, proposed by

Prashnani, Ekta, Hong Cai, Yasamin Mostofi, and Pradeep Sen.
"Pieapp: Perceptual image-error assessment through pairwise preference."
In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1808-1817. 2018.

Ref url: https://github.com/prashnani/PerceptualImageError
Modified by: Chaofeng Chen (https://github.com/chaofengc)

!!! Important Note: to keep simple test process and fair comparison with other methods,
                    we use zero padding and extract subpatches only once
                    rather than from multiple subimages as the original codes.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network
from .func_util import extract_2d_patches
from pyiqa.archs.arch_util import get_url_from_name

[docs] default_model_urls = {'url': get_url_from_name('PieAPPv0.1-0937b014.pth')}
[docs] class CompactLinear(nn.Module): def __init__(self): super().__init__() self.weight = nn.parameter.Parameter(torch.randn(1)) self.bias = nn.parameter.Parameter(torch.randn(1))
[docs] def forward(self, x): return x * self.weight + self.bias
@ARCH_REGISTRY.register()
[docs] class PieAPP(nn.Module): r""" PieAPP model implementation. Args: - patch_size (int): Size of the patches to extract from the images. - stride (int): Stride to use when extracting patches. - pretrained (bool): Whether to use a pretrained model or not. - pretrained_model_path (str): Path to the pretrained model. Methods: - flatten(matrix): Takes NxCxHxW input and outputs NxHWC. compute_features(input): Computes the features of the input image. - preprocess(x): Preprocesses the input image. forward(dist, ref): Computes the PieAPP score between the distorted and reference images. """ def __init__( self, patch_size=64, stride=27, pretrained=True, pretrained_model_path=None ): super(PieAPP, self).__init__() self.conv1 = nn.Conv2d(3, 64, 3, padding=1) self.conv2 = nn.Conv2d(64, 64, 3, padding=1) self.pool2 = nn.MaxPool2d(2, 2) self.conv3 = nn.Conv2d(64, 64, 3, padding=1) self.conv4 = nn.Conv2d(64, 128, 3, padding=1) self.pool4 = nn.MaxPool2d(2, 2) self.conv5 = nn.Conv2d(128, 128, 3, padding=1) self.conv6 = nn.Conv2d(128, 128, 3, padding=1) self.pool6 = nn.MaxPool2d(2, 2) self.conv7 = nn.Conv2d(128, 256, 3, padding=1) self.conv8 = nn.Conv2d(256, 256, 3, padding=1) self.pool8 = nn.MaxPool2d(2, 2) self.conv9 = nn.Conv2d(256, 256, 3, padding=1) self.conv10 = nn.Conv2d(256, 512, 3, padding=1) self.pool10 = nn.MaxPool2d(2, 2) self.conv11 = nn.Conv2d(512, 512, 3, padding=1) self.fc1_score = nn.Linear(120832, 512) self.fc2_score = nn.Linear(512, 1) self.fc1_weight = nn.Linear(2048, 512) self.fc2_weight = nn.Linear(512, 1) self.ref_score_subtract = CompactLinear() self.patch_size = patch_size self.stride = stride if pretrained_model_path is not None: load_pretrained_network(self, pretrained_model_path) elif pretrained: load_pretrained_network(self, default_model_urls['url']) self.pretrained = pretrained
[docs] def flatten(self, matrix): # takes NxCxHxW input and outputs NxHWC return torch.flatten(matrix, 1)
[docs] def compute_features(self, input): # conv1 -> relu -> conv2 -> relu -> pool2 -> conv3 -> relu x3 = F.relu( self.conv3(self.pool2(F.relu(self.conv2(F.relu(self.conv1(input)))))) ) # conv4 -> relu -> pool4 -> conv5 -> relu x5 = F.relu(self.conv5(self.pool4(F.relu(self.conv4(x3))))) # conv6 -> relu -> pool6 -> conv7 -> relu x7 = F.relu(self.conv7(self.pool6(F.relu(self.conv6(x5))))) # conv8 -> relu -> pool8 -> conv9 -> relu x9 = F.relu(self.conv9(self.pool8(F.relu(self.conv8(x7))))) # conv10 -> relu -> pool10 -> conv11 -> relU x11 = self.flatten(F.relu(self.conv11(self.pool10(F.relu(self.conv10(x9)))))) # flatten and concatenate feature_ms = torch.cat( ( self.flatten(x3), self.flatten(x5), self.flatten(x7), self.flatten(x9), x11, ), 1, ) return feature_ms, x11
[docs] def preprocess(self, x): """Default BGR in [0, 255] in original codes""" x = x[:, [2, 1, 0]] * 255.0 return x
[docs] def forward(self, dist, ref): assert dist.shape == ref.shape, ( f'Input and reference images should have the same shape, but got {dist.shape}' ) f' and {ref.shape}' dist = self.preprocess(dist) ref = self.preprocess(ref) if not self.training: image_A_patches = extract_2d_patches( dist, self.patch_size, self.stride, padding='none' ) image_ref_patches = extract_2d_patches( ref, self.patch_size, self.stride, padding='none' ) else: image_A_patches, image_ref_patches = dist, ref image_A_patches = image_A_patches.unsqueeze(1) image_ref_patches = image_ref_patches.unsqueeze(1) bsz, num_patches, c, psz, psz = image_A_patches.shape image_A_patches = image_A_patches.reshape(bsz * num_patches, c, psz, psz) image_ref_patches = image_ref_patches.reshape(bsz * num_patches, c, psz, psz) A_multi_scale, A_coarse = self.compute_features(image_A_patches) ref_multi_scale, ref_coarse = self.compute_features(image_ref_patches) diff_ms = ref_multi_scale - A_multi_scale diff_coarse = ref_coarse - A_coarse # per patch score: fc1_score -> relu -> fc2_score per_patch_score = self.ref_score_subtract( 0.01 * self.fc2_score(F.relu(self.fc1_score(diff_ms))) ) per_patch_score = per_patch_score.view((-1, num_patches)) # per patch weight: fc1_weight -> relu -> fc2_weight per_patch_weight = self.fc2_weight(F.relu(self.fc1_weight(diff_coarse))) + 1e-6 per_patch_weight = per_patch_weight.view((-1, num_patches)) score = (per_patch_weight * per_patch_score).sum(dim=-1) / per_patch_weight.sum( dim=-1 ) return score.reshape(bsz, 1)