Source code for pyiqa.models.dbcnn_model
from pyiqa.utils import get_root_logger
from pyiqa.utils.registry import MODEL_REGISTRY
from pyiqa.models import lr_scheduler as lr_scheduler
from .general_iqa_model import GeneralIQAModel
@MODEL_REGISTRY.register()
[docs]
class DBCNNModel(GeneralIQAModel):
"""General module to train an IQA network."""
def __init__(self, opt):
super(DBCNNModel, self).__init__(opt)
self.train_stage = 'train'
[docs]
def reset_optimizers_finetune(self):
logger = get_root_logger()
logger.info('\n Start finetune stage. Set all parameters trainable\n')
train_opt = self.opt['train']
optim_params = []
for k, v in self.net.named_parameters():
v.requires_grad = True
optim_params.append(v)
optim_type = train_opt['optim_finetune'].pop('type')
self.optimizer = self.get_optimizer(
optim_type, optim_params, **train_opt['optim_finetune']
)
self.optimizers = [self.optimizer]
# reset schedulers
self.schedulers = []
self.setup_schedulers('scheduler_finetune')
[docs]
def optimize_parameters(self, current_iter):
if (
current_iter >= self.opt['train']['finetune_start_iter']
and self.train_stage != 'finetune'
):
# copy best model from coarse training stage and reset optimizers
self.copy_model(self.net_best, self.net)
self.reset_optimizers_finetune()
self.train_stage = 'finetune'
super().optimize_parameters(current_iter)