pyiqa.archs.tres_arch¶
TReS model.
- Reference:
No-Reference Image Quality Assessment via Transformers, Relative Ranking, and Self-Consistency. S. Alireza Golestaneh, Saba Dadsetan, Kris M. Kitani WACV2022
Official code: https://github.com/isalirezag/TReS
Module Contents¶
- class pyiqa.archs.tres_arch.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation='relu', normalize_before=False, return_intermediate_dec=False)[source]¶
Bases:
torch.nn.ModuleTransformer encoder used by TReS to aggregate multiscale features.
- Parameters:
d_model (int) – Feature dimension.
nhead (int) – Number of attention heads.
num_encoder_layers (int) – Number of encoder layers.
num_decoder_layers (int) – Unused legacy argument kept for compatibility.
dim_feedforward (int) – Hidden dimension in feed-forward blocks.
dropout (float) – Dropout ratio.
activation (str) – Activation function name.
normalize_before (bool) – Whether to apply pre-normalization.
return_intermediate_dec (bool) – Legacy compatibility argument.
- class pyiqa.archs.tres_arch.TransformerEncoder(encoder_layer, num_layers, norm=None)[source]¶
Bases:
torch.nn.Module
- class pyiqa.archs.tres_arch.TransformerEncoderLayer(d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='relu', normalize_before=False)[source]¶
Bases:
torch.nn.Module- forward_post(src, src_mask: torch.Tensor | None = None, src_key_padding_mask: torch.Tensor | None = None, pos: torch.Tensor | None = None)[source]¶
- class pyiqa.archs.tres_arch.PositionEmbeddingSine(num_pos_feats=64, temperature=10000, normalize=False, scale=None)[source]¶
Bases:
torch.nn.ModuleSine-cosine positional encoding for 2D feature maps.
This implementation is adapted from DETR-style positional encoding and generates a fixed embedding tensor with shape
(N, C, H, W).
- class pyiqa.archs.tres_arch.L2pooling(filter_size=5, stride=1, channels=None, pad_off=0)[source]¶
Bases:
torch.nn.ModuleL2 pooling with Hann-window smoothing.
- class pyiqa.archs.tres_arch.TReS(network='resnet50', train_dataset='koniq', nheadt=16, num_encoder_layerst=2, dim_feedforwardt=64, test_sample=50, default_mean=[0.485, 0.456, 0.406], default_std=[0.229, 0.224, 0.225], pretrained=True, pretrained_model_path=None)[source]¶
Bases:
torch.nn.ModuleTReS no-reference IQA model.
- Parameters:
network (str) – ResNet backbone name.
train_dataset (str) – Dataset key used to choose default checkpoint.
nheadt (int) – Number of transformer attention heads.
num_encoder_layerst (int) – Number of transformer encoder blocks.
dim_feedforwardt (int) – Transformer feed-forward hidden size.
test_sample (int) – Number of uniform crops during evaluation.
default_mean (list[float]) – Input normalization mean in RGB order.
default_std (list[float]) – Input normalization std in RGB order.
pretrained (bool) – Whether to load default pretrained checkpoint.
pretrained_model_path (str | None) – Optional local checkpoint path.
Example
>>> metric = TReS(train_dataset='koniq') >>> x = torch.rand(1, 3, 512, 512) >>> score = metric(x)