pyiqa.archs.maniqa_arch

MANIQA proposed by

MANIQA: Multi-dimension Attention Network for No-Reference Image Quality Assessment Sidi Yang, Tianhe Wu, Shuwei Shi, Shanshan Lao, Yuan Gong, Mingdeng Cao, Jiahao Wang and Yujiu Yang. CVPR Workshop 2022, winner of NTIRE2022 NRIQA challenge

Reference:

Module Contents

pyiqa.archs.maniqa_arch.default_model_urls[source]
class pyiqa.archs.maniqa_arch.TABlock(dim, drop=0.1)[source]

Bases: torch.nn.Module

Token-attention block used in MANIQA stages.

forward(x)[source]
class pyiqa.archs.maniqa_arch.SaveOutput[source]

Forward-hook collector for intermediate ViT block outputs.

clear()[source]
class pyiqa.archs.maniqa_arch.MANIQA(embed_dim=768, num_outputs=1, patch_size=8, drop=0.1, depths=[2, 2], window_size=4, dim_mlp=768, num_heads=[4, 4], img_size=224, num_tab=2, scale=0.13, test_sample=20, pretrained=True, pretrained_model_path=None, train_dataset='pipal', default_mean=None, default_std=None, **kwargs)[source]

Bases: torch.nn.Module

MANIQA no-reference IQA model.

Parameters:
  • embed_dim (int) – Embedding dimension.

  • num_outputs (int) – Number of output channels.

  • patch_size (int) – Patch size used by ViT backbone.

  • drop (float) – Dropout ratio for prediction heads.

  • depths (list[int]) – Depths of Swin blocks.

  • window_size (int) – Swin attention window size.

  • dim_mlp (int) – MLP dimension used in Swin blocks.

  • num_heads (list[int]) – Number of attention heads in Swin blocks.

  • img_size (int) – Input crop size.

  • num_tab (int) – Number of token-attention blocks per stage.

  • scale (float) – Swin scaling factor.

  • test_sample (int) – Number of evaluation crops.

  • pretrained (bool) – Whether to load pretrained model weights.

  • pretrained_model_path (str | None) – Optional local checkpoint path.

  • train_dataset (str) – Checkpoint key for pretrained loading.

  • default_mean (torch.Tensor | None) – Optional custom normalization mean.

  • default_std (torch.Tensor | None) – Optional custom normalization std.

  • **kwargs – Reserved compatibility arguments.

extract_feature(save_output)[source]

Concatenate selected ViT block tokens into MANIQA feature tensor.

forward(x)[source]

Predict image quality score.

Parameters:

x (torch.Tensor) – Input tensor with shape (N, 3, H, W).

Returns:

Predicted score tensor with shape (N, 1).

Return type:

torch.Tensor