pyiqa.archs.stlpips_arch¶
ST-LPIPS Model
github repo link: https://github.com/abhijay9/ShiftTolerant-LPIPS
Cite as: @inproceedings{ghildyal2022stlpips,
title={Shift-tolerant Perceptual Similarity Metric}, author={Abhijay Ghildyal and Feng Liu}, booktitle={European Conference on Computer Vision}, year={2022}
}
Module Contents¶
- class pyiqa.archs.stlpips_arch.STLPIPS(pretrained=True, net='alex', variant='shift_tolerant', lpips=True, spatial=False, pnet_tune=False, use_dropout=True, pretrained_model_path=None, eval_mode=True, blur_filter_size=3)[source]¶
Bases:
torch.nn.ModuleST-LPIPS model. :param lpips: Whether to use linear layers on top of base/trunk network. :type lpips: Boolean :param pretrained: Whether means linear layers are calibrated with human
perceptual judgments.
- Parameters:
net (String) – [‘alex’,’vgg’,’squeeze’] are the base/trunk networks available.
pretrained_model_path (String) – Petrained model path.
network (The following parameters should only be changed if training the)
eval_mode (Boolean) – choose the mode; True is for test mode (default).
pnet_tune (Boolean) – Whether to tune the base/trunk network.
use_dropout (Boolean) – Whether to use dropout when training linear layers.
- forward(in0, in1, retPerLayer=False, normalize=True)[source]¶
Computation IQA using LPIPS. :param in1: An input tensor. Shape \((N, C, H, W)\). :param in0: A reference tensor. Shape \((N, C, H, W)\). :param retPerLayer: return result contains result of
each layer or not. Default: False.
- Parameters:
normalize (Boolean) – Whether to normalize image data range in [0,1] to [-1,1]. Default: True.
- Returns:
Quality score.
- class pyiqa.archs.stlpips_arch.NetLinLayer(chn_in, chn_out=1, use_dropout=False)[source]¶
Bases:
torch.nn.ModuleA single linear layer which does a 1x1 conv
- class pyiqa.archs.stlpips_arch.alexnet(requires_grad=False, variant='shift_tolerant', filter_size=3)[source]¶
Bases:
torch.nn.Module
- class pyiqa.archs.stlpips_arch.vggnet(requires_grad=False, variant='shift_tolerant', filter_size=3)[source]¶
Bases:
torch.nn.Module
- class pyiqa.archs.stlpips_arch.Downsample(pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0, pad_size='', pad_more=False)[source]¶
Bases:
torch.nn.Module
- class pyiqa.archs.stlpips_arch.VGG(features, num_classes=1000, init_weights=True)[source]¶
Bases:
torch.nn.Module