import torch
from collections import OrderedDict
from pyiqa.default_model_configs import DEFAULT_CONFIGS
# from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs import build_network
from pyiqa.utils.img_util import imread2tensor
from pyiqa.losses.loss_util import weight_reduce_loss
from pyiqa.archs.arch_util import load_pretrained_network
[docs]
class InferenceModel(torch.nn.Module):
"""Common interface for quality inference of images with default setting of each metric."""
def __init__(
self,
metric_name,
as_loss=False,
loss_weight=None,
loss_reduction='mean',
device=None,
seed=123,
check_input_range=True,
**kwargs, # Other metric options
):
super(InferenceModel, self).__init__()
self.metric_name = metric_name
# ============ set metric properties ===========
self.lower_better = DEFAULT_CONFIGS[metric_name].get('lower_better', False)
self.metric_mode = DEFAULT_CONFIGS[metric_name].get('metric_mode', None)
self.score_range = DEFAULT_CONFIGS[metric_name].get('score_range', None)
if self.metric_mode is None:
self.metric_mode = kwargs.pop('metric_mode')
elif 'metric_mode' in kwargs:
kwargs.pop('metric_mode')
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
self.as_loss = as_loss
self.loss_weight = loss_weight
self.loss_reduction = loss_reduction
if metric_name == 'compare2score':
self.as_loss = True
self.loss_reduction = 'none'
# disable input range check when used as loss
self.check_input_range = check_input_range if not as_loss else False
# =========== define metric model ===============
net_opts = OrderedDict()
# load default setting first
if metric_name in DEFAULT_CONFIGS.keys():
default_opt = DEFAULT_CONFIGS[metric_name]['metric_opts']
net_opts.update(default_opt)
# then update with custom setting
net_opts.update(kwargs)
self.net = build_network(net_opts)
self.net = self.net.to(self.device)
self.net.eval()
self.seed = seed
self.dummy_param = torch.nn.Parameter(torch.empty(0)).to(self.device)
self.eps = 1e-6
[docs]
def load_weights(self, weights_path, weight_keys='params'):
load_pretrained_network(self.net, weights_path, weight_keys=weight_keys)
[docs]
def forward(self, target, ref=None, **kwargs):
device = self.dummy_param.device
with torch.set_grad_enabled(self.as_loss):
if 'fid' in self.metric_name:
output = self.net(target, ref, device=device, **kwargs)
elif self.metric_name == 'inception_score':
output = self.net(target, device=device, **kwargs)
else:
if not torch.is_tensor(target):
target = imread2tensor(target, rgb=True)
target = target.unsqueeze(0)
if self.metric_mode == 'FR':
assert ref is not None, (
'Please specify reference image for Full Reference metric'
)
ref = imread2tensor(ref, rgb=True)
ref = ref.unsqueeze(0)
self.is_valid_input(ref)
self.is_valid_input(target)
if self.metric_mode == 'FR':
assert ref is not None, (
'Please specify reference image for Full Reference metric'
)
output = self.net(target.to(device), ref.to(device), **kwargs)
elif self.metric_mode == 'NR':
output = self.net(target.to(device), **kwargs)
if self.as_loss:
if isinstance(output, tuple):
output = output[0]
return weight_reduce_loss(output, self.loss_weight, self.loss_reduction)
else:
return output