Source code for pyiqa.archs.fsim_arch

r"""FSIM Metric

@article{zhang2011fsim,
  title={FSIM: A feature similarity index for image quality assessment},
  author={Zhang, Lin and Zhang, Lei and Mou, Xuanqin and Zhang, David},
  journal={IEEE transactions on Image Processing},
  volume={20},
  number={8},
  pages={2378--2386},
  year={2011},
  publisher={IEEE}
}

Created by: https://github.com/photosynthesis-team/piq/blob/master/piq/fsim.py
Modified by: Jiadi Mo (https://github.com/JiadiMo)

Refer to:
    Official matlab code from https://www4.comp.polyu.edu.hk/~cslzhang/IQA/FSIM/Files/FeatureSIM.m
    PIQA from https://github.com/francois-rozet/piqa/blob/master/piqa/fsim.py

"""

import math
import functools
from typing import Tuple
import torch.nn as nn
import torch

from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.utils.color_util import rgb2yiq
from .func_util import gradient_map, similarity_map, ifftshift, get_meshgrid


[docs] def fsim( x: torch.Tensor, y: torch.Tensor, chromatic: bool = True, scales: int = 4, orientations: int = 4, min_length: int = 6, mult: int = 2, sigma_f: float = 0.55, delta_theta: float = 1.2, k: float = 2.0, ) -> torch.Tensor: r"""Compute Feature Similarity Index Measure for a batch of images. Args: - x: An input tensor. Shape :math:`(N, C, H, W)`. - y: A target tensor. Shape :math:`(N, C, H, W)`. - chromatic: Flag to compute FSIMc, which also takes into account chromatic components - scales: Number of wavelets used for computation of phase congruensy maps - orientations: Number of filter orientations used for computation of phase congruensy maps - min_length: Wavelength of smallest scale filter - mult: Scaling factor between successive filters - sigma_f: Ratio of the standard deviation of the Gaussian describing the log Gabor filter's transfer function in the frequency domain to the filter center frequency. - delta_theta: Ratio of angular interval between filter orientations and the standard deviation of the angular Gaussian function used to construct filters in the frequency plane. - k: No of standard deviations of the noise energy beyond the mean at which we set the noise threshold point, below which phase congruency values get penalized. Returns: - Index of similarity between two images. Usually in [0, 1] interval. Can be bigger than 1 for predicted :math:`x` images with higher contrast than the original ones. References: L. Zhang, L. Zhang, X. Mou and D. Zhang, "FSIM: A Feature Similarity Index for Image Quality Assessment," IEEE Transactions on Image Processing, vol. 20, no. 8, pp. 2378-2386, Aug. 2011, doi: 10.1109/TIP.2011.2109730. https://ieeexplore.ieee.org/document/5705575 """ # Rescale to [0, 255] range, because all constant are calculated for this factor x = x / float(1.0) * 255 y = y / float(1.0) * 255 # Apply average pooling kernel_size = max(1, round(min(x.shape[-2:]) / 256)) x = torch.nn.functional.avg_pool2d(x, kernel_size) y = torch.nn.functional.avg_pool2d(y, kernel_size) num_channels = x.size(1) # Convert RGB to YIQ color space if num_channels == 3: x_yiq = rgb2yiq(x) y_yiq = rgb2yiq(y) x_lum = x_yiq[:, :1] y_lum = y_yiq[:, :1] x_i = x_yiq[:, 1:2] y_i = y_yiq[:, 1:2] x_q = x_yiq[:, 2:] y_q = y_yiq[:, 2:] else: x_lum = x y_lum = y # Compute phase congruency maps pc_x = _phase_congruency( x_lum, scales=scales, orientations=orientations, min_length=min_length, mult=mult, sigma_f=sigma_f, delta_theta=delta_theta, k=k, ) pc_y = _phase_congruency( y_lum, scales=scales, orientations=orientations, min_length=min_length, mult=mult, sigma_f=sigma_f, delta_theta=delta_theta, k=k, ) # Gradient maps scharr_filter = ( torch.tensor([[[-3.0, 0.0, 3.0], [-10.0, 0.0, 10.0], [-3.0, 0.0, 3.0]]]) / 16 ) kernels = torch.stack([scharr_filter, scharr_filter.transpose(-1, -2)]) grad_map_x = gradient_map(x_lum, kernels) grad_map_y = gradient_map(y_lum, kernels) # Constants from the paper T1, T2, T3, T4, lmbda = 0.85, 160, 200, 200, 0.03 # Compute FSIM PC = similarity_map(pc_x, pc_y, T1) GM = similarity_map(grad_map_x, grad_map_y, T2) pc_max = torch.where(pc_x > pc_y, pc_x, pc_y) score = GM * PC * pc_max # torch.sum(score)/torch.sum(pc_max) if chromatic: assert num_channels == 3, ( 'Chromatic component can be computed only for RGB images!' ) S_I = similarity_map(x_i, y_i, T3) S_Q = similarity_map(x_q, y_q, T4) score = score * torch.abs(S_I * S_Q) ** lmbda # Complex gradients will work in PyTorch 1.6.0 # score = score * torch.real((S_I * S_Q).to(torch.complex64) ** lmbda) result = score.sum(dim=[1, 2, 3]) / pc_max.sum(dim=[1, 2, 3]) return result
def _construct_filters( x: torch.Tensor, scales: int = 4, orientations: int = 4, min_length: int = 6, mult: int = 2, sigma_f: float = 0.55, delta_theta: float = 1.2, k: float = 2.0, use_lowpass_filter=True, ): """Creates a stack of filters used for computation of phase congruensy maps Args: - x: Tensor. Shape :math:`(N, 1, H, W)`. - scales: Number of wavelets - orientations: Number of filter orientations - min_length: Wavelength of smallest scale filter - mult: Scaling factor between successive filters - sigma_f: Ratio of the standard deviation of the Gaussian describing the log Gabor filter's transfer function in the frequency domain to the filter center frequency. - delta_theta: Ratio of angular interval between filter orientations and the standard deviation of the angular Gaussian function used to construct filters in the freq. plane. - k: No of standard deviations of the noise energy beyond the mean at which we set the noise threshold point, below which phase congruency values get penalized. """ N, _, H, W = x.shape # Calculate the standard deviation of the angular Gaussian function # used to construct filters in the freq. plane. theta_sigma = math.pi / (orientations * delta_theta) # Pre-compute some stuff to speed up filter construction grid_x, grid_y = get_meshgrid((H, W)) radius = torch.sqrt(grid_x**2 + grid_y**2) theta = torch.atan2(-grid_y, grid_x) # Quadrant shift radius and theta so that filters are constructed with 0 frequency at the corners. # Get rid of the 0 radius value at the 0 frequency point (now at top-left corner) # so that taking the log of the radius will not cause trouble. radius = ifftshift(radius) theta = ifftshift(theta) radius[0, 0] = 1 sintheta = torch.sin(theta) costheta = torch.cos(theta) # Filters are constructed in terms of two components. # 1) The radial component, which controls the frequency band that the filter responds to # 2) The angular component, which controls the orientation that the filter responds to. # The two components are multiplied together to construct the overall filter. # First construct a low-pass filter that is as large as possible, yet falls # away to zero at the boundaries. All log Gabor filters are multiplied by # this to ensure no extra frequencies at the 'corners' of the FFT are # incorporated as this seems to upset the normalisation process when lp = _lowpassfilter(size=(H, W), cutoff=0.45, n=15) # Construct the radial filter components... log_gabor = [] for s in range(scales): wavelength = min_length * mult**s omega_0 = 1.0 / wavelength gabor_filter = torch.exp( (-(torch.log(radius / omega_0) ** 2)) / (2 * math.log(sigma_f) ** 2) ) if use_lowpass_filter: gabor_filter = gabor_filter * lp gabor_filter[0, 0] = 0 log_gabor.append(gabor_filter) # Then construct the angular filter components... spread = [] for o in range(orientations): angl = o * math.pi / orientations # For each point in the filter matrix calculate the angular distance from # the specified filter orientation. To overcome the angular wrap-around # problem sine difference and cosine difference values are first computed # and then the atan2 function is used to determine angular distance. ds = sintheta * math.cos(angl) - costheta * math.sin( angl ) # Difference in sine. dc = costheta * math.cos(angl) + sintheta * math.sin( angl ) # Difference in cosine. dtheta = torch.abs(torch.atan2(ds, dc)) spread.append(torch.exp((-(dtheta**2)) / (2 * theta_sigma**2))) spread = torch.stack(spread) log_gabor = torch.stack(log_gabor) # Multiply, add batch dimension and transfer to correct device. filters = ( (spread.repeat_interleave(scales, dim=0) * log_gabor.repeat(orientations, 1, 1)) .unsqueeze(0) .to(x) ) return filters def _phase_congruency( x: torch.Tensor, scales: int = 4, orientations: int = 4, min_length: int = 6, mult: int = 2, sigma_f: float = 0.55, delta_theta: float = 1.2, k: float = 2.0, ) -> torch.Tensor: r"""Compute Phase Congruence for a batch of greyscale images Args: x: Tensor. Shape :math:`(N, 1, H, W)`. scales: Number of wavelet scales orientations: Number of filter orientations min_length: Wavelength of smallest scale filter mult: Scaling factor between successive filters sigma_f: Ratio of the standard deviation of the Gaussian describing the log Gabor filter's transfer function in the frequency domain to the filter center frequency. delta_theta: Ratio of angular interval between filter orientations and the standard deviation of the angular Gaussian function used to construct filters in the freq. plane. k: No of standard deviations of the noise energy beyond the mean at which we set the noise threshold point, below which phase congruency values get penalized. Returns: Phase Congruency map with shape :math:`(N, H, W)` """ EPS = torch.finfo(x.dtype).eps N, _, H, W = x.shape # Fourier transform filters = _construct_filters( x, scales, orientations, min_length, mult, sigma_f, delta_theta, k ) imagefft = torch.fft.fft2(x) filters_ifft = torch.fft.ifft2(filters) filters_ifft = filters_ifft.real * math.sqrt(H * W) even_odd = torch.view_as_real(torch.fft.ifft2(imagefft * filters)).view( N, orientations, scales, H, W, 2 ) # Amplitude of even & odd filter response. An = sqrt(real^2 + imag^2) an = torch.sqrt(torch.sum(even_odd**2, dim=-1)) # Take filter at scale 0 and sum spatially # Record mean squared filter value at smallest scale. # This is used for noise estimation. em_n = (filters.view(1, orientations, scales, H, W)[:, :, :1, ...] ** 2).sum( dim=[-2, -1], keepdims=True ) # Sum of even filter convolution results. sum_e = even_odd[..., 0].sum(dim=2, keepdims=True) # Sum of odd filter convolution results. sum_o = even_odd[..., 1].sum(dim=2, keepdims=True) # Get weighted mean filter response vector, this gives the weighted mean phase angle. x_energy = torch.sqrt(sum_e**2 + sum_o**2) + EPS mean_e = sum_e / x_energy mean_o = sum_o / x_energy # Now calculate An(cos(phase_deviation) - | sin(phase_deviation)) | by # using dot and cross products between the weighted mean filter response # vector and the individual filter response vectors at each scale. # This quantity is phase congruency multiplied by An, which we call energy. # Extract even and odd convolution results. even = even_odd[..., 0] odd = even_odd[..., 1] energy = ( even * mean_e + odd * mean_o - torch.abs(even * mean_o - odd * mean_e) ).sum(dim=2, keepdim=True) # Compensate for noise # We estimate the noise power from the energy squared response at the # smallest scale. If the noise is Gaussian the energy squared will have a # Chi-squared 2DOF pdf. We calculate the median energy squared response # as this is a robust statistic. From this we estimate the mean. # The estimate of noise power is obtained by dividing the mean squared # energy value by the mean squared filter value abs_eo = torch.sqrt(torch.sum(even_odd[:, :, :1, ...] ** 2, dim=-1)).reshape( N, orientations, 1, 1, H * W ) median_e2n = torch.median(abs_eo**2, dim=-1, keepdim=True).values mean_e2n = -median_e2n / math.log(0.5) # Estimate of noise power. noise_power = mean_e2n / em_n # Now estimate the total energy^2 due to noise # Estimate for sum(An^2) + sum(Ai.*Aj.*(cphi.*cphj + sphi.*sphj)) filters_ifft = filters_ifft.view(1, orientations, scales, H, W) sum_an2 = torch.sum(filters_ifft**2, dim=-3, keepdim=True) sum_ai_aj = torch.zeros(N, orientations, 1, H, W).to(x) for s in range(scales - 1): sum_ai_aj = sum_ai_aj + ( filters_ifft[:, :, s : s + 1] * filters_ifft[:, :, s + 1 :] ).sum(dim=-3, keepdim=True) sum_an2 = torch.sum(sum_an2, dim=[-1, -2], keepdim=True) sum_ai_aj = torch.sum(sum_ai_aj, dim=[-1, -2], keepdim=True) noise_energy2 = 2 * noise_power * sum_an2 + 4 * noise_power * sum_ai_aj # Rayleigh parameter tau = torch.sqrt(noise_energy2 / 2) # Expected value of noise energy noise_energy = tau * math.sqrt(math.pi / 2) moise_energy_sigma = torch.sqrt((2 - math.pi / 2) * tau**2) # Noise threshold T = noise_energy + k * moise_energy_sigma # The estimated noise effect calculated above is only valid for the PC_1 measure. # The PC_2 measure does not lend itself readily to the same analysis. However # empirically it seems that the noise effect is overestimated roughly by a factor # of 1.7 for the filter parameters used here. # Empirical rescaling of the estimated noise effect to suit the PC_2 phase congruency measure T = T / 1.7 # Apply noise threshold energy = torch.max(energy - T, torch.zeros_like(T)) eps = torch.finfo(energy.dtype).eps energy_all = energy.sum(dim=[1, 2]) + eps an_all = an.sum(dim=[1, 2]) + eps result_pc = energy_all / an_all return result_pc.unsqueeze(1) def _lowpassfilter(size: Tuple[int, int], cutoff: float, n: int) -> torch.Tensor: r""" Constructs a low-pass Butterworth filter. Args: size: Tuple with height and width of filter to construct cutoff: Cutoff frequency of the filter in (0, 0.5() n: Filter order. Higher `n` means sharper transition. Note that `n` is doubled so that it is always an even integer. Returns: f = 1 / (1 + w/cutoff) ^ 2n """ assert 0 < cutoff <= 0.5, 'Cutoff frequency must be between 0 and 0.5' assert n > 1 and int(n) == n, 'n must be an integer >= 1' grid_x, grid_y = get_meshgrid(size) # A matrix with every pixel = radius relative to centre. radius = torch.sqrt(grid_x**2 + grid_y**2) return ifftshift(1.0 / (1.0 + (radius / cutoff) ** (2 * n))) @ARCH_REGISTRY.register()
[docs] class FSIM(nn.Module): r"""Args: - chromatic: Flag to compute FSIMc, which also takes into account chromatic components - scales: Number of wavelets used for computation of phase congruensy maps - orientations: Number of filter orientations used for computation of phase congruensy maps - min_length: Wavelength of smallest scale filter - mult: Scaling factor between successive filters - sigma_f: Ratio of the standard deviation of the Gaussian describing the log Gabor filter's transfer function in the frequency domain to the filter center frequency. - delta_theta: Ratio of angular interval between filter orientations and the standard deviation of the angular Gaussian function used to construct filters in the frequency plane. - k: No of standard deviations of the noise energy beyond the mean at which we set the noise threshold point, below which phase congruency values get penalized. References: L. Zhang, L. Zhang, X. Mou and D. Zhang, "FSIM: A Feature Similarity Index for Image Quality Assessment," IEEE Transactions on Image Processing, vol. 20, no. 8, pp. 2378-2386, Aug. 2011, doi: 10.1109/TIP.2011.2109730. https://ieeexplore.ieee.org/document/5705575 """ def __init__( self, chromatic: bool = True, scales: int = 4, orientations: int = 4, min_length: int = 6, mult: int = 2, sigma_f: float = 0.55, delta_theta: float = 1.2, k: float = 2.0, ) -> None: super().__init__() # Save function with predefined parameters, rather than parameters themself self.fsim = functools.partial( fsim, chromatic=chromatic, scales=scales, orientations=orientations, min_length=min_length, mult=mult, sigma_f=sigma_f, delta_theta=delta_theta, k=k, )
[docs] def forward( self, X: torch.Tensor, Y: torch.Tensor, ) -> torch.Tensor: r"""Computation of FSIM as a loss function. Args: - x: An input tensor. Shape :math:`(N, C, H, W)`. - y: A target tensor. Shape :math:`(N, C, H, W)`. Returns: - Value of FSIM loss to be minimized in [0, 1] range. """ assert X.shape == Y.shape, ( f'Input and reference images should have the same shape, but got {X.shape} and {Y.shape}' ) score = self.fsim(X, Y) return score