Source code for pyiqa.archs.dbcnn_arch

r"""DBCNN Metric

Reference:
    Zhang, Weixia, et al. "Blind image quality assessment using
    a deep bilinear convolutional neural network." IEEE Transactions
    on Circuits and Systems for Video Technology 30.1 (2018): 36-47.

Ref url: https://github.com/zwx8981/DBCNN-PyTorch/blob/master/DBCNN.py
Re-implemented by: Chaofeng Chen (https://github.com/chaofengc)
"""

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network, get_url_from_name

[docs] default_model_urls = { 'csiq': get_url_from_name('DBCNN_CSIQ-8677d071.pth'), 'tid2008': get_url_from_name('DBCNN_TID2008-4b47c5d1.pth'), 'tid2013': get_url_from_name('DBCNN_TID2013-485d021d.pth'), 'live': get_url_from_name('DBCNN_LIVE-97262bf4.pth'), 'livec': get_url_from_name('DBCNN_LIVEC-83f6dad3.pth'), 'livem': get_url_from_name('DBCNN_LIVEM-698474e3.pth'), 'koniq': get_url_from_name('DBCNN_KonIQ10k-2de81c0a.pth'), 'scnn': get_url_from_name('DBCNN_scnn-7ea73d75.pth'), }
[docs] class SCNN(nn.Module): """Network branch for synthetic distortions. Args: use_bn (bool): Whether to use batch normalization. Modified from https://github.com/zwx8981/DBCNN-PyTorch/blob/master/SCNN.py """ def __init__(self, use_bn=True): super(SCNN, self).__init__() self.num_class = 39 self.use_bn = use_bn self.features = nn.Sequential( *self._make_layers(3, 48, 3, 1, 1), *self._make_layers(48, 48, 3, 2, 1), *self._make_layers(48, 64, 3, 1, 1), *self._make_layers(64, 64, 3, 2, 1), *self._make_layers(64, 64, 3, 1, 1), *self._make_layers(64, 64, 3, 2, 1), *self._make_layers(64, 128, 3, 1, 1), *self._make_layers(128, 128, 3, 1, 1), *self._make_layers(128, 128, 3, 2, 1), ) self.pooling = nn.AdaptiveAvgPool2d(1) self.projection = nn.Sequential( *self._make_layers(128, 256, 1, 1, 0), *self._make_layers(256, 256, 1, 1, 0), ) self.classifier = nn.Linear(256, self.num_class) def _make_layers(self, in_ch, out_ch, ksz, stride, pad): """Helper function to create layers for the network.""" if self.use_bn: layers = [ nn.Conv2d(in_ch, out_ch, ksz, stride, pad), nn.BatchNorm2d(out_ch), nn.ReLU(True), ] else: layers = [ nn.Conv2d(in_ch, out_ch, ksz, stride, pad), nn.ReLU(True), ] return layers
[docs] def forward(self, X): """ Forward pass for the SCNN. Args: X (torch.Tensor): Input tensor with shape (N, C, H, W). Returns: torch.Tensor: Output tensor after processing through the network. """ X = self.features(X) X = self.pooling(X) X = self.projection(X) X = X.view(X.shape[0], -1) X = self.classifier(X) return X
@ARCH_REGISTRY.register()
[docs] class DBCNN(nn.Module): """Full DBCNN network. Args: fc (bool): Whether to initialize the fc layers. use_bn (bool): Whether to use batch normalization. pretrained_scnn_path (str): Pretrained SCNN path. pretrained (bool): Whether to load pretrained weights. pretrained_model_path (str): Pretrained model path. default_mean (list): Default mean value. default_std (list): Default std value. """ def __init__( self, fc=True, use_bn=True, pretrained_scnn_path=None, pretrained=True, pretrained_model_path=None, default_mean=[0.485, 0.456, 0.406], default_std=[0.229, 0.224, 0.225], ): super(DBCNN, self).__init__() # Convolution and pooling layers of VGG-16. self.features1 = torchvision.models.vgg16(weights='IMAGENET1K_V1').features self.features1 = nn.Sequential(*list(self.features1.children())[:-1]) scnn = SCNN(use_bn=use_bn) load_pretrained_network(scnn, default_model_urls['scnn']) self.features2 = scnn.features # Linear classifier. self.fc = torch.nn.Linear(512 * 128, 1) self.default_mean = torch.Tensor(default_mean).view(1, 3, 1, 1) self.default_std = torch.Tensor(default_std).view(1, 3, 1, 1) if fc: # Freeze all previous layers. for param in self.features1.parameters(): param.requires_grad = False for param in scnn.parameters(): param.requires_grad = False # Initialize the fc layers. nn.init.kaiming_normal_(self.fc.weight.data) if self.fc.bias is not None: nn.init.constant_(self.fc.bias.data, val=0) if pretrained_model_path is None and pretrained: url_key = 'koniq' if isinstance(pretrained, bool) else pretrained pretrained_model_path = default_model_urls[url_key] if pretrained_model_path is not None: load_pretrained_network(self, pretrained_model_path, True, 'params')
[docs] def preprocess(self, x): """ Preprocess the input tensor. Args: x (torch.Tensor): Input tensor with shape (N, C, H, W). Returns: torch.Tensor: Preprocessed tensor. """ x = (x - self.default_mean.to(x)) / self.default_std.to(x) return x
[docs] def forward(self, X): """ Compute IQA using DBCNN model. Args: X (torch.Tensor): An input tensor with (N, C, H, W) shape. RGB channel order for colour images. Returns: torch.Tensor: Value of DBCNN model. """ X = self.preprocess(X) X1 = self.features1(X) X2 = self.features2(X) N, _, H, W = X1.shape N, _, H2, W2 = X2.shape if (H != H2) or (W != W2): X2 = F.interpolate(X2, (H, W), mode='bilinear', align_corners=True) X1 = X1.view(N, 512, H * W) X2 = X2.view(N, 128, H * W) X = torch.bmm(X1, torch.transpose(X2, 1, 2)) / (H * W) # Bilinear X = X.view(N, 512 * 128) X = torch.sqrt(X + 1e-8) X = torch.nn.functional.normalize(X) X = self.fc(X) return X