pyiqa.archs.ahiq_arch¶
This module implements the Attention-based Hybrid Image Quality (AHIQ) assessment network as introduced in the following paper:
- @article{lao2022attentions,
title = {Attentions Help CNNs See Better: Attention-based Hybrid Image Quality Assessment Network}, author = {Lao, Shanshan and Gong, Yuan and Shi, Shuwei and Yang, Sidi and Wu, Tianhe and Wang, Jiahao and Xia, Weihao and Yang, Yujiu}, journal = {arXiv preprint arXiv:2204.10485}, year = {2022}
}
Reference URL: https://github.com/IIGROUP/AHIQ Re-implemented by: Chaofeng Chen (https://github.com/chaofengc)
Module Contents¶
- class pyiqa.archs.ahiq_arch.SaveOutput[source]¶
SaveOutput class to save intermediate outputs of layers during forward pass.
- class pyiqa.archs.ahiq_arch.DeformFusion(patch_size=8, in_channels=768 * 5, cnn_channels=256 * 3, out_channels=256 * 3)[source]¶
Bases:
torch.nn.ModuleDeformable Fusion Network.
- Parameters:
patch_size (int, optional) – Size of the patches. Default is 8.
in_channels (int, optional) – Number of input channels. Default is 768 * 5.
cnn_channels (int, optional) – Number of CNN channels. Default is 256 * 3.
out_channels (int, optional) – Number of output channels. Default is 256 * 3.
- class pyiqa.archs.ahiq_arch.Pixel_Prediction(inchannels=768 * 5 + 256 * 3, outchannels=256, d_hidn=1024)[source]¶
Bases:
torch.nn.ModulePixel Prediction Network.
- Parameters:
inchannels (int, optional) – Number of input channels. Default is 768 * 5 + 256 * 3.
outchannels (int, optional) – Number of output channels. Default is 256.
d_hidn (int, optional) – Hidden dimension. Default is 1024.
- class pyiqa.archs.ahiq_arch.AHIQ(num_crop=20, crop_size=224, default_mean=[0.485, 0.456, 0.406], default_std=[0.229, 0.224, 0.225], pretrained=True, pretrained_model_path=None)[source]¶
Bases:
torch.nn.ModuleAHIQ model implementation.
This class implements the Attention-based Hybrid Image Quality (AHIQ) assessment network, which combines ResNet50 and Vision Transformer (ViT) backbones with deformable convolution layers for enhanced image quality assessment.
- Parameters:
num_crop (-) – Number of crops to use for testing. Default is 20.
crop_size (-) – Size of the crops. Default is 224.
default_mean (-) – List of mean values for normalization. Default is [0.485, 0.456, 0.406].
default_std (-) – List of standard deviation values for normalization. Default is [0.229, 0.224, 0.225].
pretrained (-) – Whether to use a pretrained model. Default is True.
pretrained_model_path (-) – Path to a pretrained model. Default is None.
- - resnet50
ResNet50 backbone.
- Type:
nn.Module
- - vit
Vision Transformer backbone.
- Type:
nn.Module
- - deform_net
Deformable fusion network.
- Type:
nn.Module
- - regressor
Pixel prediction network.
- Type:
nn.Module
- - default_mean
Mean values for normalization.
- Type:
torch.Tensor
- - default_std
Standard deviation values for normalization.
- Type:
torch.Tensor
- - eps
Small value to avoid division by zero.
- Type:
float
- - crops
Number of crops to use for testing.
- Type:
int
- - crop_size
Size of the crops.
- Type:
int
- fix_network(model)[source]¶
Fixes the network by setting all parameters to not require gradients.
- Parameters:
model (nn.Module) – The model to fix.
- preprocess(x)[source]¶
Preprocesses the input tensor by normalizing it.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The normalized tensor.
- Return type:
torch.Tensor
- get_vit_feature(x)[source]¶
Gets the intermediate features from the Vision Transformer backbone.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The intermediate features.
- Return type:
torch.Tensor
- get_resnet_feature(x)[source]¶
Gets the intermediate features from the ResNet50 backbone.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The intermediate features.
- Return type:
torch.Tensor