r"""FSIM Metric
@article{zhang2011fsim,
title={FSIM: A feature similarity index for image quality assessment},
author={Zhang, Lin and Zhang, Lei and Mou, Xuanqin and Zhang, David},
journal={IEEE transactions on Image Processing},
volume={20},
number={8},
pages={2378--2386},
year={2011},
publisher={IEEE}
}
Created by: https://github.com/photosynthesis-team/piq/blob/master/piq/fsim.py
Modified by: Jiadi Mo (https://github.com/JiadiMo)
Refer to:
Official matlab code from https://www4.comp.polyu.edu.hk/~cslzhang/IQA/FSIM/Files/FeatureSIM.m
PIQA from https://github.com/francois-rozet/piqa/blob/master/piqa/fsim.py
"""
import math
import functools
from typing import Tuple
import torch.nn as nn
import torch
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.utils.color_util import rgb2yiq
from .func_util import gradient_map, similarity_map, ifftshift, get_meshgrid
[docs]
def fsim(
x: torch.Tensor,
y: torch.Tensor,
chromatic: bool = True,
scales: int = 4,
orientations: int = 4,
min_length: int = 6,
mult: int = 2,
sigma_f: float = 0.55,
delta_theta: float = 1.2,
k: float = 2.0,
) -> torch.Tensor:
r"""Compute Feature Similarity Index Measure for a batch of images.
Args:
- x: An input tensor. Shape :math:`(N, C, H, W)`.
- y: A target tensor. Shape :math:`(N, C, H, W)`.
- chromatic: Flag to compute FSIMc, which also takes into account chromatic components
- scales: Number of wavelets used for computation of phase congruensy maps
- orientations: Number of filter orientations used for computation of phase congruensy maps
- min_length: Wavelength of smallest scale filter
- mult: Scaling factor between successive filters
- sigma_f: Ratio of the standard deviation of the Gaussian describing the log Gabor filter's
transfer function in the frequency domain to the filter center frequency.
- delta_theta: Ratio of angular interval between filter orientations and the standard deviation
of the angular Gaussian function used to construct filters in the frequency plane.
- k: No of standard deviations of the noise energy beyond the mean at which we set the noise
threshold point, below which phase congruency values get penalized.
Returns:
- Index of similarity between two images. Usually in [0, 1] interval.
Can be bigger than 1 for predicted :math:`x` images with higher contrast than the original ones.
References:
L. Zhang, L. Zhang, X. Mou and D. Zhang, "FSIM: A Feature Similarity Index for Image Quality Assessment,"
IEEE Transactions on Image Processing, vol. 20, no. 8, pp. 2378-2386, Aug. 2011, doi: 10.1109/TIP.2011.2109730.
https://ieeexplore.ieee.org/document/5705575
"""
# Rescale to [0, 255] range, because all constant are calculated for this factor
x = x / float(1.0) * 255
y = y / float(1.0) * 255
# Apply average pooling
kernel_size = max(1, round(min(x.shape[-2:]) / 256))
x = torch.nn.functional.avg_pool2d(x, kernel_size)
y = torch.nn.functional.avg_pool2d(y, kernel_size)
num_channels = x.size(1)
# Convert RGB to YIQ color space
if num_channels == 3:
x_yiq = rgb2yiq(x)
y_yiq = rgb2yiq(y)
x_lum = x_yiq[:, :1]
y_lum = y_yiq[:, :1]
x_i = x_yiq[:, 1:2]
y_i = y_yiq[:, 1:2]
x_q = x_yiq[:, 2:]
y_q = y_yiq[:, 2:]
else:
x_lum = x
y_lum = y
# Compute phase congruency maps
pc_x = _phase_congruency(
x_lum,
scales=scales,
orientations=orientations,
min_length=min_length,
mult=mult,
sigma_f=sigma_f,
delta_theta=delta_theta,
k=k,
)
pc_y = _phase_congruency(
y_lum,
scales=scales,
orientations=orientations,
min_length=min_length,
mult=mult,
sigma_f=sigma_f,
delta_theta=delta_theta,
k=k,
)
# Gradient maps
scharr_filter = (
torch.tensor([[[-3.0, 0.0, 3.0], [-10.0, 0.0, 10.0], [-3.0, 0.0, 3.0]]]) / 16
)
kernels = torch.stack([scharr_filter, scharr_filter.transpose(-1, -2)])
grad_map_x = gradient_map(x_lum, kernels)
grad_map_y = gradient_map(y_lum, kernels)
# Constants from the paper
T1, T2, T3, T4, lmbda = 0.85, 160, 200, 200, 0.03
# Compute FSIM
PC = similarity_map(pc_x, pc_y, T1)
GM = similarity_map(grad_map_x, grad_map_y, T2)
pc_max = torch.where(pc_x > pc_y, pc_x, pc_y)
score = GM * PC * pc_max # torch.sum(score)/torch.sum(pc_max)
if chromatic:
assert num_channels == 3, (
'Chromatic component can be computed only for RGB images!'
)
S_I = similarity_map(x_i, y_i, T3)
S_Q = similarity_map(x_q, y_q, T4)
score = score * torch.abs(S_I * S_Q) ** lmbda
# Complex gradients will work in PyTorch 1.6.0
# score = score * torch.real((S_I * S_Q).to(torch.complex64) ** lmbda)
result = score.sum(dim=[1, 2, 3]) / pc_max.sum(dim=[1, 2, 3])
return result
def _construct_filters(
x: torch.Tensor,
scales: int = 4,
orientations: int = 4,
min_length: int = 6,
mult: int = 2,
sigma_f: float = 0.55,
delta_theta: float = 1.2,
k: float = 2.0,
use_lowpass_filter=True,
):
"""Creates a stack of filters used for computation of phase congruensy maps
Args:
- x: Tensor. Shape :math:`(N, 1, H, W)`.
- scales: Number of wavelets
- orientations: Number of filter orientations
- min_length: Wavelength of smallest scale filter
- mult: Scaling factor between successive filters
- sigma_f: Ratio of the standard deviation of the Gaussian
describing the log Gabor filter's transfer function
in the frequency domain to the filter center frequency.
- delta_theta: Ratio of angular interval between filter orientations
and the standard deviation of the angular Gaussian function
used to construct filters in the freq. plane.
- k: No of standard deviations of the noise energy beyond the mean
at which we set the noise threshold point, below which phase
congruency values get penalized.
"""
N, _, H, W = x.shape
# Calculate the standard deviation of the angular Gaussian function
# used to construct filters in the freq. plane.
theta_sigma = math.pi / (orientations * delta_theta)
# Pre-compute some stuff to speed up filter construction
grid_x, grid_y = get_meshgrid((H, W))
radius = torch.sqrt(grid_x**2 + grid_y**2)
theta = torch.atan2(-grid_y, grid_x)
# Quadrant shift radius and theta so that filters are constructed with 0 frequency at the corners.
# Get rid of the 0 radius value at the 0 frequency point (now at top-left corner)
# so that taking the log of the radius will not cause trouble.
radius = ifftshift(radius)
theta = ifftshift(theta)
radius[0, 0] = 1
sintheta = torch.sin(theta)
costheta = torch.cos(theta)
# Filters are constructed in terms of two components.
# 1) The radial component, which controls the frequency band that the filter responds to
# 2) The angular component, which controls the orientation that the filter responds to.
# The two components are multiplied together to construct the overall filter.
# First construct a low-pass filter that is as large as possible, yet falls
# away to zero at the boundaries. All log Gabor filters are multiplied by
# this to ensure no extra frequencies at the 'corners' of the FFT are
# incorporated as this seems to upset the normalisation process when
lp = _lowpassfilter(size=(H, W), cutoff=0.45, n=15)
# Construct the radial filter components...
log_gabor = []
for s in range(scales):
wavelength = min_length * mult**s
omega_0 = 1.0 / wavelength
gabor_filter = torch.exp(
(-(torch.log(radius / omega_0) ** 2)) / (2 * math.log(sigma_f) ** 2)
)
if use_lowpass_filter:
gabor_filter = gabor_filter * lp
gabor_filter[0, 0] = 0
log_gabor.append(gabor_filter)
# Then construct the angular filter components...
spread = []
for o in range(orientations):
angl = o * math.pi / orientations
# For each point in the filter matrix calculate the angular distance from
# the specified filter orientation. To overcome the angular wrap-around
# problem sine difference and cosine difference values are first computed
# and then the atan2 function is used to determine angular distance.
ds = sintheta * math.cos(angl) - costheta * math.sin(
angl
) # Difference in sine.
dc = costheta * math.cos(angl) + sintheta * math.sin(
angl
) # Difference in cosine.
dtheta = torch.abs(torch.atan2(ds, dc))
spread.append(torch.exp((-(dtheta**2)) / (2 * theta_sigma**2)))
spread = torch.stack(spread)
log_gabor = torch.stack(log_gabor)
# Multiply, add batch dimension and transfer to correct device.
filters = (
(spread.repeat_interleave(scales, dim=0) * log_gabor.repeat(orientations, 1, 1))
.unsqueeze(0)
.to(x)
)
return filters
def _phase_congruency(
x: torch.Tensor,
scales: int = 4,
orientations: int = 4,
min_length: int = 6,
mult: int = 2,
sigma_f: float = 0.55,
delta_theta: float = 1.2,
k: float = 2.0,
) -> torch.Tensor:
r"""Compute Phase Congruence for a batch of greyscale images
Args:
x: Tensor. Shape :math:`(N, 1, H, W)`.
scales: Number of wavelet scales
orientations: Number of filter orientations
min_length: Wavelength of smallest scale filter
mult: Scaling factor between successive filters
sigma_f: Ratio of the standard deviation of the Gaussian
describing the log Gabor filter's transfer function
in the frequency domain to the filter center frequency.
delta_theta: Ratio of angular interval between filter orientations
and the standard deviation of the angular Gaussian function
used to construct filters in the freq. plane.
k: No of standard deviations of the noise energy beyond the mean
at which we set the noise threshold point, below which phase
congruency values get penalized.
Returns:
Phase Congruency map with shape :math:`(N, H, W)`
"""
EPS = torch.finfo(x.dtype).eps
N, _, H, W = x.shape
# Fourier transform
filters = _construct_filters(
x, scales, orientations, min_length, mult, sigma_f, delta_theta, k
)
imagefft = torch.fft.fft2(x)
filters_ifft = torch.fft.ifft2(filters)
filters_ifft = filters_ifft.real * math.sqrt(H * W)
even_odd = torch.view_as_real(torch.fft.ifft2(imagefft * filters)).view(
N, orientations, scales, H, W, 2
)
# Amplitude of even & odd filter response. An = sqrt(real^2 + imag^2)
an = torch.sqrt(torch.sum(even_odd**2, dim=-1))
# Take filter at scale 0 and sum spatially
# Record mean squared filter value at smallest scale.
# This is used for noise estimation.
em_n = (filters.view(1, orientations, scales, H, W)[:, :, :1, ...] ** 2).sum(
dim=[-2, -1], keepdims=True
)
# Sum of even filter convolution results.
sum_e = even_odd[..., 0].sum(dim=2, keepdims=True)
# Sum of odd filter convolution results.
sum_o = even_odd[..., 1].sum(dim=2, keepdims=True)
# Get weighted mean filter response vector, this gives the weighted mean phase angle.
x_energy = torch.sqrt(sum_e**2 + sum_o**2) + EPS
mean_e = sum_e / x_energy
mean_o = sum_o / x_energy
# Now calculate An(cos(phase_deviation) - | sin(phase_deviation)) | by
# using dot and cross products between the weighted mean filter response
# vector and the individual filter response vectors at each scale.
# This quantity is phase congruency multiplied by An, which we call energy.
# Extract even and odd convolution results.
even = even_odd[..., 0]
odd = even_odd[..., 1]
energy = (
even * mean_e + odd * mean_o - torch.abs(even * mean_o - odd * mean_e)
).sum(dim=2, keepdim=True)
# Compensate for noise
# We estimate the noise power from the energy squared response at the
# smallest scale. If the noise is Gaussian the energy squared will have a
# Chi-squared 2DOF pdf. We calculate the median energy squared response
# as this is a robust statistic. From this we estimate the mean.
# The estimate of noise power is obtained by dividing the mean squared
# energy value by the mean squared filter value
abs_eo = torch.sqrt(torch.sum(even_odd[:, :, :1, ...] ** 2, dim=-1)).reshape(
N, orientations, 1, 1, H * W
)
median_e2n = torch.median(abs_eo**2, dim=-1, keepdim=True).values
mean_e2n = -median_e2n / math.log(0.5)
# Estimate of noise power.
noise_power = mean_e2n / em_n
# Now estimate the total energy^2 due to noise
# Estimate for sum(An^2) + sum(Ai.*Aj.*(cphi.*cphj + sphi.*sphj))
filters_ifft = filters_ifft.view(1, orientations, scales, H, W)
sum_an2 = torch.sum(filters_ifft**2, dim=-3, keepdim=True)
sum_ai_aj = torch.zeros(N, orientations, 1, H, W).to(x)
for s in range(scales - 1):
sum_ai_aj = sum_ai_aj + (
filters_ifft[:, :, s : s + 1] * filters_ifft[:, :, s + 1 :]
).sum(dim=-3, keepdim=True)
sum_an2 = torch.sum(sum_an2, dim=[-1, -2], keepdim=True)
sum_ai_aj = torch.sum(sum_ai_aj, dim=[-1, -2], keepdim=True)
noise_energy2 = 2 * noise_power * sum_an2 + 4 * noise_power * sum_ai_aj
# Rayleigh parameter
tau = torch.sqrt(noise_energy2 / 2)
# Expected value of noise energy
noise_energy = tau * math.sqrt(math.pi / 2)
moise_energy_sigma = torch.sqrt((2 - math.pi / 2) * tau**2)
# Noise threshold
T = noise_energy + k * moise_energy_sigma
# The estimated noise effect calculated above is only valid for the PC_1 measure.
# The PC_2 measure does not lend itself readily to the same analysis. However
# empirically it seems that the noise effect is overestimated roughly by a factor
# of 1.7 for the filter parameters used here.
# Empirical rescaling of the estimated noise effect to suit the PC_2 phase congruency measure
T = T / 1.7
# Apply noise threshold
energy = torch.max(energy - T, torch.zeros_like(T))
eps = torch.finfo(energy.dtype).eps
energy_all = energy.sum(dim=[1, 2]) + eps
an_all = an.sum(dim=[1, 2]) + eps
result_pc = energy_all / an_all
return result_pc.unsqueeze(1)
def _lowpassfilter(size: Tuple[int, int], cutoff: float, n: int) -> torch.Tensor:
r"""
Constructs a low-pass Butterworth filter.
Args:
size: Tuple with height and width of filter to construct
cutoff: Cutoff frequency of the filter in (0, 0.5()
n: Filter order. Higher `n` means sharper transition.
Note that `n` is doubled so that it is always an even integer.
Returns:
f = 1 / (1 + w/cutoff) ^ 2n
"""
assert 0 < cutoff <= 0.5, 'Cutoff frequency must be between 0 and 0.5'
assert n > 1 and int(n) == n, 'n must be an integer >= 1'
grid_x, grid_y = get_meshgrid(size)
# A matrix with every pixel = radius relative to centre.
radius = torch.sqrt(grid_x**2 + grid_y**2)
return ifftshift(1.0 / (1.0 + (radius / cutoff) ** (2 * n)))
@ARCH_REGISTRY.register()
[docs]
class FSIM(nn.Module):
r"""Args:
- chromatic: Flag to compute FSIMc, which also takes into account chromatic components
- scales: Number of wavelets used for computation of phase congruensy maps
- orientations: Number of filter orientations used for computation of phase congruensy maps
- min_length: Wavelength of smallest scale filter
- mult: Scaling factor between successive filters
- sigma_f: Ratio of the standard deviation of the Gaussian describing the log Gabor filter's
transfer function in the frequency domain to the filter center frequency.
- delta_theta: Ratio of angular interval between filter orientations and the standard deviation
of the angular Gaussian function used to construct filters in the frequency plane.
- k: No of standard deviations of the noise energy beyond the mean at which we set the noise
threshold point, below which phase congruency values get penalized.
References:
L. Zhang, L. Zhang, X. Mou and D. Zhang, "FSIM: A Feature Similarity Index for Image Quality Assessment,"
IEEE Transactions on Image Processing, vol. 20, no. 8, pp. 2378-2386, Aug. 2011, doi: 10.1109/TIP.2011.2109730.
https://ieeexplore.ieee.org/document/5705575
"""
def __init__(
self,
chromatic: bool = True,
scales: int = 4,
orientations: int = 4,
min_length: int = 6,
mult: int = 2,
sigma_f: float = 0.55,
delta_theta: float = 1.2,
k: float = 2.0,
) -> None:
super().__init__()
# Save function with predefined parameters, rather than parameters themself
self.fsim = functools.partial(
fsim,
chromatic=chromatic,
scales=scales,
orientations=orientations,
min_length=min_length,
mult=mult,
sigma_f=sigma_f,
delta_theta=delta_theta,
k=k,
)
[docs]
def forward(
self,
X: torch.Tensor,
Y: torch.Tensor,
) -> torch.Tensor:
r"""Computation of FSIM as a loss function.
Args:
- x: An input tensor. Shape :math:`(N, C, H, W)`.
- y: A target tensor. Shape :math:`(N, C, H, W)`.
Returns:
- Value of FSIM loss to be minimized in [0, 1] range.
"""
assert X.shape == Y.shape, (
f'Input and reference images should have the same shape, but got {X.shape} and {Y.shape}'
)
score = self.fsim(X, Y)
return score