Source code for pyiqa.models.bapps_model

from collections import OrderedDict
import torch
from pyiqa.metrics import calculate_metric
from pyiqa.utils.registry import MODEL_REGISTRY
from .general_iqa_model import GeneralIQAModel
from tqdm import tqdm
import os.path as osp


@MODEL_REGISTRY.register()
[docs] class BAPPSModel(GeneralIQAModel): """General module to train an IQA network."""
[docs] def feed_data(self, data): if 'use_ref' in self.opt['train']: self.use_ref = self.opt['train']['use_ref'] self.img_A_input = data['distA_img'].to(self.device) self.img_B_input = data['distB_img'].to(self.device) self.img_ref_input = data['ref_img'].to(self.device) self.gt_mos = data['mos_label'].to(self.device) self.img_path = data['img_path']
# from torchvision.utils import save_image # print(self.img_ref_input.shape) # save_image(torch.cat([self.img_ref_input, self.img_A_input, self.img_B_input], dim=0), 'tmp_test_bappsdataset.jpg') # exit()
[docs] def compute_accuracy(self, d0, d1, judge): d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten() judge_per = judge.cpu().numpy().flatten() acc = d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per) return acc.mean()
[docs] def optimize_parameters(self, current_iter): self.optimizer.zero_grad() score_A = self.net(self.img_A_input, self.img_ref_input) score_B = self.net(self.img_B_input, self.img_ref_input) # For BAPPS, train_output_score = 1 / (1 + torch.exp(score_B - score_A)) l_total = 0 loss_dict = OrderedDict() # pixel loss if self.cri_mos: l_mos = self.cri_mos(train_output_score, self.gt_mos) l_total += l_mos loss_dict['l_mos'] = l_mos l_total.backward() self.optimizer.step() self.log_dict = self.reduce_loss_dict(loss_dict) # log metrics in training batch self.log_dict['train_metrics/acc'] = self.compute_accuracy( score_A, score_B, self.gt_mos )
@torch.no_grad()
[docs] def test(self): self.net.eval() with torch.no_grad(): self.score_A = self.net(self.img_A_input, self.img_ref_input) self.score_B = self.net(self.img_B_input, self.img_ref_input) self.net.train()
@torch.no_grad()
[docs] def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): dataset_name = dataloader.dataset.opt['name'] with_metrics = self.opt['val'].get('metrics') is not None use_pbar = self.opt['val'].get('pbar', False) if with_metrics: if not hasattr(self, 'metric_results'): # only execute in the first run self.metric_results = { metric: 0 for metric in self.opt['val']['metrics'].keys() } # initialize the best metric results for each dataset_name (supporting multiple validation datasets) self._initialize_best_metric_results(dataset_name) # zero self.metric_results if with_metrics: self.metric_results = {metric: 0 for metric in self.metric_results} if use_pbar: pbar = tqdm(total=len(dataloader), unit='image') pred_score_A = [] pred_score_B = [] gt_mos = [] for idx, val_data in enumerate(dataloader): img_name = osp.basename(val_data['img_path'][0]) self.feed_data(val_data) self.test() if len(self.score_A.shape) <= 1: self.score_A = self.score_A.reshape(-1, 1) self.score_B = self.score_B.reshape(-1, 1) pred_score_A.append(self.score_A) pred_score_B.append(self.score_B) gt_mos.append(self.gt_mos) if use_pbar: pbar.update(1) pbar.set_description(f'Test {img_name:>20}') if use_pbar: pbar.close() pred_score_A = torch.cat(pred_score_A, dim=0).squeeze(1).cpu().numpy() pred_score_B = torch.cat(pred_score_B, dim=0).squeeze(1).cpu().numpy() gt_mos = torch.cat(gt_mos, dim=0).squeeze(1).cpu().numpy() if with_metrics: # calculate all metrics for name, opt_ in self.opt['val']['metrics'].items(): self.metric_results[name] = calculate_metric( [pred_score_A, pred_score_B, gt_mos], opt_ ) if self.key_metric is not None: # If the best metric is updated, update and save best model to_update = self._update_best_metric_result( dataset_name, self.key_metric, self.metric_results[self.key_metric], current_iter, ) if to_update: for name, opt_ in self.opt['val']['metrics'].items(): self._update_metric_result( dataset_name, name, self.metric_results[name], current_iter ) self.copy_model(self.net, self.net_best) self.save_network(self.net_best, 'net_best') else: # update each metric separately updated = [] for name, opt_ in self.opt['val']['metrics'].items(): tmp_updated = self._update_best_metric_result( dataset_name, name, self.metric_results[name], current_iter ) updated.append(tmp_updated) # save best model if any metric is updated if sum(updated): self.copy_model(self.net, self.net_best) self.save_network(self.net_best, 'net_best') self._log_validation_metric_values(current_iter, dataset_name, tb_logger)