Source code for pyiqa.archs.vsi_arch

r"""VSI Metric.

@article{zhang2014vsi,
  title={VSI: A visual saliency-induced index for perceptual image quality assessment},
  author={Zhang, Lin and Shen, Ying and Li, Hongyu},
  journal={IEEE Transactions on Image processing},
  volume={23},
  number={10},
  pages={4270--4281},
  year={2014},
  publisher={IEEE}
}

Created by: https://github.com/photosynthesis-team/piq/blob/master/piq/vsi.py

Modified by: Jiadi Mo (https://github.com/JiadiMo)

Refer to:
    IQA-Optimization from https://github.com/dingkeyan93/IQA-optimization/blob/master/IQA_pytorch/VSI.py
    Official matlab code is not available
"""

import warnings
import functools
from typing import Union, Tuple
import torch
import torch.nn as nn
from torch.nn.functional import avg_pool2d, interpolate, pad

from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.utils.color_util import rgb2lmn, rgb2lab
from .func_util import (
    ifftshift,
    gradient_map,
    get_meshgrid,
    similarity_map,
    scharr_filter,
    safe_sqrt,
)


[docs] def vsi( x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float] = 1.0, c1: float = 1.27, c2: float = 386.0, c3: float = 130.0, alpha: float = 0.4, beta: float = 0.02, omega_0: float = 0.021, sigma_f: float = 1.34, sigma_d: float = 145.0, sigma_c: float = 0.001, ) -> torch.Tensor: r"""Compute Visual Saliency-induced Index for a batch of images. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. data_range: Maximum value range of images (usually 1.0 or 255). c1: coefficient to calculate saliency component of VSI. c2: coefficient to calculate gradient component of VSI. c3: coefficient to calculate color component of VSI. alpha: power for gradient component of VSI. beta: power for color component of VSI. omega_0: coefficient to get log Gabor filter at SDSP. sigma_f: coefficient to get log Gabor filter at SDSP. sigma_d: coefficient to get SDSP. sigma_c: coefficient to get SDSP. Returns: Index of similarity between two images. Usually in [0, 1] range. References: L. Zhang, Y. Shen and H. Li, "VSI: A Visual Saliency-Induced Index for Perceptual Image Quality Assessment," IEEE Transactions on Image Processing, vol. 23, no. 10, pp. 4270-4281, Oct. 2014, doi: 10.1109/TIP.2014.2346028 https://ieeexplore.ieee.org/document/6873260 Note: The original method supports only RGB image. """ x, y = x.double(), y.double() if x.size(1) == 1: x = x.repeat(1, 3, 1, 1) y = y.repeat(1, 3, 1, 1) warnings.warn( 'The original VSI supports only RGB images. The input images were converted to RGB by copying ' 'the grey channel 3 times.' ) # Scale to [0, 255] range to match scale of constant x = x * 255.0 / data_range y = y * 255.0 / data_range vs_x = sdsp( x, data_range=255, omega_0=omega_0, sigma_f=sigma_f, sigma_d=sigma_d, sigma_c=sigma_c, ) vs_y = sdsp( y, data_range=255, omega_0=omega_0, sigma_f=sigma_f, sigma_d=sigma_d, sigma_c=sigma_c, ) # Convert to LMN colour space x_lmn = rgb2lmn(x) y_lmn = rgb2lmn(y) # Averaging image if the size is large enough kernel_size = max(1, round(min(vs_x.size()[-2:]) / 256)) padding = kernel_size // 2 if padding: upper_pad = padding bottom_pad = (kernel_size - 1) // 2 pad_to_use = [upper_pad, bottom_pad, upper_pad, bottom_pad] mode = 'replicate' vs_x = pad(vs_x, pad=pad_to_use, mode=mode) vs_y = pad(vs_y, pad=pad_to_use, mode=mode) x_lmn = pad(x_lmn, pad=pad_to_use, mode=mode) y_lmn = pad(y_lmn, pad=pad_to_use, mode=mode) vs_x = avg_pool2d(vs_x, kernel_size=kernel_size) vs_y = avg_pool2d(vs_y, kernel_size=kernel_size) x_lmn = avg_pool2d(x_lmn, kernel_size=kernel_size) y_lmn = avg_pool2d(y_lmn, kernel_size=kernel_size) # Calculate gradient map kernels = torch.stack([scharr_filter(), scharr_filter().transpose(1, 2)]).to(x_lmn) gm_x = gradient_map(x_lmn[:, :1], kernels) gm_y = gradient_map(y_lmn[:, :1], kernels) # Calculate all similarity maps s_vs = similarity_map(vs_x, vs_y, c1) s_gm = similarity_map(gm_x, gm_y, c2) s_m = similarity_map(x_lmn[:, 1:2], y_lmn[:, 1:2], c3) s_n = similarity_map(x_lmn[:, 2:], y_lmn[:, 2:], c3) s_c = s_m * s_n s_c_complex = [s_c.abs(), torch.atan2(torch.zeros_like(s_c), s_c)] s_c_complex_pow = [s_c_complex[0] ** beta, s_c_complex[1] * beta] s_c_real_pow = s_c_complex_pow[0] * torch.cos(s_c_complex_pow[1]) s = s_vs * s_gm.pow(alpha) * s_c_real_pow vs_max = torch.max(vs_x, vs_y) eps = torch.finfo(vs_max.dtype).eps output = s * vs_max output = ( (output.sum(dim=(-1, -2)) + eps) / (vs_max.sum(dim=(-1, -2)) + eps) ).squeeze(-1) return output
[docs] def sdsp( x: torch.Tensor, data_range: Union[int, float] = 255, omega_0: float = 0.021, sigma_f: float = 1.34, sigma_d: float = 145.0, sigma_c: float = 0.001, ) -> torch.Tensor: r"""SDSP algorithm for salient region detection from a given image. Supports only colour images with RGB channel order. Args: x: Tensor. Shape :math:`(N, 3, H, W)`. data_range: Maximum value range of images (usually 1.0 or 255). omega_0: coefficient for log Gabor filter sigma_f: coefficient for log Gabor filter sigma_d: coefficient for the central areas, which have a bias towards attention sigma_c: coefficient for the warm colors, which have a bias towards attention Returns: torch.Tensor: Visual saliency map """ x = x / data_range * 255 size = x.size() size_to_use = (256, 256) x = interpolate(input=x, size=size_to_use, mode='bilinear', align_corners=False) x_lab = rgb2lab(x, data_range=255) lg = _log_gabor(size_to_use, omega_0, sigma_f).to(x).view(1, 1, *size_to_use) # torch version >= '1.8.0' x_fft = torch.fft.fft2(x_lab) x_ifft_real = torch.fft.ifft2(x_fft * lg).real s_f = safe_sqrt(x_ifft_real.pow(2).sum(dim=1, keepdim=True)) coordinates = torch.stack(get_meshgrid(size_to_use), dim=0).to(x) coordinates = coordinates * size_to_use[0] + 1 s_d = torch.exp(-torch.sum(coordinates**2, dim=0) / sigma_d**2).view( 1, 1, *size_to_use ) eps = torch.finfo(x_lab.dtype).eps min_x = x_lab.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values max_x = x_lab.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values normalized = (x_lab - min_x) / (max_x - min_x + eps) norm = normalized[:, 1:].pow(2).sum(dim=1, keepdim=True) s_c = 1 - torch.exp(-norm / sigma_c**2) vs_m = s_f * s_d * s_c vs_m = interpolate(vs_m, size[-2:], mode='bilinear', align_corners=True) min_vs_m = vs_m.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values max_vs_m = vs_m.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values return (vs_m - min_vs_m) / (max_vs_m - min_vs_m + eps)
def _log_gabor(size: Tuple[int, int], omega_0: float, sigma_f: float) -> torch.Tensor: r"""Creates log Gabor filter Args: size: size of the requires log Gabor filter omega_0: center frequency of the filter sigma_f: bandwidth of the filter Returns: log Gabor filter """ xx, yy = get_meshgrid(size) radius = (xx**2 + yy**2).sqrt() mask = radius <= 0.5 r = radius * mask r = ifftshift(r) r[0, 0] = 1 lg = torch.exp((-(r / omega_0).log().pow(2)) / (2 * sigma_f**2)) lg[0, 0] = 0 return lg @ARCH_REGISTRY.register()
[docs] class VSI(nn.Module): r"""Creates a criterion that measures Visual Saliency-induced Index error between each element in the input and target. Args: data_range: Maximum value range of images (usually 1.0 or 255). c1: coefficient to calculate saliency component of VSI c2: coefficient to calculate gradient component of VSI c3: coefficient to calculate color component of VSI alpha: power for gradient component of VSI beta: power for color component of VSI omega_0: coefficient to get log Gabor filter at SDSP sigma_f: coefficient to get log Gabor filter at SDSP sigma_d: coefficient to get SDSP sigma_c: coefficient to get SDSP References: L. Zhang, Y. Shen and H. Li, "VSI: A Visual Saliency-Induced Index for Perceptual Image Quality Assessment," IEEE Transactions on Image Processing, vol. 23, no. 10, pp. 4270-4281, Oct. 2014, doi: 10.1109/TIP.2014.2346028 https://ieeexplore.ieee.org/document/6873260 """ def __init__( self, c1: float = 1.27, c2: float = 386.0, c3: float = 130.0, alpha: float = 0.4, beta: float = 0.02, data_range: Union[int, float] = 1.0, omega_0: float = 0.021, sigma_f: float = 1.34, sigma_d: float = 145.0, sigma_c: float = 0.001, ) -> None: super().__init__() self.data_range = data_range self.vsi = functools.partial( vsi, c1=c1, c2=c2, c3=c3, alpha=alpha, beta=beta, omega_0=omega_0, sigma_f=sigma_f, sigma_d=sigma_d, sigma_c=sigma_c, data_range=data_range, )
[docs] def forward(self, x, y): r"""Computation of VSI as a loss function. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. Returns: Value of VSI loss to be minimized in [0, 1] range. Note: Both inputs are supposed to have RGB channels order in accordance with the original approach. Nevertheless, the method supports greyscale images, which they are converted to RGB by copying the grey channel 3 times. """ return self.vsi(x=x, y=y)