r"""BRISQUE Metric
Created by: https://github.com/photosynthesis-team/piq/blob/master/piq/brisque.py
Modified by: Jiadi Mo (https://github.com/JiadiMo)
Reference:
MATLAB codes: https://live.ece.utexas.edu/research/Quality/index_algorithms.htm BRISQUE;
Pretrained model from: https://github.com/photosynthesis-team/piq/releases/download/v0.4.0/brisque_svm_weights.pt
"""
import scipy
import numpy as np
import torch
from pyiqa.utils.color_util import to_y_channel
from pyiqa.matlab_utils import imresize
from pyiqa.matlab_utils.nss_feature import compute_nss_features
from .func_util import estimate_ggd_param, estimate_aggd_param, normalize_img_with_gauss
from pyiqa.utils.download_util import load_file_from_url
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import get_url_from_name
[docs]
default_model_urls = {
'url': get_url_from_name('brisque_svm_weights.pth'),
'brisque_matlab': get_url_from_name('brisque_matlab.mat'),
}
[docs]
def brisque(
x: torch.Tensor,
kernel_size: int = 7,
kernel_sigma: float = 7 / 6,
test_y_channel: bool = True,
sv_coef: torch.Tensor = None,
sv: torch.Tensor = None,
gamma: float = 0.05,
rho: float = -153.591,
scale: float = 1,
version: str = 'original',
) -> torch.Tensor:
r"""Interface of BRISQUE index.
Args:
x (torch.Tensor): An input tensor. Shape :math:`(N, C, H, W)`.
kernel_size (int): The side-length of the sliding window used in comparison. Must be an odd value.
kernel_sigma (float): Sigma of normal distribution.
test_y_channel (bool): Whether to use the y-channel of YCBCR.
sv_coef (torch.Tensor): Support vector coefficients.
sv (torch.Tensor): Support vectors.
gamma (float): Gamma parameter for the RBF kernel.
rho (float): Bias term in the decision function.
scale (float): Scaling factor for the features.
version (str): Version of the BRISQUE implementation ('original' or 'matlab').
Returns:
torch.Tensor: Value of BRISQUE index.
References:
Mittal, Anish, Anush Krishna Moorthy, and Alan Conrad Bovik.
"No-reference image quality assessment in the spatial domain."
IEEE Transactions on image processing 21, no. 12 (2012): 4695-4708.
"""
if test_y_channel and x.size(1) == 3:
x = to_y_channel(x, 255.0)
else:
x = x * 255
features = []
num_of_scales = 2
for _ in range(num_of_scales):
if version == 'matlab':
xnorm = normalize_img_with_gauss(
x, kernel_size, kernel_sigma, padding='replicate'
)
features.append(compute_nss_features(xnorm))
elif version == 'original':
features.append(natural_scene_statistics(x, kernel_size, kernel_sigma))
x = imresize(x, scale=0.5, antialiasing=True)
features = torch.cat(features, dim=-1)
sv_coef = sv_coef.to(x)
sv = sv.to(x)
if version == 'original':
scaled_features = scale_features(features)
elif version == 'matlab':
scaled_features = features / scale
sv = sv.t()
kernel_features = rbf_kernel(features=scaled_features, sv=sv, gamma=gamma)
score = kernel_features @ sv_coef - rho
return score
[docs]
def natural_scene_statistics(
luma: torch.Tensor, kernel_size: int = 7, sigma: float = 7.0 / 6
) -> torch.Tensor:
"""
Compute natural scene statistics (NSS) features for a given luminance image.
Args:
luma (torch.Tensor): Luminance image tensor.
kernel_size (int): Size of the Gaussian kernel.
sigma (float): Standard deviation of the Gaussian kernel.
Returns:
torch.Tensor: NSS features.
"""
luma_nrmlzd = normalize_img_with_gauss(luma, kernel_size, sigma, padding='same')
alpha, sigma = estimate_ggd_param(luma_nrmlzd)
features = [alpha, sigma.pow(2)]
shifts = [(0, 1), (1, 0), (1, 1), (-1, 1)]
for shift in shifts:
shifted_luma_nrmlzd = torch.roll(luma_nrmlzd, shifts=shift, dims=(-2, -1))
alpha, sigma_l, sigma_r = estimate_aggd_param(
luma_nrmlzd * shifted_luma_nrmlzd, return_sigma=True
)
eta = (sigma_r - sigma_l) * torch.exp(
torch.lgamma(2.0 / alpha)
- (torch.lgamma(1.0 / alpha) + torch.lgamma(3.0 / alpha)) / 2
)
features.extend((alpha, eta, sigma_l.pow(2), sigma_r.pow(2)))
return torch.stack(features, dim=-1)
[docs]
def scale_features(features: torch.Tensor) -> torch.Tensor:
"""
Scale features to the range [-1, 1] based on predefined feature ranges.
Args:
features (torch.Tensor): Input features.
Returns:
torch.Tensor: Scaled features.
"""
lower_bound = -1
upper_bound = 1
# Feature range is taken from official implementation of BRISQUE on MATLAB.
# Source: https://live.ece.utexas.edu/research/Quality/index_algorithms.htm
feature_ranges = torch.tensor(
[
[0.338, 10],
[0.017204, 0.806612],
[0.236, 1.642],
[-0.123884, 0.20293],
[0.000155, 0.712298],
[0.001122, 0.470257],
[0.244, 1.641],
[-0.123586, 0.179083],
[0.000152, 0.710456],
[0.000975, 0.470984],
[0.249, 1.555],
[-0.135687, 0.100858],
[0.000174, 0.684173],
[0.000913, 0.534174],
[0.258, 1.561],
[-0.143408, 0.100486],
[0.000179, 0.685696],
[0.000888, 0.536508],
[0.471, 3.264],
[0.012809, 0.703171],
[0.218, 1.046],
[-0.094876, 0.187459],
[1.5e-005, 0.442057],
[0.001272, 0.40803],
[0.222, 1.042],
[-0.115772, 0.162604],
[1.6e-005, 0.444362],
[0.001374, 0.40243],
[0.227, 0.996],
[-0.117188, 0.09832299999999999],
[3e-005, 0.531903],
[0.001122, 0.369589],
[0.228, 0.99],
[-0.12243, 0.098658],
[2.8e-005, 0.530092],
[0.001118, 0.370399],
]
).to(features)
scaled_features = lower_bound + (upper_bound - lower_bound) * (
features - feature_ranges[..., 0]
) / (feature_ranges[..., 1] - feature_ranges[..., 0])
return scaled_features
[docs]
def rbf_kernel(
features: torch.Tensor, sv: torch.Tensor, gamma: float = 0.05
) -> torch.Tensor:
"""
Compute the Radial Basis Function (RBF) kernel between features and support vectors.
Args:
features (torch.Tensor): Input features.
sv (torch.Tensor): Support vectors.
gamma (float): Gamma parameter for the RBF kernel.
Returns:
torch.Tensor: RBF kernel values.
"""
dist = (features.unsqueeze(dim=-1) - sv.unsqueeze(dim=0)).pow(2).sum(dim=1)
return torch.exp(-dist * gamma)
@ARCH_REGISTRY.register()
[docs]
class BRISQUE(torch.nn.Module):
r"""Creates a criterion that measures the BRISQUE score.
Args:
kernel_size (int): By default, the mean and covariance of a pixel is obtained
by convolution with given filter_size. Must be an odd value.
kernel_sigma (float): Standard deviation for Gaussian kernel.
test_y_channel (bool): Whether to use the y-channel of YCBCR.
version (str): Version of the BRISQUE implementation ('original' or 'matlab').
pretrained_model_path (str, optional): The model path.
Attributes:
kernel_size (int): The side-length of the sliding window used in comparison.
kernel_sigma (float): Sigma of normal distribution.
test_y_channel (bool): Whether to use the y-channel of YCBCR.
sv_coef (torch.Tensor): Support vector coefficients.
sv (torch.Tensor): Support vectors.
gamma (float): Gamma parameter for the RBF kernel.
rho (float): Bias term in the decision function.
scale (float): Scaling factor for the features.
version (str): Version of the BRISQUE implementation ('original' or 'matlab').
"""
def __init__(
self,
kernel_size: int = 7,
kernel_sigma: float = 7 / 6,
test_y_channel: bool = True,
version: str = 'original',
pretrained_model_path: str = None,
) -> None:
super().__init__()
[docs]
self.kernel_size = kernel_size
# This check might look redundant because kernel size is checked within the brisque function anyway.
# However, this check allows to fail fast when the loss is being initialised and training has not been started.
assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]'
assert test_y_channel, (
'Only [test_y_channel=True] is supported for current BRISQUE model, which is taken directly from official codes: https://github.com/utlive/BRISQUE.'
)
[docs]
self.kernel_sigma = kernel_sigma
[docs]
self.test_y_channel = test_y_channel
if pretrained_model_path is not None:
self.sv_coef, self.sv = torch.load(
pretrained_model_path, weights_only=False
)
elif version == 'original':
# gamma and rho are SVM model parameters taken from official implementation of BRISQUE on MATLAB
# Source: https://live.ece.utexas.edu/research/Quality/index_algorithms.htm
pretrained_model_path = load_file_from_url(default_model_urls['url'])
self.sv_coef, self.sv = torch.load(
pretrained_model_path, weights_only=False
)
elif version == 'matlab':
pretrained_model_path = load_file_from_url(
default_model_urls['brisque_matlab']
)
params = scipy.io.loadmat(pretrained_model_path)
sv = params['sv']
sv_coef = np.ravel(params['sv_coef'])
self.sv = torch.from_numpy(sv)
self.sv_coef = torch.from_numpy(sv_coef)
# Set hyper-parameters based on the version
if version == 'original':
self.gamma = 0.05
self.rho = -153.591
self.scale = 1
elif version == 'matlab':
self.gamma = 1
self.rho = -43.4582
self.scale = 0.3210
[docs]
self.sv = self.sv / self.scale
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""Computation of BRISQUE score as a loss function.
Args:
x (torch.Tensor): An input tensor with (N, C, H, W) shape. RGB channel order for colour images.
Returns:
torch.Tensor: Value of BRISQUE metric.
"""
return brisque(
x,
kernel_size=self.kernel_size,
kernel_sigma=self.kernel_sigma,
test_y_channel=self.test_y_channel,
sv_coef=self.sv_coef,
sv=self.sv,
gamma=self.gamma,
rho=self.rho,
scale=self.scale,
version=self.version,
)