r"""VSI Metric.
@article{zhang2014vsi,
title={VSI: A visual saliency-induced index for perceptual image quality assessment},
author={Zhang, Lin and Shen, Ying and Li, Hongyu},
journal={IEEE Transactions on Image processing},
volume={23},
number={10},
pages={4270--4281},
year={2014},
publisher={IEEE}
}
Created by: https://github.com/photosynthesis-team/piq/blob/master/piq/vsi.py
Modified by: Jiadi Mo (https://github.com/JiadiMo)
Refer to:
IQA-Optimization from https://github.com/dingkeyan93/IQA-optimization/blob/master/IQA_pytorch/VSI.py
Official matlab code is not available
"""
import warnings
import functools
from typing import Union, Tuple
import torch
import torch.nn as nn
from torch.nn.functional import avg_pool2d, interpolate, pad
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.utils.color_util import rgb2lmn, rgb2lab
from .func_util import (
ifftshift,
gradient_map,
get_meshgrid,
similarity_map,
scharr_filter,
safe_sqrt,
)
[docs]
def vsi(
x: torch.Tensor,
y: torch.Tensor,
data_range: Union[int, float] = 1.0,
c1: float = 1.27,
c2: float = 386.0,
c3: float = 130.0,
alpha: float = 0.4,
beta: float = 0.02,
omega_0: float = 0.021,
sigma_f: float = 1.34,
sigma_d: float = 145.0,
sigma_c: float = 0.001,
) -> torch.Tensor:
r"""Compute Visual Saliency-induced Index for a batch of images.
Args:
x: An input tensor. Shape :math:`(N, C, H, W)`.
y: A target tensor. Shape :math:`(N, C, H, W)`.
data_range: Maximum value range of images (usually 1.0 or 255).
c1: coefficient to calculate saliency component of VSI.
c2: coefficient to calculate gradient component of VSI.
c3: coefficient to calculate color component of VSI.
alpha: power for gradient component of VSI.
beta: power for color component of VSI.
omega_0: coefficient to get log Gabor filter at SDSP.
sigma_f: coefficient to get log Gabor filter at SDSP.
sigma_d: coefficient to get SDSP.
sigma_c: coefficient to get SDSP.
Returns:
Index of similarity between two images. Usually in [0, 1] range.
References:
L. Zhang, Y. Shen and H. Li, "VSI: A Visual Saliency-Induced Index for Perceptual
Image Quality Assessment," IEEE Transactions on Image Processing, vol. 23, no. 10,
pp. 4270-4281, Oct. 2014, doi: 10.1109/TIP.2014.2346028
https://ieeexplore.ieee.org/document/6873260
Note:
The original method supports only RGB image.
"""
x, y = x.double(), y.double()
if x.size(1) == 1:
x = x.repeat(1, 3, 1, 1)
y = y.repeat(1, 3, 1, 1)
warnings.warn(
'The original VSI supports only RGB images. The input images were converted to RGB by copying '
'the grey channel 3 times.'
)
# Scale to [0, 255] range to match scale of constant
x = x * 255.0 / data_range
y = y * 255.0 / data_range
vs_x = sdsp(
x,
data_range=255,
omega_0=omega_0,
sigma_f=sigma_f,
sigma_d=sigma_d,
sigma_c=sigma_c,
)
vs_y = sdsp(
y,
data_range=255,
omega_0=omega_0,
sigma_f=sigma_f,
sigma_d=sigma_d,
sigma_c=sigma_c,
)
# Convert to LMN colour space
x_lmn = rgb2lmn(x)
y_lmn = rgb2lmn(y)
# Averaging image if the size is large enough
kernel_size = max(1, round(min(vs_x.size()[-2:]) / 256))
padding = kernel_size // 2
if padding:
upper_pad = padding
bottom_pad = (kernel_size - 1) // 2
pad_to_use = [upper_pad, bottom_pad, upper_pad, bottom_pad]
mode = 'replicate'
vs_x = pad(vs_x, pad=pad_to_use, mode=mode)
vs_y = pad(vs_y, pad=pad_to_use, mode=mode)
x_lmn = pad(x_lmn, pad=pad_to_use, mode=mode)
y_lmn = pad(y_lmn, pad=pad_to_use, mode=mode)
vs_x = avg_pool2d(vs_x, kernel_size=kernel_size)
vs_y = avg_pool2d(vs_y, kernel_size=kernel_size)
x_lmn = avg_pool2d(x_lmn, kernel_size=kernel_size)
y_lmn = avg_pool2d(y_lmn, kernel_size=kernel_size)
# Calculate gradient map
kernels = torch.stack([scharr_filter(), scharr_filter().transpose(1, 2)]).to(x_lmn)
gm_x = gradient_map(x_lmn[:, :1], kernels)
gm_y = gradient_map(y_lmn[:, :1], kernels)
# Calculate all similarity maps
s_vs = similarity_map(vs_x, vs_y, c1)
s_gm = similarity_map(gm_x, gm_y, c2)
s_m = similarity_map(x_lmn[:, 1:2], y_lmn[:, 1:2], c3)
s_n = similarity_map(x_lmn[:, 2:], y_lmn[:, 2:], c3)
s_c = s_m * s_n
s_c_complex = [s_c.abs(), torch.atan2(torch.zeros_like(s_c), s_c)]
s_c_complex_pow = [s_c_complex[0] ** beta, s_c_complex[1] * beta]
s_c_real_pow = s_c_complex_pow[0] * torch.cos(s_c_complex_pow[1])
s = s_vs * s_gm.pow(alpha) * s_c_real_pow
vs_max = torch.max(vs_x, vs_y)
eps = torch.finfo(vs_max.dtype).eps
output = s * vs_max
output = (
(output.sum(dim=(-1, -2)) + eps) / (vs_max.sum(dim=(-1, -2)) + eps)
).squeeze(-1)
return output
[docs]
def sdsp(
x: torch.Tensor,
data_range: Union[int, float] = 255,
omega_0: float = 0.021,
sigma_f: float = 1.34,
sigma_d: float = 145.0,
sigma_c: float = 0.001,
) -> torch.Tensor:
r"""SDSP algorithm for salient region detection from a given image.
Supports only colour images with RGB channel order.
Args:
x: Tensor. Shape :math:`(N, 3, H, W)`.
data_range: Maximum value range of images (usually 1.0 or 255).
omega_0: coefficient for log Gabor filter
sigma_f: coefficient for log Gabor filter
sigma_d: coefficient for the central areas, which have a bias towards attention
sigma_c: coefficient for the warm colors, which have a bias towards attention
Returns:
torch.Tensor: Visual saliency map
"""
x = x / data_range * 255
size = x.size()
size_to_use = (256, 256)
x = interpolate(input=x, size=size_to_use, mode='bilinear', align_corners=False)
x_lab = rgb2lab(x, data_range=255)
lg = _log_gabor(size_to_use, omega_0, sigma_f).to(x).view(1, 1, *size_to_use)
# torch version >= '1.8.0'
x_fft = torch.fft.fft2(x_lab)
x_ifft_real = torch.fft.ifft2(x_fft * lg).real
s_f = safe_sqrt(x_ifft_real.pow(2).sum(dim=1, keepdim=True))
coordinates = torch.stack(get_meshgrid(size_to_use), dim=0).to(x)
coordinates = coordinates * size_to_use[0] + 1
s_d = torch.exp(-torch.sum(coordinates**2, dim=0) / sigma_d**2).view(
1, 1, *size_to_use
)
eps = torch.finfo(x_lab.dtype).eps
min_x = x_lab.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values
max_x = x_lab.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values
normalized = (x_lab - min_x) / (max_x - min_x + eps)
norm = normalized[:, 1:].pow(2).sum(dim=1, keepdim=True)
s_c = 1 - torch.exp(-norm / sigma_c**2)
vs_m = s_f * s_d * s_c
vs_m = interpolate(vs_m, size[-2:], mode='bilinear', align_corners=True)
min_vs_m = vs_m.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values
max_vs_m = vs_m.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values
return (vs_m - min_vs_m) / (max_vs_m - min_vs_m + eps)
def _log_gabor(size: Tuple[int, int], omega_0: float, sigma_f: float) -> torch.Tensor:
r"""Creates log Gabor filter
Args:
size: size of the requires log Gabor filter
omega_0: center frequency of the filter
sigma_f: bandwidth of the filter
Returns:
log Gabor filter
"""
xx, yy = get_meshgrid(size)
radius = (xx**2 + yy**2).sqrt()
mask = radius <= 0.5
r = radius * mask
r = ifftshift(r)
r[0, 0] = 1
lg = torch.exp((-(r / omega_0).log().pow(2)) / (2 * sigma_f**2))
lg[0, 0] = 0
return lg
@ARCH_REGISTRY.register()
[docs]
class VSI(nn.Module):
r"""Creates a criterion that measures Visual Saliency-induced Index error between
each element in the input and target.
Args:
data_range: Maximum value range of images (usually 1.0 or 255).
c1: coefficient to calculate saliency component of VSI
c2: coefficient to calculate gradient component of VSI
c3: coefficient to calculate color component of VSI
alpha: power for gradient component of VSI
beta: power for color component of VSI
omega_0: coefficient to get log Gabor filter at SDSP
sigma_f: coefficient to get log Gabor filter at SDSP
sigma_d: coefficient to get SDSP
sigma_c: coefficient to get SDSP
References:
L. Zhang, Y. Shen and H. Li, "VSI: A Visual Saliency-Induced Index for Perceptual
Image Quality Assessment," IEEE Transactions on Image Processing, vol. 23, no. 10,
pp. 4270-4281, Oct. 2014, doi: 10.1109/TIP.2014.2346028
https://ieeexplore.ieee.org/document/6873260
"""
def __init__(
self,
c1: float = 1.27,
c2: float = 386.0,
c3: float = 130.0,
alpha: float = 0.4,
beta: float = 0.02,
data_range: Union[int, float] = 1.0,
omega_0: float = 0.021,
sigma_f: float = 1.34,
sigma_d: float = 145.0,
sigma_c: float = 0.001,
) -> None:
super().__init__()
self.data_range = data_range
self.vsi = functools.partial(
vsi,
c1=c1,
c2=c2,
c3=c3,
alpha=alpha,
beta=beta,
omega_0=omega_0,
sigma_f=sigma_f,
sigma_d=sigma_d,
sigma_c=sigma_c,
data_range=data_range,
)
[docs]
def forward(self, x, y):
r"""Computation of VSI as a loss function.
Args:
x: An input tensor. Shape :math:`(N, C, H, W)`.
y: A target tensor. Shape :math:`(N, C, H, W)`.
Returns:
Value of VSI loss to be minimized in [0, 1] range.
Note:
Both inputs are supposed to have RGB channels order in accordance with the original approach.
Nevertheless, the method supports greyscale images, which they are converted to RGB by copying the grey
channel 3 times.
"""
return self.vsi(x=x, y=y)