Source code for pyiqa.models.wadiqam_model


from pyiqa.utils.registry import MODEL_REGISTRY
from .general_iqa_model import GeneralIQAModel


@MODEL_REGISTRY.register()
[docs] class WaDIQaMModel(GeneralIQAModel): """General module to train an IQA network."""
[docs] def setup_optimizers(self): train_opt = self.opt['train'] optim_opt = train_opt['optim'] bare_net = self.get_bare_model(self.net) optim_params = [ { 'params': bare_net.features.parameters(), 'lr': optim_opt.pop('lr_basemodel'), }, { 'params': [ p for k, p in bare_net.named_parameters() if 'features' not in k ], 'lr': optim_opt.pop('lr_fc_layers'), }, ] optim_type = optim_opt.pop('type') self.optimizer = self.get_optimizer(optim_type, optim_params, **optim_opt) self.optimizers.append(self.optimizer)