pyiqa.models.base_model ======================= .. py:module:: pyiqa.models.base_model Module Contents --------------- .. py:class:: BaseModel(opt) Base model. .. py:method:: feed_data(data) .. py:method:: optimize_parameters() .. py:method:: get_current_visuals() .. py:method:: save(epoch, current_iter) Save networks and training state. .. py:method:: validation(dataloader, current_iter, tb_logger, save_img=False) Validation function. :param dataloader: Validation dataloader. :type dataloader: torch.utils.data.DataLoader :param current_iter: Current iteration. :type current_iter: int :param tb_logger: Tensorboard logger. :type tb_logger: tensorboard logger :param save_img: Whether to save images. Default: False. :type save_img: bool .. py:method:: model_ema(decay=0.999) .. py:method:: copy_model(net_a, net_b) copy model from net_a to net_b .. py:method:: get_current_log() .. py:method:: model_to_device(net) Model to device. It also warps models with DistributedDataParallel or DataParallel. :param net: :type net: nn.Module .. py:method:: get_optimizer(optim_type, params, lr, **kwargs) .. py:method:: setup_schedulers(scheduler_name='scheduler') Set up schedulers. .. py:method:: get_bare_model(net) Get bare model, especially under wrapping with DistributedDataParallel or DataParallel. .. py:method:: print_network(net) Print the str and parameter number of a network. :param net: :type net: nn.Module .. py:method:: update_learning_rate(current_iter, warmup_iter=-1) Update learning rate. :param current_iter: Current iteration. :type current_iter: int :param warmup_iter: Default: -1. :type warmup_iter: int .. py:method:: get_current_learning_rate() .. py:method:: save_network(net, net_label, current_iter=None, param_key='params') Save networks. :param net: Network(s) to be saved. :type net: nn.Module | list[nn.Module] :param net_label: Network label. :type net_label: str :param current_iter: Current iter number. :type current_iter: int :param param_key: The parameter key(s) to save network. Default: 'params'. :type param_key: str | list[str] .. py:method:: load_network(net, load_path, strict=True, param_key='params') Load network. :param load_path: The path of networks to be loaded. :type load_path: str :param net: Network. :type net: nn.Module :param strict: Whether strictly loaded. :type strict: bool :param param_key: The parameter key of loaded network. If set to None, use the root 'path'. Default: 'params'. :type param_key: str .. py:method:: save_training_state(epoch, current_iter) Save training states during training, which will be used for resuming. :param epoch: Current epoch. :type epoch: int :param current_iter: Current iteration. :type current_iter: int .. py:method:: resume_training(resume_state) Reload the optimizers and schedulers for resumed training. :param resume_state: Resume state. :type resume_state: dict .. py:method:: reduce_loss_dict(loss_dict) reduce loss dict. In distributed training, it averages the losses among different GPUs . :param loss_dict: Loss dict. :type loss_dict: OrderedDict