r"""LPIPS Model.
Created by: https://github.com/richzhang/PerceptualSimilarity.
Modified by: Jiadi Mo (https://github.com/JiadiMo)
Reference:
Zhang, Richard, et al. "The unreasonable effectiveness of deep features as
a perceptual metric." Proceedings of the IEEE conference on computer vision
and pattern recognition. 2018.
TOPIQ: A Top-down Approach from Semantics to Distortions for Image Quality Assessment.
Chaofeng Chen, Jiadi Mo, Jingwen Hou, Haoning Wu, Liang Liao, Wenxiu Sun, Qiong Yan, Weisi Lin.
Transactions on Image Processing, 2024.
"""
import torch
from torchvision import models
import torch.nn as nn
from collections import namedtuple
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network
from pyiqa.archs.arch_util import get_url_from_name
[docs]
default_model_urls = {
'0.0_alex': get_url_from_name('LPIPS_v0.0_alex-18720f55.pth'),
'0.0_vgg': get_url_from_name('LPIPS_v0.0_vgg-b9e42362.pth'),
'0.0_squeeze': get_url_from_name('LPIPS_v0.0_squeeze-c27abd3a.pth'),
'0.1_alex': get_url_from_name('LPIPS_v0.1_alex-df73285e.pth'),
'0.1_vgg': get_url_from_name('LPIPS_v0.1_vgg-a78928a0.pth'),
'0.1_squeeze': get_url_from_name('LPIPS_v0.1_squeeze-4a5350f2.pth'),
}
[docs]
def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W
return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)
[docs]
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2, 3], keepdim=keepdim)
[docs]
def normalize_tensor(in_feat, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
return in_feat / (norm_factor + eps)
@ARCH_REGISTRY.register()
[docs]
class LPIPS(nn.Module):
"""LPIPS model.
Args:
lpips (Boolean) : Whether to use linear layers on top of base/trunk network.
pretrained (Boolean): Whether means linear layers are calibrated with human
perceptual judgments.
pnet_rand (Boolean): Whether to randomly initialized trunk.
net (String): ['alex','vgg','squeeze'] are the base/trunk networks available.
version (String): choose the version ['v0.1'] is the default and latest;
['v0.0'] contained a normalization bug.
pretrained_model_path (String): Petrained model path.
The following parameters should only be changed if training the network:
eval_mode (Boolean): choose the mode; True is for test mode (default).
pnet_tune (Boolean): Whether to tune the base/trunk network.
use_dropout (Boolean): Whether to use dropout when training linear layers.
"""
def __init__(
self,
pretrained=True,
net='alex',
version='0.1',
lpips=True,
spatial=False,
pnet_rand=False,
pnet_tune=False,
use_dropout=True,
pretrained_model_path=None,
eval_mode=True,
semantic_weight_layer=-1,
**kwargs,
):
super(LPIPS, self).__init__()
self.pnet_type = net
self.pnet_tune = pnet_tune
self.pnet_rand = pnet_rand
self.spatial = spatial
self.lpips = lpips # false means baseline of just averaging all layers
self.version = version
self.scaling_layer = ScalingLayer()
self.semantic_weight_layer = semantic_weight_layer
if self.pnet_type in ['vgg', 'vgg16']:
net_type = vgg16
self.chns = [64, 128, 256, 512, 512]
elif self.pnet_type == 'alex':
net_type = alexnet
self.chns = [64, 192, 384, 256, 256]
elif self.pnet_type == 'squeeze':
net_type = squeezenet
self.chns = [64, 128, 256, 384, 384, 512, 512]
self.L = len(self.chns)
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
if lpips:
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
if self.pnet_type == 'squeeze': # 7 layers for squeezenet
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
self.lins += [self.lin5, self.lin6]
self.lins = nn.ModuleList(self.lins)
if pretrained_model_path is not None:
load_pretrained_network(self, pretrained_model_path, False)
elif pretrained:
load_pretrained_network(
self, default_model_urls[f'{version}_{net}'], False
)
if eval_mode:
self.eval()
[docs]
def forward(self, in1, in0, retPerLayer=False, normalize=True):
r"""Computation IQA using LPIPS.
Args:
in1: An input tensor. Shape :math:`(N, C, H, W)`.
in0: A reference tensor. Shape :math:`(N, C, H, W)`.
retPerLayer (Boolean): return result contains result of
each layer or not. Default: False.
normalize (Boolean): Whether to normalize image data range
in [0,1] to [-1,1]. Default: True.
Returns:
Quality score.
"""
if (
normalize
): # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
# v0.0 - original release had a bug, where input was not scaled
in0_input, in1_input = (
(self.scaling_layer(in0), self.scaling_layer(in1))
if self.version == '0.1'
else (in0, in1)
)
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
for kk in range(self.L):
feats0[kk], feats1[kk] = (
normalize_tensor(outs0[kk]),
normalize_tensor(outs1[kk]),
)
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
if self.lpips:
if self.spatial:
res = [
upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:])
for kk in range(self.L)
]
elif self.semantic_weight_layer >= 0:
res = []
semantic_feat = outs0[self.semantic_weight_layer]
for kk in range(self.L):
diff_score = self.lins[kk](diffs[kk])
semantic_weight = torch.nn.functional.interpolate(
semantic_feat,
size=diff_score.shape[2:],
mode='bilinear',
align_corners=False,
)
avg_score = torch.sum(
diff_score * semantic_weight, dim=[1, 2, 3], keepdim=True
) / torch.sum(semantic_weight, dim=[1, 2, 3], keepdim=True)
res.append(avg_score)
else:
res = [
spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
for kk in range(self.L)
]
else:
if self.spatial:
res = [
upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:])
for kk in range(self.L)
]
else:
res = [
spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True)
for kk in range(self.L)
]
val = 0
for i in range(self.L):
val += res[i]
if retPerLayer:
return (val, res)
else:
return val.squeeze(-1).squeeze(-1)
[docs]
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer(
'shift', torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
)
self.register_buffer(
'scale', torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
)
[docs]
def forward(self, inp):
return (inp - self.shift) / self.scale
[docs]
class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv"""
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = (
[
nn.Dropout(),
]
if (use_dropout)
else []
)
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)
[docs]
def forward(self, x):
return self.model(x)
[docs]
class squeezenet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(squeezenet, self).__init__()
pretrained_features = models.squeezenet1_1(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.slice6 = torch.nn.Sequential()
self.slice7 = torch.nn.Sequential()
self.N_slices = 7
for x in range(2):
self.slice1.add_module(str(x), pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), pretrained_features[x])
for x in range(10, 11):
self.slice5.add_module(str(x), pretrained_features[x])
for x in range(11, 12):
self.slice6.add_module(str(x), pretrained_features[x])
for x in range(12, 13):
self.slice7.add_module(str(x), pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
[docs]
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
h = self.slice6(h)
h_relu6 = h
h = self.slice7(h)
h_relu7 = h
vgg_outputs = namedtuple(
'SqueezeOutputs',
['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'],
)
out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
return out
[docs]
class alexnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(alexnet, self).__init__()
alexnet_pretrained_features = models.alexnet(weights='IMAGENET1K_V1').features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(2):
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
for x in range(10, 12):
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
[docs]
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
alexnet_outputs = namedtuple(
'AlexnetOutputs', ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']
)
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
return out
[docs]
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(weights='IMAGENET1K_V1').features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
[docs]
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
'VggOutputs', ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
[docs]
class resnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True, num=18):
super(resnet, self).__init__()
if num == 18:
self.net = models.resnet18(pretrained=pretrained)
elif num == 34:
self.net = models.resnet34(pretrained=pretrained)
elif num == 50:
self.net = models.resnet50(pretrained=pretrained)
elif num == 101:
self.net = models.resnet101(pretrained=pretrained)
elif num == 152:
self.net = models.resnet152(pretrained=pretrained)
self.N_slices = 5
self.conv1 = self.net.conv1
self.bn1 = self.net.bn1
self.relu = self.net.relu
self.maxpool = self.net.maxpool
self.layer1 = self.net.layer1
self.layer2 = self.net.layer2
self.layer3 = self.net.layer3
self.layer4 = self.net.layer4
[docs]
def forward(self, X):
h = self.conv1(X)
h = self.bn1(h)
h = self.relu(h)
h_relu1 = h
h = self.maxpool(h)
h = self.layer1(h)
h_conv2 = h
h = self.layer2(h)
h_conv3 = h
h = self.layer3(h)
h_conv4 = h
h = self.layer4(h)
h_conv5 = h
outputs = namedtuple('Outputs', ['relu1', 'conv2', 'conv3', 'conv4', 'conv5'])
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
return out