Source code for pyiqa.models.distiqa_model
from collections import OrderedDict
import torch
from pyiqa.metrics import calculate_metric
from pyiqa.utils.registry import MODEL_REGISTRY
from .general_iqa_model import GeneralIQAModel
@MODEL_REGISTRY.register()
[docs]
class DistIQAModel(GeneralIQAModel):
"""General module to train an IQA network."""
[docs]
def feed_data(self, data):
self.img_input = data['img'].to(self.device)
self.gt_mos = data['mos_label'].to(self.device)
self.gt_mos_dist = data['mos_dist'].to(self.device)
self.use_ref = False
[docs]
def test(self):
self.net.eval()
with torch.no_grad():
self.output_score = self.net(
self.img_input, return_mos=True, return_dist=False
)
self.net.train()
[docs]
def optimize_parameters(self, current_iter):
self.optimizer.zero_grad()
self.output_mos, self.output_dist = self.net(
self.img_input, return_mos=True, return_dist=True
)
l_total = 0
loss_dict = OrderedDict()
if self.cri_mos:
l_mos = self.cri_mos(self.output_dist, self.gt_mos_dist)
l_total += l_mos
loss_dict['l_mos'] = l_mos
l_total.backward()
self.optimizer.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
# log metrics in training batch
pred_score = self.output_mos.squeeze(1).cpu().detach().numpy()
gt_mos = self.gt_mos.squeeze(1).cpu().detach().numpy()
for name, opt_ in self.opt['val']['metrics'].items():
self.log_dict[f'train_metrics/{name}'] = calculate_metric(
[pred_score, gt_mos], opt_
)