pyiqa.archs.ahiq_arch

This module implements the Attention-based Hybrid Image Quality (AHIQ) assessment network as introduced in the following paper:

@article{lao2022attentions,

title = {Attentions Help CNNs See Better: Attention-based Hybrid Image Quality Assessment Network}, author = {Lao, Shanshan and Gong, Yuan and Shi, Shuwei and Yang, Sidi and Wu, Tianhe and Wang, Jiahao and Xia, Weihao and Yang, Yujiu}, journal = {arXiv preprint arXiv:2204.10485}, year = {2022}

}

Reference URL: https://github.com/IIGROUP/AHIQ Re-implemented by: Chaofeng Chen (https://github.com/chaofengc)

Module Contents

pyiqa.archs.ahiq_arch.default_model_urls[source]
class pyiqa.archs.ahiq_arch.SaveOutput[source]

SaveOutput class to save intermediate outputs of layers during forward pass.

clear(device)[source]
class pyiqa.archs.ahiq_arch.DeformFusion(patch_size=8, in_channels=768 * 5, cnn_channels=256 * 3, out_channels=256 * 3)[source]

Bases: torch.nn.Module

Deformable Fusion Network.

Parameters:
  • patch_size (int, optional) – Size of the patches. Default is 8.

  • in_channels (int, optional) – Number of input channels. Default is 768 * 5.

  • cnn_channels (int, optional) – Number of CNN channels. Default is 256 * 3.

  • out_channels (int, optional) – Number of output channels. Default is 256 * 3.

forward(cnn_feat, vit_feat)[source]
class pyiqa.archs.ahiq_arch.Pixel_Prediction(inchannels=768 * 5 + 256 * 3, outchannels=256, d_hidn=1024)[source]

Bases: torch.nn.Module

Pixel Prediction Network.

Parameters:
  • inchannels (int, optional) – Number of input channels. Default is 768 * 5 + 256 * 3.

  • outchannels (int, optional) – Number of output channels. Default is 256.

  • d_hidn (int, optional) – Hidden dimension. Default is 1024.

forward(f_dis, f_ref, cnn_dis, cnn_ref)[source]
class pyiqa.archs.ahiq_arch.AHIQ(num_crop=20, crop_size=224, default_mean=[0.485, 0.456, 0.406], default_std=[0.229, 0.224, 0.225], pretrained=True, pretrained_model_path=None)[source]

Bases: torch.nn.Module

AHIQ model implementation.

This class implements the Attention-based Hybrid Image Quality (AHIQ) assessment network, which combines ResNet50 and Vision Transformer (ViT) backbones with deformable convolution layers for enhanced image quality assessment.

Parameters:
  • num_crop (-) – Number of crops to use for testing. Default is 20.

  • crop_size (-) – Size of the crops. Default is 224.

  • default_mean (-) – List of mean values for normalization. Default is [0.485, 0.456, 0.406].

  • default_std (-) – List of standard deviation values for normalization. Default is [0.229, 0.224, 0.225].

  • pretrained (-) – Whether to use a pretrained model. Default is True.

  • pretrained_model_path (-) – Path to a pretrained model. Default is None.

- resnet50

ResNet50 backbone.

Type:

nn.Module

- vit

Vision Transformer backbone.

Type:

nn.Module

- deform_net

Deformable fusion network.

Type:

nn.Module

- regressor

Pixel prediction network.

Type:

nn.Module

- default_mean

Mean values for normalization.

Type:

torch.Tensor

- default_std

Standard deviation values for normalization.

Type:

torch.Tensor

- eps

Small value to avoid division by zero.

Type:

float

- crops

Number of crops to use for testing.

Type:

int

- crop_size

Size of the crops.

Type:

int

init_saveoutput()[source]

Initializes the SaveOutput hook to get intermediate features.

fix_network(model)[source]

Fixes the network by setting all parameters to not require gradients.

Parameters:

model (nn.Module) – The model to fix.

preprocess(x)[source]

Preprocesses the input tensor by normalizing it.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The normalized tensor.

Return type:

torch.Tensor

get_vit_feature(x)[source]

Gets the intermediate features from the Vision Transformer backbone.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The intermediate features.

Return type:

torch.Tensor

get_resnet_feature(x)[source]

Gets the intermediate features from the ResNet50 backbone.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The intermediate features.

Return type:

torch.Tensor

regress_score(dis, ref)[source]

Computes the quality score for a distorted and reference image pair.

Parameters:
  • dis (-) – The distorted image.

  • ref (-) – The reference image.

Returns:

The quality score.

Return type:

torch.Tensor

forward(x, y)[source]

Computes the quality score for a batch of distorted and reference image pairs.

Parameters:
  • x (-) – The batch of distorted images.

  • y (-) – The batch of reference images.

Returns:

The quality scores.

Return type:

torch.Tensor