Source code for pyiqa.models.inference_model

import torch

from collections import OrderedDict
from pyiqa.default_model_configs import DEFAULT_CONFIGS

from pyiqa.archs import build_network
from pyiqa.utils.img_util import imread2tensor

from pyiqa.losses.loss_util import weight_reduce_loss
from pyiqa.archs.arch_util import load_pretrained_network


[docs] class InferenceModel(torch.nn.Module): """Common interface for quality inference of images with default setting of each metric.""" def __init__( self, metric_name, as_loss=False, loss_weight=None, loss_reduction='mean', device=None, seed=123, check_input_range=True, **kwargs, # Other metric options ): super(InferenceModel, self).__init__() if metric_name not in DEFAULT_CONFIGS: raise KeyError(f'Unknown metric: {metric_name}') self.metric_name = metric_name metric_cfg = DEFAULT_CONFIGS[metric_name] # ============ set metric properties =========== self.lower_better = metric_cfg.get('lower_better', False) self.metric_mode = metric_cfg.get('metric_mode', None) self.score_range = metric_cfg.get('score_range', None) if self.metric_mode is None: self.metric_mode = kwargs.pop('metric_mode', None) if self.metric_mode is None: raise ValueError(f'`metric_mode` must be provided for metric: {metric_name}') elif 'metric_mode' in kwargs: kwargs.pop('metric_mode') if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = device self.as_loss = as_loss self.loss_weight = loss_weight self.loss_reduction = loss_reduction if metric_name == 'compare2score': self.as_loss = True self.loss_reduction = 'none' # disable input range check when used as loss self.check_input_range = check_input_range if not self.as_loss else False # =========== define metric model =============== net_opts = OrderedDict() # load default setting first default_opt = metric_cfg['metric_opts'] net_opts.update(default_opt) # then update with custom setting net_opts.update(kwargs) self.net = build_network(net_opts) self.net = self.net.to(self.device) self.net.eval() self.seed = seed self.dummy_param = torch.nn.Parameter(torch.empty(0)).to(self.device) self.eps = 1e-6
[docs] def load_weights(self, weights_path, weight_keys='params'): load_pretrained_network(self.net, weights_path, weight_keys=weight_keys)
[docs] def is_valid_input(self, x): if x is not None: assert isinstance(x, torch.Tensor), 'Input must be a torch.Tensor' assert x.dim() == 4, 'Input must be 4D tensor (B, C, H, W)' assert x.shape[1] in [1, 3], 'Input must be RGB or gray image' if self.check_input_range: assert x.min() > -self.eps and x.max() < 1 + self.eps, ( f'Input must be normalized to [0, 1], but got min={x.min():.4f}, max={x.max():.4f}' )
def _to_batched_tensor(self, img): if torch.is_tensor(img): return img return imread2tensor(img, rgb=True).unsqueeze(0) def _prepare_inputs(self, target, ref=None): target = self._to_batched_tensor(target) self.is_valid_input(target) if self.metric_mode == 'FR': assert ref is not None, 'Please specify reference image for Full Reference metric' ref = self._to_batched_tensor(ref) self.is_valid_input(ref) return target, ref
[docs] def forward(self, target, ref=None, **kwargs): device = self.dummy_param.device with torch.backends.cudnn.flags(enabled=True, benchmark=False, deterministic=True): with torch.set_grad_enabled(self.as_loss): if 'fid' in self.metric_name: output = self.net(target, ref, device=device, **kwargs) elif self.metric_name == 'inception_score': output = self.net(target, device=device, **kwargs) else: target, ref = self._prepare_inputs(target, ref) if self.metric_mode == 'FR': output = self.net(target.to(device), ref.to(device), **kwargs) elif self.metric_mode == 'NR': output = self.net(target.to(device), **kwargs) else: raise ValueError(f'Unsupported metric mode: {self.metric_mode}') if self.as_loss: if isinstance(output, tuple): output = output[0] return weight_reduce_loss(output, self.loss_weight, self.loss_reduction) else: return output