Source code for pyiqa.archs.msswd_arch
"""Perceptual color difference metric, MS-SWD.
@inproceedings{he2024ms-swd,
title={Multiscale Sliced {Wasserstein} Distances as Perceptual Color Difference Measures},
author={He, Jiaqi and Wang, Zhihua and Wang, Leon and Liu, Tsein-I and Fang, Yuming and Sun, Qilin and Ma, Kede},
booktitle={European Conference on Computer Vision},
pages={1--18},
year={2024},
url={http://arxiv.org/abs/2407.10181}
}
Reference:
- Official github: https://github.com/real-hjq/MS-SWD
"""
import torch
from torch import nn
import torchvision.transforms.functional as TF
from pyiqa.archs.arch_util import load_pretrained_network, get_url_from_name
from pyiqa.utils.registry import ARCH_REGISTRY
[docs]
def color_space_transform(input_color, fromSpace2toSpace):
"""Transform color tensors between supported color spaces.
Args:
input_color (torch.Tensor): Color tensor with shape ``(N, C, H, W)``.
fromSpace2toSpace (str): Conversion key, for example ``'srgb2lab'``.
Returns:
torch.Tensor: Transformed tensor with shape ``(N, C, H, W)``.
Raises:
ValueError: If the conversion key is not defined.
"""
dim = input_color.size()
device = input_color.device
# Assume D65 standard illuminant
reference_illuminant = torch.tensor(
[[[0.950428545]], [[1.000000000]], [[1.088900371]]]
).to(device)
inv_reference_illuminant = torch.tensor(
[[[1.052156925]], [[1.000000000]], [[0.918357670]]]
).to(device)
if fromSpace2toSpace == 'srgb2linrgb':
limit = 0.04045
transformed_color = torch.where(
input_color > limit,
torch.pow((torch.clamp(input_color, min=limit) + 0.055) / 1.055, 2.4),
input_color / 12.92,
) # clamp to stabilize training
elif fromSpace2toSpace == 'linrgb2srgb':
limit = 0.0031308
transformed_color = torch.where(
input_color > limit,
1.055 * torch.pow(torch.clamp(input_color, min=limit), (1.0 / 2.4)) - 0.055,
12.92 * input_color,
)
elif fromSpace2toSpace in ['linrgb2xyz', 'xyz2linrgb']:
# Source: https://www.image-engineering.de/library/technotes/958-how-to-convert-between-srgb-and-ciexyz
# Assumes D65 standard illuminant
if fromSpace2toSpace == 'linrgb2xyz':
a11 = 10135552 / 24577794
a12 = 8788810 / 24577794
a13 = 4435075 / 24577794
a21 = 2613072 / 12288897
a22 = 8788810 / 12288897
a23 = 887015 / 12288897
a31 = 1425312 / 73733382
a32 = 8788810 / 73733382
a33 = 70074185 / 73733382
else:
# Constants found by taking the inverse of the matrix
# defined by the constants for linrgb2xyz
a11 = 3.241003275
a12 = -1.537398934
a13 = -0.498615861
a21 = -0.969224334
a22 = 1.875930071
a23 = 0.041554224
a31 = 0.055639423
a32 = -0.204011202
a33 = 1.057148933
A = torch.Tensor([[a11, a12, a13], [a21, a22, a23], [a31, a32, a33]])
input_color = input_color.view(dim[0], dim[1], dim[2] * dim[3]) # NC(HW)
transformed_color = torch.matmul(A.to(device), input_color)
transformed_color = transformed_color.view(dim[0], dim[1], dim[2], dim[3])
elif fromSpace2toSpace == 'xyz2ycxcz':
input_color = torch.mul(input_color, inv_reference_illuminant)
y = 116 * input_color[:, 1:2, :, :] - 16
cx = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
cz = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
transformed_color = torch.cat((y, cx, cz), 1)
elif fromSpace2toSpace == 'ycxcz2xyz':
y = (input_color[:, 0:1, :, :] + 16) / 116
cx = input_color[:, 1:2, :, :] / 500
cz = input_color[:, 2:3, :, :] / 200
x = y + cx
z = y - cz
transformed_color = torch.cat((x, y, z), 1)
transformed_color = torch.mul(transformed_color, reference_illuminant)
elif fromSpace2toSpace == 'xyz2lab':
input_color = torch.mul(input_color, inv_reference_illuminant)
delta = 6 / 29
delta_square = delta * delta
delta_cube = delta * delta_square
factor = 1 / (3 * delta_square)
clamped_term = torch.pow(
torch.clamp(input_color, min=delta_cube), 1.0 / 3.0
).to(dtype=input_color.dtype)
div = (factor * input_color + (4 / 29)).to(dtype=input_color.dtype)
input_color = torch.where(
input_color > delta_cube, clamped_term, div
) # clamp to stabilize training
L = 116 * input_color[:, 1:2, :, :] - 16
a = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
b = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
transformed_color = torch.cat((L, a, b), 1)
elif fromSpace2toSpace == 'lab2xyz':
y = (input_color[:, 0:1, :, :] + 16) / 116
a = input_color[:, 1:2, :, :] / 500
b = input_color[:, 2:3, :, :] / 200
x = y + a
z = y - b
xyz = torch.cat((x, y, z), 1)
delta = 6 / 29
delta_square = delta * delta
factor = 3 * delta_square
xyz = torch.where(xyz > delta, torch.pow(xyz, 3), factor * (xyz - 4 / 29))
transformed_color = torch.mul(xyz, reference_illuminant)
elif fromSpace2toSpace == 'srgb2xyz':
transformed_color = color_space_transform(input_color, 'srgb2linrgb')
transformed_color = color_space_transform(transformed_color, 'linrgb2xyz')
elif fromSpace2toSpace == 'srgb2ycxcz':
transformed_color = color_space_transform(input_color, 'srgb2linrgb')
transformed_color = color_space_transform(transformed_color, 'linrgb2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2ycxcz')
elif fromSpace2toSpace == 'linrgb2ycxcz':
transformed_color = color_space_transform(input_color, 'linrgb2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2ycxcz')
elif fromSpace2toSpace == 'srgb2lab':
transformed_color = color_space_transform(input_color, 'srgb2linrgb')
transformed_color = color_space_transform(transformed_color, 'linrgb2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2lab')
elif fromSpace2toSpace == 'linrgb2lab':
transformed_color = color_space_transform(input_color, 'linrgb2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2lab')
elif fromSpace2toSpace == 'ycxcz2linrgb':
transformed_color = color_space_transform(input_color, 'ycxcz2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2linrgb')
elif fromSpace2toSpace == 'lab2srgb':
transformed_color = color_space_transform(input_color, 'lab2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2linrgb')
transformed_color = color_space_transform(transformed_color, 'linrgb2srgb')
elif fromSpace2toSpace == 'ycxcz2lab':
transformed_color = color_space_transform(input_color, 'ycxcz2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2lab')
else:
raise ValueError(
'Error: The color transform %s is not defined!' % fromSpace2toSpace
)
return transformed_color
@ARCH_REGISTRY.register()
[docs]
class MS_SWD_learned(nn.Module):
"""MS-SWD perceptual color difference metric.
Args:
resize_input (bool): Whether to resize inputs with short side larger
than ``256`` before scoring.
pretrained (bool): Whether to load pretrained weights.
pretrained_model_path (str | None): Optional local checkpoint path.
**kwargs: Reserved compatibility arguments.
"""
def __init__(
self,
resize_input: bool = True,
pretrained: bool = True,
pretrained_model_path: str = None,
**kwargs,
):
super(MS_SWD_learned, self).__init__()
self.conv11x11 = nn.Conv2d(
3,
128,
kernel_size=11,
stride=1,
padding=5,
padding_mode='reflect',
dilation=1,
bias=False,
)
self.conv_m1 = nn.Conv2d(
128, 64, kernel_size=1, stride=1, padding=0, dilation=1, bias=False
)
self.relu = nn.LeakyReLU()
self.resize_input = resize_input
if pretrained_model_path is not None:
load_pretrained_network(self, pretrained_model_path, weight_keys='params')
elif pretrained:
load_pretrained_network(
self, get_url_from_name('msswd_weights.pth'), weight_keys='net_dict'
)
[docs]
def preprocess_img(self, x):
"""Optionally resize image batch before feature extraction."""
if self.resize_input and min(x.shape[2:]) > 256:
x = TF.resize(x, 256)
return x
[docs]
def forward_once(self, x):
"""Encode one image batch into sorted SWD feature representation."""
x = color_space_transform(x, 'srgb2lab')
x = self.conv11x11(x)
x = self.relu(x)
x = self.conv_m1(x)
x = x.reshape(x.shape[0], x.shape[1], -1)
return x
[docs]
def forward(self, x, y):
"""Compute MS-SWD distance.
Args:
x (torch.Tensor): Distorted image tensor with shape ``(N, 3, H, W)``.
y (torch.Tensor): Reference image tensor with shape ``(N, 3, H, W)``.
Returns:
torch.Tensor: Distance scores with shape ``(N,)``.
"""
x = self.preprocess_img(x)
y = self.preprocess_img(y)
output_x = self.forward_once(x)
output_y = self.forward_once(y)
# Sort and compute L1 distance
output_x, _ = torch.sort(output_x, dim=2)
output_y, _ = torch.sort(output_y, dim=2)
swd = torch.abs(output_x - output_y)
swd = torch.mean(swd, dim=[1, 2])
return swd