pyiqa.archs.ahiq_arch ===================== .. py:module:: pyiqa.archs.ahiq_arch .. autoapi-nested-parse:: AHIQ Metric Implementation ========================== This module implements the Attention-based Hybrid Image Quality (AHIQ) assessment network as introduced in the following paper: @article{lao2022attentions, title = {Attentions Help CNNs See Better: Attention-based Hybrid Image Quality Assessment Network}, author = {Lao, Shanshan and Gong, Yuan and Shi, Shuwei and Yang, Sidi and Wu, Tianhe and Wang, Jiahao and Xia, Weihao and Yang, Yujiu}, journal = {arXiv preprint arXiv:2204.10485}, year = {2022} } Reference URL: https://github.com/IIGROUP/AHIQ Re-implemented by: Chaofeng Chen (https://github.com/chaofengc) Module Contents --------------- .. py:data:: default_model_urls .. py:class:: SaveOutput SaveOutput class to save intermediate outputs of layers during forward pass. .. py:method:: clear(device) .. py:class:: DeformFusion(patch_size=8, in_channels=768 * 5, cnn_channels=256 * 3, out_channels=256 * 3) Bases: :py:obj:`torch.nn.Module` Deformable Fusion Network. :param patch_size: Size of the patches. Default is 8. :type patch_size: int, optional :param in_channels: Number of input channels. Default is 768 * 5. :type in_channels: int, optional :param cnn_channels: Number of CNN channels. Default is 256 * 3. :type cnn_channels: int, optional :param out_channels: Number of output channels. Default is 256 * 3. :type out_channels: int, optional .. py:method:: forward(cnn_feat, vit_feat) .. py:class:: Pixel_Prediction(inchannels=768 * 5 + 256 * 3, outchannels=256, d_hidn=1024) Bases: :py:obj:`torch.nn.Module` Pixel Prediction Network. :param inchannels: Number of input channels. Default is 768 * 5 + 256 * 3. :type inchannels: int, optional :param outchannels: Number of output channels. Default is 256. :type outchannels: int, optional :param d_hidn: Hidden dimension. Default is 1024. :type d_hidn: int, optional .. py:method:: forward(f_dis, f_ref, cnn_dis, cnn_ref) .. py:class:: AHIQ(num_crop=20, crop_size=224, default_mean=[0.485, 0.456, 0.406], default_std=[0.229, 0.224, 0.225], pretrained=True, pretrained_model_path=None) Bases: :py:obj:`torch.nn.Module` AHIQ model implementation. This class implements the Attention-based Hybrid Image Quality (AHIQ) assessment network, which combines ResNet50 and Vision Transformer (ViT) backbones with deformable convolution layers for enhanced image quality assessment. :param - num_crop: Number of crops to use for testing. Default is 20. :type - num_crop: int, optional :param - crop_size: Size of the crops. Default is 224. :type - crop_size: int, optional :param - default_mean: List of mean values for normalization. Default is [0.485, 0.456, 0.406]. :type - default_mean: list, optional :param - default_std: List of standard deviation values for normalization. Default is [0.229, 0.224, 0.225]. :type - default_std: list, optional :param - pretrained: Whether to use a pretrained model. Default is True. :type - pretrained: bool, optional :param - pretrained_model_path: Path to a pretrained model. Default is None. :type - pretrained_model_path: str, optional .. attribute:: - resnet50 ResNet50 backbone. :type: nn.Module .. attribute:: - vit Vision Transformer backbone. :type: nn.Module .. attribute:: - deform_net Deformable fusion network. :type: nn.Module .. attribute:: - regressor Pixel prediction network. :type: nn.Module .. attribute:: - default_mean Mean values for normalization. :type: torch.Tensor .. attribute:: - default_std Standard deviation values for normalization. :type: torch.Tensor .. attribute:: - eps Small value to avoid division by zero. :type: float .. attribute:: - crops Number of crops to use for testing. :type: int .. attribute:: - crop_size Size of the crops. :type: int .. py:method:: init_saveoutput() Initializes the SaveOutput hook to get intermediate features. .. py:method:: fix_network(model) Fixes the network by setting all parameters to not require gradients. :param model: The model to fix. :type model: nn.Module .. py:method:: preprocess(x) Preprocesses the input tensor by normalizing it. :param x: The input tensor. :type x: torch.Tensor :returns: The normalized tensor. :rtype: torch.Tensor .. py:method:: get_vit_feature(x) Gets the intermediate features from the Vision Transformer backbone. :param x: The input tensor. :type x: torch.Tensor :returns: The intermediate features. :rtype: torch.Tensor .. py:method:: get_resnet_feature(x) Gets the intermediate features from the ResNet50 backbone. :param x: The input tensor. :type x: torch.Tensor :returns: The intermediate features. :rtype: torch.Tensor .. py:method:: regress_score(dis, ref) Computes the quality score for a distorted and reference image pair. :param - dis: The distorted image. :type - dis: torch.Tensor :param - ref: The reference image. :type - ref: torch.Tensor :returns: The quality score. :rtype: torch.Tensor .. py:method:: forward(x, y) Computes the quality score for a batch of distorted and reference image pairs. :param - x: The batch of distorted images. :type - x: torch.Tensor :param - y: The batch of reference images. :type - y: torch.Tensor :returns: The quality scores. :rtype: torch.Tensor