import torch
from collections import OrderedDict
from os import path as osp
from tqdm import tqdm
from pyiqa.archs import build_network
from pyiqa.losses import build_loss
from pyiqa.metrics import calculate_metric
from pyiqa.utils import get_root_logger
from pyiqa.utils.registry import MODEL_REGISTRY
from .base_model import BaseModel
@MODEL_REGISTRY.register()
[docs]
class GeneralIQAModel(BaseModel):
"""General module to train an IQA network."""
def __init__(self, opt):
super(GeneralIQAModel, self).__init__(opt)
# define network
self.net = build_network(opt['network'])
self.net = self.model_to_device(self.net)
self.print_network(self.net)
# load pretrained models
load_path = self.opt['path'].get('pretrain_network', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_g', 'params')
self.load_network(
self.net,
load_path,
self.opt['path'].get('strict_load', True),
param_key,
)
if self.is_train:
self.init_training_settings()
[docs]
def init_training_settings(self):
self.net.train()
train_opt = self.opt['train']
self.net_best = build_network(self.opt['network']).to(self.device)
# define losses
if train_opt.get('mos_loss_opt'):
self.cri_mos = build_loss(train_opt['mos_loss_opt']).to(self.device)
else:
self.cri_mos = None
# define metric related loss, such as plcc loss
if train_opt.get('metric_loss_opt'):
self.cri_metric = build_loss(train_opt['metric_loss_opt']).to(self.device)
else:
self.cri_metric = None
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
[docs]
def setup_optimizers(self):
train_opt = self.opt['train']
optim_opt = train_opt['optim']
param_dict = {k: v for k, v in self.net.named_parameters()}
param_keys = list(param_dict.keys())
# set different lr for different modules if needed, e.g., lr_backbone, lr_head
lr_keys = [i for i in optim_opt.keys() if i.startswith('lr_')]
optim_params = []
for key in lr_keys:
if key.startswith('lr_'):
module_key = key.replace('lr_', '')
logger = get_root_logger()
logger.info(
f'Set optimizer for {module_key} with lr: {optim_opt[key]}, weight_decay: {optim_opt.get(f"weight_decay_{module_key}", 0.0)}'
)
optim_params.append(
{
'params': [
param_dict[k]
for k in param_keys
if module_key in k and param_dict[k].requires_grad
],
'lr': optim_opt.pop(key, 0.0),
'weight_decay': optim_opt.pop(
f'weight_decay_{module_key}', 0.0
),
}
)
# should use param_keys[:] to avoid iteration error
for k in param_keys[:]:
if module_key in k:
param_keys.remove(k)
# append the rest of the parameters
optim_params.append(
{
'params': [
param_dict[k] for k in param_keys if param_dict[k].requires_grad
],
}
)
# log params that will not be optimized
for k, v in param_dict.items():
if not v.requires_grad:
logger = get_root_logger()
logger.warning(f'Params {k} will not be optimized.')
# remove blank param list
for k in optim_params:
if len(k['params']) == 0:
optim_params.remove(k)
optim_type = train_opt['optim'].pop('type')
self.optimizer = self.get_optimizer(
optim_type, optim_params, **train_opt['optim']
)
self.optimizers.append(self.optimizer)
[docs]
def feed_data(self, data):
self.img_input = data['img'].to(self.device)
if 'mos_label' in data:
self.gt_mos = data['mos_label'].to(self.device)
if 'ref_img' in data:
self.use_ref = True
self.ref_input = data['ref_img'].to(self.device)
else:
self.use_ref = False
if 'use_ref' in self.opt['train']:
self.use_ref = self.opt['train']['use_ref']
[docs]
def net_forward(self, net):
if self.use_ref:
return net(self.img_input, self.ref_input)
else:
return net(self.img_input)
[docs]
def optimize_parameters(self, current_iter):
self.optimizer.zero_grad()
self.output_score = self.net_forward(self.net)
l_total = 0
loss_dict = OrderedDict()
# pixel loss
if self.cri_mos:
l_mos = self.cri_mos(self.output_score, self.gt_mos)
l_total += l_mos
loss_dict['l_mos'] = l_mos
if self.cri_metric:
l_metric = self.cri_metric(self.output_score, self.gt_mos)
l_total += l_metric
loss_dict['l_metric'] = l_metric
l_total.backward()
self.optimizer.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
# log metrics in training batch
pred_score = self.output_score.squeeze(1).cpu().detach().numpy()
gt_mos = self.gt_mos.squeeze(1).cpu().detach().numpy()
for name, opt_ in self.opt['val']['metrics'].items():
self.log_dict[f'train_metrics/{name}'] = calculate_metric(
[pred_score, gt_mos], opt_
)
[docs]
def test(self):
self.net.eval()
with torch.no_grad():
self.output_score = self.net_forward(self.net)
self.net.train()
[docs]
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
if self.opt['rank'] == 0:
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
[docs]
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
use_pbar = self.opt['val'].get('pbar', False)
if with_metrics:
if not hasattr(self, 'metric_results'): # only execute in the first run
self.metric_results = {
metric: 0 for metric in self.opt['val']['metrics'].keys()
}
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
if with_metrics:
self.metric_results = {metric: 0 for metric in self.metric_results}
if use_pbar:
pbar = tqdm(total=len(dataloader), unit='image')
pred_score = []
gt_mos = []
for idx, val_data in enumerate(dataloader):
img_name = osp.basename(val_data['img_path'][0])
self.feed_data(val_data)
self.test()
pred_score.append(self.output_score)
gt_mos.append(self.gt_mos)
if use_pbar:
pbar.update(1)
pbar.set_description(f'Test {img_name:>20}')
if use_pbar:
pbar.close()
pred_score = torch.cat(pred_score, dim=0).squeeze(1).cpu().numpy()
gt_mos = torch.cat(gt_mos, dim=0).squeeze(1).cpu().numpy()
if with_metrics:
# calculate all metrics
for name, opt_ in self.opt['val']['metrics'].items():
self.metric_results[name] = calculate_metric([pred_score, gt_mos], opt_)
if self.key_metric is not None:
# If the best metric is updated, update and save best model
to_update = self._update_best_metric_result(
dataset_name,
self.key_metric,
self.metric_results[self.key_metric],
current_iter,
)
if to_update:
for name, opt_ in self.opt['val']['metrics'].items():
self._update_metric_result(
dataset_name, name, self.metric_results[name], current_iter
)
self.copy_model(self.net, self.net_best)
self.save_network(self.net_best, 'net_best')
else:
# update each metric separately
updated = []
for name, opt_ in self.opt['val']['metrics'].items():
tmp_updated = self._update_best_metric_result(
dataset_name, name, self.metric_results[name], current_iter
)
updated.append(tmp_updated)
# save best model if any metric is updated
if sum(updated):
self.copy_model(self.net, self.net_best)
self.save_network(self.net_best, 'net_best')
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
log_str = f'Validation {dataset_name}\n'
for metric, value in self.metric_results.items():
log_str += f'\t # {metric}: {value:.4f}'
if hasattr(self, 'best_metric_results'):
log_str += (
f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter'
)
log_str += '\n'
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric, value in self.metric_results.items():
tb_logger.add_scalar(
f'val_metrics/{dataset_name}/{metric}', value, current_iter
)
[docs]
def save(self, epoch, current_iter, save_net_label='net'):
self.save_network(self.net, save_net_label, current_iter)
self.save_training_state(epoch, current_iter)