pyiqa.archs.iqt_arch ==================== .. py:module:: pyiqa.archs.iqt_arch .. autoapi-nested-parse:: IQA metric introduced by @inproceedings{cheon2021iqt, title={Perceptual image quality assessment with transformers}, author={Cheon, Manri and Yoon, Sung-Jun and Kang, Byungyeon and Lee, Junwoo}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, pages={433--442}, year={2021} } Ref url: https://github.com/anse3832/IQT Re-implemented by: Chaofeng Chen (https://github.com/chaofengc) Module Contents --------------- .. py:class:: IQARegression(config) Bases: :py:obj:`torch.nn.Module` .. py:method:: forward(enc_inputs, enc_inputs_embed, dec_inputs, dec_inputs_embed) .. py:class:: Transformer(config) Bases: :py:obj:`torch.nn.Module` .. py:method:: forward(enc_inputs, enc_inputs_embed, dec_inputs, dec_inputs_embed) .. py:class:: Encoder(config) Bases: :py:obj:`torch.nn.Module` .. py:method:: forward(inputs, inputs_embed) .. py:class:: EncoderLayer(config) Bases: :py:obj:`torch.nn.Module` .. py:method:: forward(inputs, attn_mask) .. py:function:: get_sinusoid_encoding_table(n_seq, d_hidn) .. py:function:: get_attn_pad_mask(seq_q, seq_k, i_pad) .. py:class:: MultiHeadAttention(config) Bases: :py:obj:`torch.nn.Module` .. py:method:: forward(Q, K, V, attn_mask) .. py:class:: ScaledDotProductAttention(config) Bases: :py:obj:`torch.nn.Module` .. py:method:: forward(Q, K, V, attn_mask) .. py:class:: PoswiseFeedForwardNet(config) Bases: :py:obj:`torch.nn.Module` .. py:method:: forward(inputs) .. py:class:: Decoder(config) Bases: :py:obj:`torch.nn.Module` .. py:method:: forward(dec_inputs, dec_inputs_embed, enc_inputs, enc_outputs) .. py:class:: DecoderLayer(config) Bases: :py:obj:`torch.nn.Module` .. py:method:: forward(dec_inputs, enc_outputs, self_attn_mask, dec_enc_attn_mask) .. py:function:: get_attn_decoder_mask(seq) .. py:class:: SaveOutput .. py:method:: clear(device) .. py:class:: DeformFusion(patch_size=8, in_channels=768 * 5, cnn_channels=256 * 3, out_channels=256 * 3) Bases: :py:obj:`torch.nn.Module` .. py:method:: forward(cnn_feat, vit_feat) .. py:class:: Pixel_Prediction(inchannels=768 * 5 + 256 * 3, outchannels=256, d_hidn=1024) Bases: :py:obj:`torch.nn.Module` .. py:method:: forward(f_dis, f_ref, cnn_dis, cnn_ref) .. py:class:: IQT(num_crop=20, config_dataset='live', default_mean=timm.data.IMAGENET_INCEPTION_MEAN, default_std=timm.data.IMAGENET_INCEPTION_STD, pretrained=False, pretrained_model_path=None) Bases: :py:obj:`torch.nn.Module` Image Quality Transformer (IQT) model for image quality assessment. :param - num_crop: Number of crops to take from the input image. :type - num_crop: int :param - config_dataset: Name of the dataset to use for configuration. :type - config_dataset: str :param - default_mean: Default mean values for input normalization. :type - default_mean: list :param - default_std: Default standard deviation values for input normalization. :type - default_std: list :param - pretrained: Whether to use a pretrained model. :type - pretrained: bool :param - pretrained_model_path: Path to the pretrained model. :type - pretrained_model_path: str .. attribute:: - backbone Inception ResNet V2 backbone model. :type: nn.Module .. attribute:: - config Configuration object for the IQT model. :type: Config .. attribute:: - enc_inputs Encoded input tensor. :type: torch.Tensor .. attribute:: - dec_inputs Decoded input tensor. :type: torch.Tensor .. attribute:: - regressor Regression model for IQT. :type: IQARegression .. attribute:: - default_mean Default mean values for input normalization. :type: torch.Tensor .. attribute:: - default_std Default standard deviation values for input normalization. :type: torch.Tensor .. attribute:: - eps Epsilon value for numerical stability. :type: float .. attribute:: - crops Number of crops to take from the input image. :type: int .. attribute:: - crop_size Size of the input image crop. :type: int .. py:method:: init_saveoutput() Initialize the SaveOutput object and register hook handles for the backbone model. .. py:method:: fix_network(model) Fix the network by setting all parameters to not require gradients. :param model: The model to fix. :type model: nn.Module .. py:method:: preprocess(x) Preprocess the input tensor by normalizing it. :param x: The input tensor. :type x: torch.Tensor :returns: The normalized input tensor. :rtype: torch.Tensor .. py:method:: get_backbone_feature(x) Get the backbone features for the input tensor. :param x: The input tensor. :type x: torch.Tensor :returns: The backbone features for the input tensor. :rtype: torch.Tensor .. py:method:: regress_score(dis, ref) Regress the score for the input image. :param dis: The distorted image. :type dis: torch.Tensor :param ref: The reference image. :type ref: torch.Tensor :returns: The predicted score for the input image. :rtype: torch.Tensor .. py:method:: forward(x, y)