pyiqa.archs.iqt_arch

IQA metric introduced by

@inproceedings{cheon2021iqt,

title={Perceptual image quality assessment with transformers}, author={Cheon, Manri and Yoon, Sung-Jun and Kang, Byungyeon and Lee, Junwoo}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, pages={433–442}, year={2021}

}

Ref url: https://github.com/anse3832/IQT Re-implemented by: Chaofeng Chen (https://github.com/chaofengc)

Module Contents

class pyiqa.archs.iqt_arch.IQARegression(config)[source]

Bases: torch.nn.Module

forward(enc_inputs, enc_inputs_embed, dec_inputs, dec_inputs_embed)[source]
class pyiqa.archs.iqt_arch.Transformer(config)[source]

Bases: torch.nn.Module

forward(enc_inputs, enc_inputs_embed, dec_inputs, dec_inputs_embed)[source]
class pyiqa.archs.iqt_arch.Encoder(config)[source]

Bases: torch.nn.Module

forward(inputs, inputs_embed)[source]
class pyiqa.archs.iqt_arch.EncoderLayer(config)[source]

Bases: torch.nn.Module

forward(inputs, attn_mask)[source]
pyiqa.archs.iqt_arch.get_sinusoid_encoding_table(n_seq, d_hidn)[source]
pyiqa.archs.iqt_arch.get_attn_pad_mask(seq_q, seq_k, i_pad)[source]
class pyiqa.archs.iqt_arch.MultiHeadAttention(config)[source]

Bases: torch.nn.Module

forward(Q, K, V, attn_mask)[source]
class pyiqa.archs.iqt_arch.ScaledDotProductAttention(config)[source]

Bases: torch.nn.Module

forward(Q, K, V, attn_mask)[source]
class pyiqa.archs.iqt_arch.PoswiseFeedForwardNet(config)[source]

Bases: torch.nn.Module

forward(inputs)[source]
class pyiqa.archs.iqt_arch.Decoder(config)[source]

Bases: torch.nn.Module

forward(dec_inputs, dec_inputs_embed, enc_inputs, enc_outputs)[source]
class pyiqa.archs.iqt_arch.DecoderLayer(config)[source]

Bases: torch.nn.Module

forward(dec_inputs, enc_outputs, self_attn_mask, dec_enc_attn_mask)[source]
pyiqa.archs.iqt_arch.get_attn_decoder_mask(seq)[source]
class pyiqa.archs.iqt_arch.SaveOutput[source]
clear(device)[source]
class pyiqa.archs.iqt_arch.DeformFusion(patch_size=8, in_channels=768 * 5, cnn_channels=256 * 3, out_channels=256 * 3)[source]

Bases: torch.nn.Module

forward(cnn_feat, vit_feat)[source]
class pyiqa.archs.iqt_arch.Pixel_Prediction(inchannels=768 * 5 + 256 * 3, outchannels=256, d_hidn=1024)[source]

Bases: torch.nn.Module

forward(f_dis, f_ref, cnn_dis, cnn_ref)[source]
class pyiqa.archs.iqt_arch.IQT(num_crop=20, config_dataset='live', default_mean=timm.data.IMAGENET_INCEPTION_MEAN, default_std=timm.data.IMAGENET_INCEPTION_STD, pretrained=False, pretrained_model_path=None)[source]

Bases: torch.nn.Module

Image Quality Transformer (IQT) model for image quality assessment.

Parameters:
  • num_crop (-) – Number of crops to take from the input image.

  • config_dataset (-) – Name of the dataset to use for configuration.

  • default_mean (-) – Default mean values for input normalization.

  • default_std (-) – Default standard deviation values for input normalization.

  • pretrained (-) – Whether to use a pretrained model.

  • pretrained_model_path (-) – Path to the pretrained model.

- backbone

Inception ResNet V2 backbone model.

Type:

nn.Module

- config

Configuration object for the IQT model.

Type:

Config

- enc_inputs

Encoded input tensor.

Type:

torch.Tensor

- dec_inputs

Decoded input tensor.

Type:

torch.Tensor

- regressor

Regression model for IQT.

Type:

IQARegression

- default_mean

Default mean values for input normalization.

Type:

torch.Tensor

- default_std

Default standard deviation values for input normalization.

Type:

torch.Tensor

- eps

Epsilon value for numerical stability.

Type:

float

- crops

Number of crops to take from the input image.

Type:

int

- crop_size

Size of the input image crop.

Type:

int

init_saveoutput()[source]

Initialize the SaveOutput object and register hook handles for the backbone model.

fix_network(model)[source]

Fix the network by setting all parameters to not require gradients.

Parameters:

model (nn.Module) – The model to fix.

preprocess(x)[source]

Preprocess the input tensor by normalizing it.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The normalized input tensor.

Return type:

torch.Tensor

get_backbone_feature(x)[source]

Get the backbone features for the input tensor.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The backbone features for the input tensor.

Return type:

torch.Tensor

regress_score(dis, ref)[source]

Regress the score for the input image.

Parameters:
  • dis (torch.Tensor) – The distorted image.

  • ref (torch.Tensor) – The reference image.

Returns:

The predicted score for the input image.

Return type:

torch.Tensor

forward(x, y)[source]