"""ST-LPIPS Model
github repo link: https://github.com/abhijay9/ShiftTolerant-LPIPS
Cite as:
@inproceedings{ghildyal2022stlpips,
title={Shift-tolerant Perceptual Similarity Metric},
author={Abhijay Ghildyal and Feng Liu},
booktitle={European Conference on Computer Vision},
year={2022}
}
"""
import torch
import torch.nn as nn
from collections import namedtuple
import numpy as np
import torch.nn.functional as F
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network
from pyiqa.archs.arch_util import get_url_from_name
[docs]
default_model_urls = {
'alex_shift_tolerant': get_url_from_name('alex_shift_tolerant.pth'),
'vgg_shift_tolerant': get_url_from_name('vgg_shift_tolerant.pth'),
}
[docs]
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2, 3], keepdim=keepdim)
[docs]
def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W
in_H, in_W = in_tens.shape[2], in_tens.shape[3]
return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)
[docs]
def normalize_tensor(in_feat, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
return in_feat / (norm_factor + eps)
@ARCH_REGISTRY.register()
[docs]
class STLPIPS(nn.Module):
"""ST-LPIPS model.
Args:
lpips (Boolean) : Whether to use linear layers on top of base/trunk network.
pretrained (Boolean): Whether means linear layers are calibrated with human
perceptual judgments.
net (String): ['alex','vgg','squeeze'] are the base/trunk networks available.
pretrained_model_path (String): Petrained model path.
The following parameters should only be changed if training the network:
eval_mode (Boolean): choose the mode; True is for test mode (default).
pnet_tune (Boolean): Whether to tune the base/trunk network.
use_dropout (Boolean): Whether to use dropout when training linear layers.
"""
def __init__(
self,
pretrained=True,
net='alex',
variant='shift_tolerant',
lpips=True,
spatial=False,
pnet_tune=False,
use_dropout=True,
pretrained_model_path=None,
eval_mode=True,
blur_filter_size=3,
):
super(STLPIPS, self).__init__()
self.pnet_type = net
self.pnet_tune = pnet_tune
self.spatial = spatial
self.lpips = lpips # false means baseline of just averaging all layers
self.scaling_layer = ScalingLayer()
if self.pnet_type in ['vgg']:
net_type = vggnet
self.chns = [64, 128, 256, 512, 512]
elif self.pnet_type == 'alex':
net_type = alexnet
self.chns = [64, 192, 384, 256, 256]
self.net = net_type(
requires_grad=self.pnet_tune, variant=variant, filter_size=blur_filter_size
)
self.L = len(self.chns)
if lpips:
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
if self.pnet_type == 'squeeze': # 7 layers for squeezenet
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
self.lins += [self.lin5, self.lin6]
self.lins = nn.ModuleList(self.lins)
if pretrained_model_path is not None:
load_pretrained_network(self, pretrained_model_path, False)
elif pretrained:
load_pretrained_network(
self, default_model_urls[f'{net}_{variant}'], False
)
if eval_mode:
self.eval()
[docs]
def forward(self, in0, in1, retPerLayer=False, normalize=True):
"""Computation IQA using LPIPS.
Args:
in1: An input tensor. Shape :math:`(N, C, H, W)`.
in0: A reference tensor. Shape :math:`(N, C, H, W)`.
retPerLayer (Boolean): return result contains result of
each layer or not. Default: False.
normalize (Boolean): Whether to normalize image data range
in [0,1] to [-1,1]. Default: True.
Returns:
Quality score.
"""
if (
normalize
): # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1))
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
for kk in range(self.L):
feats0[kk], feats1[kk] = (
normalize_tensor(outs0[kk]),
normalize_tensor(outs1[kk]),
)
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
if self.lpips:
if self.spatial:
res = [
upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:])
for kk in range(self.L)
]
else:
res = [
spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
for kk in range(self.L)
]
else:
if self.spatial:
res = [
upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:])
for kk in range(self.L)
]
else:
res = [
spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True)
for kk in range(self.L)
]
val = res[0]
for l in range(1, self.L):
val += res[l]
if retPerLayer:
return (val, res)
else:
return val
[docs]
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer(
'shift', torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
)
self.register_buffer(
'scale', torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
)
[docs]
def forward(self, inp):
return (inp - self.shift) / self.scale
[docs]
class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv"""
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = (
[
nn.Dropout(),
]
if (use_dropout)
else []
)
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)
[docs]
def forward(self, x):
return self.model(x)
[docs]
class alexnet(nn.Module):
def __init__(self, requires_grad=False, variant='shift_tolerant', filter_size=3):
super(alexnet, self).__init__()
self.slice1 = nn.Sequential()
self.slice2 = nn.Sequential()
self.slice3 = nn.Sequential()
self.slice4 = nn.Sequential()
self.slice5 = nn.Sequential()
if variant == 'vanilla':
features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True), # 1
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True), # 4
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True), # 7
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True), # 9
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True), # 11
nn.MaxPool2d(kernel_size=3, stride=2),
)
for x in range(2):
self.slice1.add_module(str(x), features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), features[x])
for x in range(10, 12):
self.slice5.add_module(str(x), features[x])
elif variant == 'antialiased':
features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=2, padding=2),
nn.ReLU(inplace=True), # 1
Downsample(filt_size=filter_size, stride=2, channels=64),
nn.MaxPool2d(kernel_size=3, stride=1),
Downsample(filt_size=filter_size, stride=2, channels=64),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True), # 6
nn.MaxPool2d(kernel_size=3, stride=1),
Downsample(filt_size=filter_size, stride=2, channels=192),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True), # 10
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True), # 12
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True), # 14
nn.MaxPool2d(kernel_size=3, stride=1),
Downsample(filt_size=filter_size, stride=2, channels=256),
)
for x in range(2):
self.slice1.add_module(str(x), features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), features[x])
for x in range(7, 11):
self.slice3.add_module(str(x), features[x])
for x in range(11, 13):
self.slice4.add_module(str(x), features[x])
for x in range(13, 15):
self.slice5.add_module(str(x), features[x])
elif (
variant == 'shift_tolerant'
): # antialiased_blurpoolReflectionPad2_conv1stride1_blurAfter
features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=1, padding=2),
Downsample(filt_size=filter_size, stride=2, channels=64, pad_more=True),
nn.ReLU(inplace=True), # 2
nn.MaxPool2d(kernel_size=3, stride=1),
Downsample(filt_size=filter_size, stride=2, channels=64, pad_more=True),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True), # 6
nn.MaxPool2d(kernel_size=3, stride=1),
Downsample(
filt_size=filter_size, stride=2, channels=192, pad_more=True
),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True), # 10
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True), # 12
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True), # 14
nn.MaxPool2d(kernel_size=3, stride=1),
Downsample(
filt_size=filter_size, stride=2, channels=256, pad_more=True
),
)
for x in range(3):
self.slice1.add_module(str(x), features[x])
for x in range(3, 7):
self.slice2.add_module(str(x), features[x])
for x in range(7, 11):
self.slice3.add_module(str(x), features[x])
for x in range(11, 13):
self.slice4.add_module(str(x), features[x])
for x in range(13, 15):
self.slice5.add_module(str(x), features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
[docs]
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
alexnet_outputs = namedtuple(
'AlexnetOutputs', ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']
)
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
return out
[docs]
class vggnet(nn.Module):
def __init__(self, requires_grad=False, variant='shift_tolerant', filter_size=3):
super(vggnet, self).__init__()
filter_size = 3
self.slice1 = nn.Sequential()
self.slice2 = nn.Sequential()
self.slice3 = nn.Sequential()
self.slice4 = nn.Sequential()
self.slice5 = nn.Sequential()
self.N_slices = 5
if variant == 'vanilla':
vgg_features = tv.vgg16(pretrained=False).features
for x in range(4):
self.slice1.add_module(str(x), vgg_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_features[x])
elif variant == 'shift_tolerant':
vgg_features = vgg16(filter_size=filter_size, pad_more=True).features
for x in range(4):
self.slice1.add_module(str(x), vgg_features[x])
for x in range(4, 10):
self.slice2.add_module(str(x), vgg_features[x])
for x in range(10, 18):
self.slice3.add_module(str(x), vgg_features[x])
for x in range(18, 26):
self.slice4.add_module(str(x), vgg_features[x])
for x in range(26, 34):
self.slice5.add_module(str(x), vgg_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
[docs]
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
'VggOutputs', ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
# Copyright (c) 2019, Adobe Inc. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
# 4.0 International Public License. To view a copy of this license, visit
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
[docs]
class Downsample(nn.Module):
def __init__(
self,
pad_type='reflect',
filt_size=3,
stride=2,
channels=None,
pad_off=0,
pad_size='',
pad_more=False,
):
super(Downsample, self).__init__()
self.filt_size = filt_size
self.pad_off = pad_off
if pad_size == '2k' or pad_more == True:
self.pad_sizes = [
int(1.0 * (filt_size - 1)),
int(np.ceil(1.0 * (filt_size - 1))),
int(1.0 * (filt_size - 1)),
int(np.ceil(1.0 * (filt_size - 1))),
]
elif pad_size == 'none':
self.pad_sizes = [0, 0, 0, 0]
else:
self.pad_sizes = [
int(1.0 * (filt_size - 1) / 2),
int(np.ceil(1.0 * (filt_size - 1) / 2)),
int(1.0 * (filt_size - 1) / 2),
int(np.ceil(1.0 * (filt_size - 1) / 2)),
]
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
self.stride = stride
self.off = int((self.stride - 1) / 2.0)
self.channels = channels
if self.filt_size == 1:
a = np.array(
[
1.0,
]
)
elif self.filt_size == 2:
a = np.array([1.0, 1.0])
elif self.filt_size == 3:
a = np.array([1.0, 2.0, 1.0])
elif self.filt_size == 4:
a = np.array([1.0, 3.0, 3.0, 1.0])
elif self.filt_size == 5:
a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
elif self.filt_size == 6:
a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
elif self.filt_size == 7:
a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
filt = torch.Tensor(a[:, None] * a[None, :])
filt = filt / torch.sum(filt)
self.register_buffer(
'filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
)
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
[docs]
def forward(self, inp):
if self.filt_size == 1:
if self.pad_off == 0:
return inp[:, :, :: self.stride, :: self.stride]
else:
return self.pad(inp)[:, :, :: self.stride, :: self.stride]
else:
return F.conv2d(
self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]
)
[docs]
def get_pad_layer(pad_type):
if pad_type in ['refl', 'reflect']:
PadLayer = nn.ReflectionPad2d
elif pad_type in ['repl', 'replicate']:
PadLayer = nn.ReplicationPad2d
elif pad_type == 'zero':
PadLayer = nn.ZeroPad2d
else:
print('Pad type [%s] not recognized' % pad_type)
return PadLayer
# This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models.
# Copyright (c) 2017 Torch Contributors.
# The Pytorch examples are available under the BSD 3-Clause License.
#
# ==========================================================================================
#
# Adobe’s modifications are Copyright 2019 Adobe. All rights reserved.
# Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
# 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
#
# ==========================================================================================
#
# BSD-3 License
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
[docs]
class VGG(nn.Module):
def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
if init_weights:
self._initialize_weights()
[docs]
def forward(self, x):
x = self.features(x)
# print(x.shape)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
if (
m.in_channels != m.out_channels
or m.out_channels != m.groups
or m.bias is not None
):
# don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu'
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
else:
print('Not initializing')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
[docs]
def make_layers(cfg, batch_norm=False, filter_size=1, pad_more=False, fconv=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
# layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
layers += [
nn.MaxPool2d(kernel_size=2, stride=1),
Downsample(
filt_size=filter_size,
stride=2,
channels=in_channels,
pad_more=pad_more,
),
]
else:
if fconv:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=2)
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
[docs]
cfg = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [
64,
64,
'M',
128,
128,
'M',
256,
256,
256,
'M',
512,
512,
512,
'M',
512,
512,
512,
'M',
],
'E': [
64,
64,
'M',
128,
128,
'M',
256,
256,
256,
256,
'M',
512,
512,
512,
512,
'M',
512,
512,
512,
512,
'M',
],
}
[docs]
def vgg16(pretrained=False, filter_size=1, pad_more=False, fconv=False, **kwargs):
"""VGG 16-layer model (configuration "D")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = VGG(
make_layers(cfg['D'], filter_size=filter_size, pad_more=pad_more, fconv=fconv),
**kwargs,
)
# if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
return model