Source code for pyiqa.archs.metaiqa_arch
r"""MetaIQA Model wrapper for pyiqa.
Reference: Zhu, Hancheng, et al. "MetaIQA: Deep Meta-Learning for No-Reference Image Quality Assessment." CVPR 2020.
Re-implemented and optimized for pyiqa ecosystem.
"""
import os
import sys
import torch
import torch.nn as nn
from torchvision import models
from PIL import Image
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network, get_url_from_name, load_file_from_url
[docs]
default_model_urls = {
'meta-train-seed': get_url_from_name('Metaiqa_prior.pth'),
'meta-infer-ready': get_url_from_name('Metaiqa_livec.pth'),
}
[docs]
class BaselineModel1(nn.Module):
def __init__(self, num_classes=1, keep_probability=0.5, inputsize=1000):
super(BaselineModel1, self).__init__()
self.fc1 = nn.Linear(inputsize, 1024)
self.bn1 = nn.BatchNorm1d(1024)
self.drop_prob = (1 - keep_probability)
self.relu1 = nn.PReLU()
self.drop1 = nn.Dropout(self.drop_prob)
self.fc2 = nn.Linear(1024, 512)
self.bn2 = nn.BatchNorm1d(512)
self.relu2 = nn.PReLU()
self.drop2 = nn.Dropout(p=self.drop_prob)
self.fc3 = nn.Linear(512, num_classes)
self.sig = nn.Sigmoid()
[docs]
def forward(self, x):
out = self.fc1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.drop1(out)
out = self.fc2(out)
out = self.bn2(out)
out = self.relu2(out)
out = self.drop2(out)
out = self.fc3(out)
out = self.sig(out)
return out
@ARCH_REGISTRY.register()
[docs]
class MetaIQA(nn.Module):
def __init__(self, pretrained=True, pretrained_model_path=None, **kwargs):
super(MetaIQA, self).__init__()
self.metric_mode = 'NR'
self.lower_better = False
self.resnet_layer = models.resnet18(weights=None)
self.head = BaselineModel1(num_classes=1, keep_probability=0.5, inputsize=1000)
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
if pretrained_model_path is not None:
load_pretrained_network(self, pretrained_model_path, strict=False, weight_keys=None)
elif pretrained:
main_script = os.path.basename(sys.argv[0])
is_train_env = (
'train.py' in main_script or
'train_nsplits.py' in main_script
)
if not is_train_env:
ckpt_path = load_file_from_url(default_model_urls['meta-infer-ready'])
print(f'Loading pretrained model from {ckpt_path} for inference...')
state_dict = torch.load(ckpt_path, map_location='cpu')
if 'params' in state_dict:
state_dict = state_dict['params']
elif 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
clean_dict = {}
for k, v in state_dict.items():
new_k = k.replace('net_g.', '').replace('module.', '').replace('net.', 'head.')
clean_dict[new_k] = v
self.load_state_dict(clean_dict, strict=False)
else:
load_pretrained_network(self, default_model_urls['meta-train-seed'], strict=False, weight_keys=None)
[docs]
def preprocess(self, x):
if x.shape[2:] != (224, 224):
x = nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
x = (x - self.mean) / self.std
return x
[docs]
def forward(self, x, ref=None):
x = self.preprocess(x)
feat = self.resnet_layer(x)
score = self.head(feat)
return score.view(-1, 1)