pyiqa.archs.musiq_arch¶
MUSIQ model.
- Reference:
Ke, Junjie, Qifei Wang, Yilin Wang, Peyman Milanfar, and Feng Yang. “Musiq: Multi-scale image quality transformer.” In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), pp. 5148-5157. 2021.
Ref url: https://github.com/google-research/google-research/tree/master/musiq Re-implemented by: Chaofeng Chen (https://github.com/chaofengc)
Module Contents¶
- class pyiqa.archs.musiq_arch.StdConv[source]¶
Bases:
torch.nn.Conv2dReference: https://github.com/joe-siyuan-qiao/WeightStandardization
- class pyiqa.archs.musiq_arch.Bottleneck(inplanes, outplanes, stride=1)[source]¶
Bases:
torch.nn.Module
- class pyiqa.archs.musiq_arch.DropPath(drop_prob=None)[source]¶
Bases:
torch.nn.ModuleDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- class pyiqa.archs.musiq_arch.Mlp(in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0)[source]¶
Bases:
torch.nn.Module
- class pyiqa.archs.musiq_arch.MultiHeadAttention(dim, num_heads=6, bias=False, attn_drop=0.0, out_drop=0.0)[source]¶
Bases:
torch.nn.Module
- class pyiqa.archs.musiq_arch.TransformerBlock(dim, mlp_dim, num_heads, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm)[source]¶
Bases:
torch.nn.Module
- class pyiqa.archs.musiq_arch.AddHashSpatialPositionEmbs(spatial_pos_grid_size, dim)[source]¶
Bases:
torch.nn.ModuleAdds learnable hash-based spatial embeddings to the inputs.
- class pyiqa.archs.musiq_arch.AddScaleEmbs(num_scales, dim)[source]¶
Bases:
torch.nn.ModuleAdds learnable scale embeddings to the inputs.
- class pyiqa.archs.musiq_arch.TransformerEncoder(input_dim, mlp_dim=1152, attention_dropout_rate=0.0, dropout_rate=0, num_heads=6, num_layers=14, num_scales=3, spatial_pos_grid_size=10, use_scale_emb=True, use_sinusoid_pos_emb=False)[source]¶
Bases:
torch.nn.Module
- class pyiqa.archs.musiq_arch.MUSIQ(patch_size=32, num_class=1, hidden_size=384, mlp_dim=1152, attention_dropout_rate=0.0, dropout_rate=0, num_heads=6, num_layers=14, num_scales=3, spatial_pos_grid_size=10, use_scale_emb=True, use_sinusoid_pos_emb=False, pretrained=True, pretrained_model_path=None, longer_side_lengths=[224, 384], max_seq_len_from_original_res=-1)[source]¶
Bases:
torch.nn.ModuleMUSIQ model architecture.
- Parameters:
patch_size (-) – Size of the patches to extract from the images.
num_class (-) – Number of classes to predict.
hidden_size (-) – Size of the hidden layer in the transformer encoder.
mlp_dim (-) – Size of the feedforward layer in the transformer encoder.
attention_dropout_rate (-) – Dropout rate for the attention layer in the transformer encoder.
dropout_rate (-) – Dropout rate for the transformer encoder.
num_heads (-) – Number of attention heads in the transformer encoder.
num_layers (-) – Number of layers in the transformer encoder.
num_scales (-) – Number of scales to use in the transformer encoder.
spatial_pos_grid_size (-) – Size of the spatial position grid in the transformer encoder.
use_scale_emb (-) – Whether to use scale embeddings in the transformer encoder.
use_sinusoid_pos_emb (-) – Whether to use sinusoidal position embeddings in the transformer encoder.
pretrained (-) – Whether to use a pretrained model. If str, specifies the path to the pretrained model.
pretrained_model_path (-) – Path to the pretrained model.
longer_side_lengths (-) – List of longer side lengths to use for multiscale evaluation.
max_seq_len_from_original_res (-) – Maximum sequence length to use for multiscale evaluation.
- - conv_root
Convolutional layer for the root of the network.
- Type:
- - gn_root
Group normalization layer for the root of the network.
- Type:
nn.GroupNorm
- - root_pool
Max pooling layer for the root of the network.
- Type:
nn.Sequential
- - block1
First bottleneck block in the network.
- Type:
- - embedding
Linear layer for the transformer encoder input.
- Type:
nn.Linear
- - transformer_encoder
Transformer encoder.
- Type:
- - head
Output layer of the network.
- Type:
nn.Sequential or nn.Linear
- forward(x, return_mos=True, return_dist=False)[source]¶
Forward pass of the MUSIQ network.
- Parameters:
x (torch.Tensor) – Input tensor.
return_mos (bool) – Whether to return the mean opinion score (MOS).
return_dist (bool) – Whether to return the predicted distribution.
- Returns:
If only one of return_mos and return_dist is True, returns a tensor. If both are True, returns a tuple of tensors.
- Return type:
torch.Tensor or tuple