Source code for pyiqa.archs.cnniqa_arch

r"""CNNIQA Model.

Zheng, Heliang, Huan Yang, Jianlong Fu, Zheng-Jun Zha, and Jiebo Luo.
"Learning conditional knowledge distillation for degraded-reference image
quality assessment." In Proceedings of the IEEE/CVF International Conference
on Computer Vision (ICCV), pp. 10242-10251. 2021.

Ref url: https://github.com/lidq92/CNNIQA
Re-implemented by: Chaofeng Chen (https://github.com/chaofengc) with modification:
    - We use 3 channel RGB input.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network, get_url_from_name

[docs] default_model_urls = {'koniq10k': get_url_from_name('CNNIQA_koniq10k-e6f14c91.pth')}
@ARCH_REGISTRY.register()
[docs] class CNNIQA(nn.Module): r"""CNNIQA model. Args: ker_size (int): Kernel size. n_kers (int): Number of kernels. n1_nodes (int): Number of n1 nodes. n2_nodes (int): Number of n2 nodes. pretrained (str): Pretrained model name. pretrained_model_path (str): Pretrained model path. """ def __init__( self, ker_size=7, n_kers=50, n1_nodes=800, n2_nodes=800, pretrained='koniq10k', pretrained_model_path=None, ): super(CNNIQA, self).__init__() self.conv1 = nn.Conv2d(3, n_kers, ker_size) self.fc1 = nn.Linear(2 * n_kers, n1_nodes) self.fc2 = nn.Linear(n1_nodes, n2_nodes) self.fc3 = nn.Linear(n2_nodes, 1) self.dropout = nn.Dropout() if pretrained_model_path is None and pretrained is not None: pretrained_model_path = default_model_urls[pretrained] if pretrained_model_path is not None: load_pretrained_network(self, pretrained_model_path, True, 'params')
[docs] def forward(self, x): r"""Compute IQA using CNNIQA model. Args: x (torch.Tensor): An input tensor with (N, C, H, W) shape. RGB channel order for colour images. Returns: torch.Tensor: Value of CNNIQA model. """ h = self.conv1(x) h1 = F.max_pool2d(h, (h.size(-2), h.size(-1))) h2 = -F.max_pool2d(-h, (h.size(-2), h.size(-1))) h = torch.cat((h1, h2), 1) # max-min pooling h = h.squeeze(3).squeeze(2) h = F.relu(self.fc1(h)) h = self.dropout(h) h = F.relu(self.fc2(h)) q = self.fc3(h) return q