Source code for pyiqa.archs.hypernet_arch

r"""HyperNet Metric

@InProceedings{hyperiqa,
    author = {Su, Shaolin and Yan, Qingsen and Zhu, Yu and Zhang, Cheng and Ge, Xin and Sun, Jinqiu and Zhang, Yanning},
    title = {Blindly Assess Image Quality in the Wild Guided by a Self-Adaptive Hyper Network},
    booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month = {June},
    year = {2020}
}

Ref url: https://github.com/SSL92/hyperIQA
Re-implemented by: Chaofeng Chen (https://github.com/chaofengc)

"""

import torch
import torch.nn as nn
import timm
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network, uniform_crop
from pyiqa.archs.arch_util import get_url_from_name


[docs] default_model_urls = { 'resnet50-koniq': get_url_from_name('HyperIQA-resnet50-koniq10k-c96c41b1.pth'), }
@ARCH_REGISTRY.register()
[docs] class HyperNet(nn.Module): """HyperNet Model. Args: - base_model_name (String): pretrained model to extract features, can be any models supported by timm. Default: resnet50. - pretrained_model_path (String): Pretrained model path. - default_mean (list): Default mean value. - default_std (list): Default std value. Reference: Su, Shaolin, Qingsen Yan, Yu Zhu, Cheng Zhang, Xin Ge, Jinqiu Sun, and Yanning Zhang. "Blindly assess image quality in the wild guided by a self-adaptive hyper network." In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 3667-3676. 2020. """ def __init__( self, base_model_name='resnet50', num_crop=25, pretrained=True, pretrained_model_path=None, default_mean=[0.485, 0.456, 0.406], default_std=[0.229, 0.224, 0.225], ): super(HyperNet, self).__init__() self.base_model = timm.create_model( base_model_name, pretrained=True, features_only=True ) lda_out_channels = 16 hyper_in_channels = 112 target_in_size = 224 hyper_fc_channels = [112, 56, 28, 14, 1] feature_size = 7 # spatial size of the last features from base model self.hyper_fc_channels = hyper_fc_channels self.num_crop = num_crop # local distortion aware module self.lda_modules = nn.ModuleList( [ nn.Sequential( nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False), nn.AvgPool2d(7, stride=7), nn.Flatten(), nn.Linear(16 * 64, lda_out_channels), ), nn.Sequential( nn.Conv2d(512, 32, kernel_size=1, stride=1, padding=0, bias=False), nn.AvgPool2d(7, stride=7), nn.Flatten(), nn.Linear(32 * 16, lda_out_channels), ), nn.Sequential( nn.Conv2d(1024, 64, kernel_size=1, stride=1, padding=0, bias=False), nn.AvgPool2d(7, stride=7), nn.Flatten(), nn.Linear(64 * 4, lda_out_channels), ), nn.Sequential( nn.AvgPool2d(7, stride=7), nn.Flatten(), nn.Linear(2048, target_in_size - lda_out_channels * 3), ), ] ) # Hyper network part, conv for generating target fc weights, fc for generating target fc biases self.fc_w_modules = nn.ModuleList([]) for i in range(4): if i == 0: out_ch = int(target_in_size * hyper_fc_channels[i] / feature_size**2) else: out_ch = int( hyper_fc_channels[i - 1] * hyper_fc_channels[i] / feature_size**2 ) self.fc_w_modules.append( nn.Conv2d(hyper_in_channels, out_ch, 3, padding=(1, 1)), ) self.fc_w_modules.append( nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(hyper_in_channels, hyper_fc_channels[3]), ) ) self.fc_b_modules = nn.ModuleList([]) for i in range(5): self.fc_b_modules.append( nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(hyper_in_channels, hyper_fc_channels[i]), ) ) # Conv layers for resnet output features self.conv1 = nn.Sequential( nn.Conv2d(2048, 1024, 1, padding=(0, 0)), nn.ReLU(inplace=True), nn.Conv2d(1024, 512, 1, padding=(0, 0)), nn.ReLU(inplace=True), nn.Conv2d(512, hyper_in_channels, 1, padding=(0, 0)), nn.ReLU(inplace=True), ) self.global_pool = nn.Sequential() self.default_mean = torch.Tensor(default_mean).view(1, 3, 1, 1) self.default_std = torch.Tensor(default_std).view(1, 3, 1, 1) if pretrained and pretrained_model_path is None: load_pretrained_network( self, default_model_urls['resnet50-koniq'], True, weight_keys='params' ) elif pretrained_model_path is not None: load_pretrained_network( self, pretrained_model_path, True, weight_keys='params' )
[docs] def preprocess(self, x): # input must have shape of (224, 224) because of network design if x.shape[2:] != torch.Size([224, 224]): x = nn.functional.interpolate(x, (224, 224), mode='bicubic') x = (x - self.default_mean.to(x)) / self.default_std.to(x) return x
[docs] def forward_patch(self, x): assert x.shape[2:] == torch.Size([224, 224]), ( f'Input patch size must be (224, 224), but got {x.shape[2:]}' ) x = self.preprocess(x) base_feats = self.base_model(x)[1:] # multi-scale local distortion aware features lda_feat_list = [] for bf, ldam in zip(base_feats, self.lda_modules): lda_feat_list.append(ldam(bf)) lda_feat = torch.cat(lda_feat_list, dim=1) # calculate target net weights & bias target_fc_w = [] target_fc_b = [] hyper_in_feat = self.conv1(base_feats[-1]) batch_size = hyper_in_feat.shape[0] for i in range(len(self.fc_w_modules)): tmp_fc_w = self.fc_w_modules[i](hyper_in_feat).reshape( batch_size, self.hyper_fc_channels[i], -1 ) target_fc_w.append(tmp_fc_w) target_fc_b.append(self.fc_b_modules[i](hyper_in_feat)) # get final IQA score x = lda_feat.unsqueeze(1) for i in range(len(target_fc_w)): if i != 4: x = torch.sigmoid( torch.bmm(x, target_fc_w[i].transpose(1, 2)) + target_fc_b[i].unsqueeze(1) ) else: x = torch.bmm(x, target_fc_w[i].transpose(1, 2)) + target_fc_b[ i ].unsqueeze(1) return x.squeeze(-1)
[docs] def forward(self, x): r"""HYPERNET model. Args: x: A distortion tensor. Shape :math:`(N, C, H, W)`. """ # imagenet normalization of input is hard coded if self.training: return self.forward_patch(x) else: b, c, h, w = x.shape crops = uniform_crop([x], 224, self.num_crop) results = self.forward_patch(crops) results = results.reshape(b, self.num_crop, -1).mean(dim=1) return results.unsqueeze(-1)