Source code for pyiqa.models.pieapp_model
from collections import OrderedDict
import torch
from pyiqa.metrics.correlation_coefficient import calculate_rmse
from pyiqa.utils.registry import MODEL_REGISTRY
from .general_iqa_model import GeneralIQAModel
@MODEL_REGISTRY.register()
[docs]
class PieAPPModel(GeneralIQAModel):
"""General module to train an IQA network."""
[docs]
def feed_data(self, data):
is_test = 'img' in data.keys()
if 'use_ref' in self.opt['train']:
self.use_ref = self.opt['train']['use_ref']
if is_test:
self.img_input = data['img'].to(self.device)
self.gt_mos = data['mos_label'].to(self.device)
self.ref_input = data['ref_img'].to(self.device)
self.ref_img_path = data['ref_img_path']
self.img_path = data['img_path']
else:
self.img_A_input = data['distA_img'].to(self.device)
self.img_B_input = data['distB_img'].to(self.device)
self.img_ref_input = data['ref_img'].to(self.device)
self.gt_prob = data['mos_label'].to(self.device)
# from torchvision.utils import save_image
# save_image(torch.cat([self.img_A_input, self.img_B_input, self.img_ref_input], dim=0), 'tmp_test_pieappdataset.jpg')
# exit()
[docs]
def optimize_parameters(self, current_iter):
self.optimizer.zero_grad()
score_A = self.net(self.img_A_input, self.img_ref_input)
score_B = self.net(self.img_B_input, self.img_ref_input)
train_output_score = 1 / (1 + torch.exp(score_A - score_B))
l_total = 0
loss_dict = OrderedDict()
# pixel loss
if self.cri_mos:
l_mos = self.cri_mos(train_output_score, self.gt_prob)
l_total += l_mos
loss_dict['l_mos'] = l_mos
l_total.backward()
self.optimizer.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
# log metrics in training batch
pred_score = train_output_score.squeeze(-1).cpu().detach().numpy()
gt_prob = self.gt_prob.squeeze(-1).cpu().detach().numpy()
self.log_dict['train_metrics/rmse'] = calculate_rmse(pred_score, gt_prob)