r"""NRQM Metric, proposed in
Chao Ma, Chih-Yuan Yang, Xiaokang Yang, Ming-Hsuan Yang
"Learning a No-Reference Quality Metric for Single-Image Super-Resolution"
Computer Vision and Image Understanding (CVIU), 2017
Matlab reference: https://github.com/chaoma99/sr-metric
This PyTorch implementation by: Chaofeng Chen (https://github.com/chaofengc)
"""
import math
import scipy.io
import torch
from torch import Tensor
import torch.nn.functional as F
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.utils.color_util import to_y_channel
from pyiqa.utils.download_util import load_file_from_url
from pyiqa.matlab_utils import (
imresize,
fspecial,
SCFpyr_PyTorch,
dct2d,
im2col,
ExactPadding2d,
)
from pyiqa.archs.func_util import extract_2d_patches
from pyiqa.archs.ssim_arch import ssim as ssim_func
from pyiqa.archs.niqe_arch import NIQE
from pyiqa.archs.arch_util import get_url_from_name
[docs]
default_model_urls = {'url': get_url_from_name('NRQM_model.mat')}
[docs]
def get_gauss_pyramid(x: Tensor, scale: int = 2):
r"""Get gaussian pyramid images with gaussian kernel."""
pyr = [x]
kernel = fspecial(3, 0.5, x.shape[1]).to(x)
pad_func = ExactPadding2d(3, stride=1, mode='same')
for i in range(scale):
x = F.conv2d(pad_func(x), kernel, groups=x.shape[1])
x = x[:, :, 1::2, 1::2]
pyr.append(x)
return pyr
[docs]
def get_var_gen_gauss(x, eps=1e-7):
r"""Get mean and variance of input local patch."""
std = x.abs().std(dim=-1, unbiased=True)
mean = x.abs().mean(dim=-1)
rho = std / (mean + eps)
return rho
[docs]
def gamma_gen_gauss(x: Tensor, block_seg=1e4):
r"""General gaussian distribution estimation.
Args:
block_seg: maximum number of blocks in parallel to avoid OOM
"""
pshape = x.shape[:-1]
x = x.reshape(-1, x.shape[-1])
eps = 1e-7
gamma = torch.arange(0.03, 10 + 0.001, 0.001).to(x)
r_table = (
torch.lgamma(1.0 / gamma)
+ torch.lgamma(3.0 / gamma)
- 2 * torch.lgamma(2.0 / gamma)
).exp()
r_table = r_table.unsqueeze(0)
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=True)
mean_abs = (x - mean).abs().mean(dim=-1, keepdim=True) ** 2
rho = var / (mean_abs + eps)
if rho.shape[0] > block_seg:
rho_seg = rho.chunk(int(rho.shape[0] // block_seg))
indexes = []
for r in rho_seg:
tmp_idx = (r - r_table).abs().argmin(dim=-1)
indexes.append(tmp_idx)
indexes = torch.cat(indexes)
else:
indexes = (rho - r_table).abs().argmin(dim=-1)
solution = gamma[indexes].reshape(*pshape)
return solution
[docs]
def gamma_dct(dct_img_block: torch.Tensor):
r"""Generalized gaussian distribution features"""
b, _, _, h, w = dct_img_block.shape
dct_flatten = dct_img_block.reshape(b, -1, h * w)[:, :, 1:]
g = gamma_gen_gauss(dct_flatten)
g = torch.sort(g, dim=-1)[0]
return g
[docs]
def coeff_var_dct(dct_img_block: torch.Tensor):
r"""Gaussian var, mean features"""
b, _, _, h, w = dct_img_block.shape
dct_flatten = dct_img_block.reshape(b, -1, h * w)[:, :, 1:]
rho = get_var_gen_gauss(dct_flatten)
rho = torch.sort(rho, dim=-1)[0]
return rho
[docs]
def oriented_dct_rho(dct_img_block: torch.Tensor):
r"""Oriented frequency features"""
eps = 1e-8
# oriented 1
feat1 = torch.cat(
[
dct_img_block[..., 0, 1:],
dct_img_block[..., 1, 2:],
dct_img_block[..., 2, 4:],
dct_img_block[..., 3, 5:],
],
dim=-1,
).squeeze(-2)
g1 = get_var_gen_gauss(feat1, eps)
# oriented 2
feat2 = torch.cat(
[
dct_img_block[..., 1, [1]],
dct_img_block[..., 2, 2:4],
dct_img_block[..., 3, 2:5],
dct_img_block[..., 4, 3:],
dct_img_block[..., 5, 4:],
dct_img_block[..., 6, 4:],
],
dim=-1,
).squeeze(-2)
g2 = get_var_gen_gauss(feat2, eps)
# oriented 3
feat3 = torch.cat(
[
dct_img_block[..., 1:, 0],
dct_img_block[..., 2:, 1],
dct_img_block[..., 4:, 2],
dct_img_block[..., 5:, 3],
],
dim=-1,
).squeeze(-2)
g3 = get_var_gen_gauss(feat3, eps)
rho = torch.stack([g1, g2, g3], dim=-1).var(dim=-1)
rho = torch.sort(rho, dim=-1)[0]
return rho
[docs]
def block_dct(img: Tensor):
r"""Get local frequency features"""
img_blocks = extract_2d_patches(img, 3 + 2 * 2, 3)
dct_img_blocks = dct2d(img_blocks)
features = []
# general gaussian distribution features
gamma_L1 = gamma_dct(dct_img_blocks)
p10_gamma_L1 = gamma_L1[:, : math.ceil(0.1 * gamma_L1.shape[-1]) + 1].mean(dim=-1)
p100_gamma_L1 = gamma_L1.mean(dim=-1)
features += [p10_gamma_L1, p100_gamma_L1]
# coefficient variation estimation
coeff_var_L1 = coeff_var_dct(dct_img_blocks)
p10_last_cv_L1 = coeff_var_L1[:, math.floor(0.9 * coeff_var_L1.shape[-1]) :].mean(
dim=-1
)
p100_cv_L1 = coeff_var_L1.mean(dim=-1)
features += [p10_last_cv_L1, p100_cv_L1]
# oriented dct features
ori_dct_feat = oriented_dct_rho(dct_img_blocks)
p10_last_orientation_L1 = ori_dct_feat[
:, math.floor(0.9 * ori_dct_feat.shape[-1]) :
].mean(dim=-1)
p100_orientation_L1 = ori_dct_feat.mean(dim=-1)
features += [p10_last_orientation_L1, p100_orientation_L1]
dct_feat = torch.stack(features, dim=1)
return dct_feat
[docs]
def norm_sender_normalized(pyr, num_scale=2, num_bands=6, blksz=3, eps=1e-12):
r"""Normalize pyramid with local spatial neighbor and band neighbor"""
border = blksz // 2
guardband = 16
subbands = []
for si in range(num_scale):
for bi in range(num_bands):
idx = si * num_bands + bi
current_band = pyr[idx]
N = blksz**2
# 3x3 window pixels
tmp = F.unfold(current_band.unsqueeze(1), 3, stride=1)
tmp = tmp.transpose(1, 2)
b, hw = tmp.shape[:2]
# parent pixels
parent_idx = idx + num_bands
if parent_idx < len(pyr):
tmp_parent = pyr[parent_idx]
tmp_parent = imresize(tmp_parent, sizes=current_band.shape[-2:])
tmp_parent = tmp_parent[:, border:-border, border:-border].reshape(
b, hw, 1
)
tmp = torch.cat((tmp, tmp_parent), dim=-1)
N += 1
# neighbor band pixels
for ni in range(num_bands):
if ni != bi:
ni_idx = si * num_bands + ni
tmp_nei = pyr[ni_idx]
tmp_nei = tmp_nei[:, border:-border, border:-border].reshape(
b, hw, 1
)
tmp = torch.cat((tmp, tmp_nei), dim=-1)
C_x = tmp.transpose(1, 2) @ tmp / tmp.shape[1]
# correct possible negative eigenvalue
L, Q = torch.linalg.eigh(C_x)
L_pos = L * (L > 0)
L_pos_sum = L_pos.sum(dim=1, keepdim=True)
L = (
L_pos
* L.sum(dim=1, keepdim=True)
/ (L_pos_sum + (L_pos_sum == 0).to(L.dtype))
)
C_x = Q @ torch.diag_embed(L) @ Q.transpose(1, 2)
o_c = current_band[:, border:-border, border:-border]
b, h, w = o_c.shape
o_c = o_c.reshape(b, hw)
o_c = o_c - o_c.mean(dim=1, keepdim=True)
tmp_y = (
torch.linalg.solve(C_x.transpose(1, 2), tmp.transpose(1, 2)).transpose(
1, 2
)
* tmp
/ N
)
tmp_y = tmp_y.to(o_c)
z = tmp_y.sum(dim=2).sqrt()
mask = z != 0
g_c = o_c * mask / (z * mask + eps)
g_c = g_c.reshape(b, h, w)
gb = int(guardband / (2 ** (si)))
g_c = g_c[:, gb:-gb, gb:-gb]
g_c = g_c - g_c.mean(dim=(1, 2), keepdim=True)
subbands.append(g_c)
return subbands
[docs]
def global_gsm(img: Tensor):
"""Global feature from gassian scale mixture model"""
batch_size = img.shape[0]
num_bands = 6
pyr = SCFpyr_PyTorch(height=2, nbands=num_bands, device=img.device).build(img)
lp_bands = [x[..., 0] for x in pyr[1]] + [x[..., 0] for x in pyr[2]]
subbands = norm_sender_normalized(lp_bands)
feat = []
# gamma
for sb in subbands:
feat.append(gamma_gen_gauss(sb.reshape(batch_size, -1)))
# gamma cross scale
for i in range(num_bands):
sb1 = subbands[i].reshape(batch_size, -1)
sb2 = subbands[i + num_bands].reshape(batch_size, -1)
gs = gamma_gen_gauss(torch.cat((sb1, sb2), dim=1))
feat.append(gs)
# structure correlation between scales
hp_band = pyr[0]
for sb in lp_bands:
curr_band = imresize(sb, sizes=hp_band.shape[1:]).unsqueeze(1)
_, tmpscore = ssim_func(
curr_band, hp_band.unsqueeze(1), get_cs=True, data_range=255
)
feat.append(tmpscore)
# structure correlation between orientations
for i in range(num_bands):
for j in range(i + 1, num_bands):
_, tmpscore = ssim_func(
subbands[i].unsqueeze(1),
subbands[j].unsqueeze(1),
get_cs=True,
data_range=255,
)
feat.append(tmpscore)
feat = torch.stack(feat, dim=1)
return feat
[docs]
def tree_regression(feat, ldau, rdau, threshold_value, pred_value, best_attri):
r"""Simple decision tree regression."""
prev_k = k = 0
for i in range(ldau.shape[0]):
best_col = best_attri[k] - 1
threshold = threshold_value[k]
key_value = feat[best_col]
prev_k = k
k = ldau[k] - 1 if key_value <= threshold else rdau[k] - 1
if k == -1:
break
y_pred = pred_value[prev_k]
return y_pred
[docs]
def random_forest_regression(feat, ldau, rdau, threshold_value, pred_value, best_attri):
r"""Simple random forest regression.
Note: currently, this is non-differentiable and only support CPU.
"""
feat = feat.cpu().data.numpy()
b, dim = feat.shape
node_num, tree_num = ldau.shape
pred = []
for i in range(b):
tmp_feat = feat[i]
tmp_pred = []
for i in range(tree_num):
tmp_result = tree_regression(
tmp_feat,
ldau[:, i],
rdau[:, i],
threshold_value[:, i],
pred_value[:, i],
best_attri[:, i],
)
tmp_pred.append(tmp_result)
pred.append(tmp_pred)
pred = torch.tensor(pred)
return pred.mean(dim=1, keepdim=True)
[docs]
def nrqm(
img: Tensor,
linear_param,
rf_param,
) -> Tensor:
"""Calculate NRQM
Args:
img (Tensor): Input image.
linear_param (np.array): (4, 1) linear regression params
rf_param: params of 3 random forest for 3 kinds of features
"""
assert img.ndim == 4, (
'Input image must be a gray or Y (of YCbCr) image with shape (b, c, h, w).'
)
# crop image
b, c, h, w = img.shape
img = img.double()
img_pyr = get_gauss_pyramid(img / 255.0)
# DCT features
f1 = []
for im in img_pyr:
f1.append(block_dct(im))
f1 = torch.cat(f1, dim=1)
# gsm features
f2 = global_gsm(img)
# svd features
f3 = []
for im in img_pyr:
col = im2col(im, 5, 'distinct')
_, s, _ = torch.linalg.svd(col, full_matrices=False)
f3.append(s)
f3 = torch.cat(f3, dim=1)
# Random forest regression. Currently not differentiable and only support CPU
preds = torch.ones(b, 1)
for feat, rf in zip([f1, f2, f3], rf_param):
tmp_pred = random_forest_regression(feat, *rf)
preds = torch.cat((preds, tmp_pred), dim=1)
quality = preds @ torch.tensor(linear_param)
return quality.squeeze()
[docs]
def calculate_nrqm(
img: torch.Tensor,
crop_border: int = 0,
test_y_channel: bool = True,
color_space: str = 'yiq',
linear_param: torch.Tensor = None,
rf_params_list: list = None,
**kwargs,
) -> torch.Tensor:
"""Calculate NRQM
Args:
img (Tensor): Input image whose quality needs to be computed.
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the metric calculation.
test_y_channel (Bool): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'.
pretrained_model_path (String): The pretrained model path.
Returns:
Tensor: NIQE result.
"""
if test_y_channel and img.shape[1] == 3:
img = to_y_channel(img, 255, color_space)
if crop_border != 0:
img = img[..., crop_border:-crop_border, crop_border:-crop_border]
nrqm_result = nrqm(img, linear_param, rf_params_list)
return nrqm_result.to(img)
@ARCH_REGISTRY.register()
[docs]
class NRQM(torch.nn.Module):
r"""NRQM metric
Ma, Chao, Chih-Yuan Yang, Xiaokang Yang, and Ming-Hsuan Yang.
"Learning a no-reference quality metric for single-image super-resolution."
Computer Vision and Image Understanding 158 (2017): 1-16.
Args:
- channels (int): Number of processed channel.
- test_y_channel (Boolean): whether to use y channel on ycbcr.
- crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the metric calculation.
- pretrained_model_path (String): The pretrained model path.
"""
def __init__(
self,
test_y_channel: bool = True,
color_space: str = 'yiq',
crop_border: int = 0,
pretrained_model_path: str = None,
) -> None:
super(NRQM, self).__init__()
self.test_y_channel = test_y_channel
self.crop_border = crop_border
self.color_space = color_space
if pretrained_model_path is not None:
pretrained_model_path = pretrained_model_path
else:
pretrained_model_path = load_file_from_url(default_model_urls['url'])
# load model
params = scipy.io.loadmat(pretrained_model_path)['model']
linear_param = params['linear'][0, 0]
rf_params_list = []
for i in range(3):
tmp_list = []
tmp_param = params['rf'][0, 0][0, i][0, 0]
tmp_list.append(tmp_param[0]) # ldau
tmp_list.append(tmp_param[1]) # rdau
tmp_list.append(tmp_param[4]) # threshold value
tmp_list.append(tmp_param[5]) # pred value
tmp_list.append(tmp_param[6]) # best attribute index
rf_params_list.append(tmp_list)
self.linear_param = linear_param
self.rf_params_list = rf_params_list
[docs]
def forward(self, X: torch.Tensor) -> torch.Tensor:
r"""Computation of NRQM metric.
Args:
X: An input tensor. Shape :math:`(N, C, H, W)`.
Returns:
Value of nrqm metric.
"""
score = calculate_nrqm(
X,
self.crop_border,
self.test_y_channel,
self.color_space,
self.linear_param,
self.rf_params_list,
)
return score
@ARCH_REGISTRY.register()
[docs]
class PI(torch.nn.Module):
r"""Perceptual Index (PI), introduced by
Blau, Yochai, Roey Mechrez, Radu Timofte, Tomer Michaeli, and Lihi Zelnik-Manor.
"The 2018 pirm challenge on perceptual image super-resolution."
In Proceedings of the European Conference on Computer Vision (ECCV) Workshops, pp. 0-0. 2018.
Ref url: https://github.com/roimehrez/PIRM2018
It is a combination of NIQE and NRQM: 1/2 * ((10 - NRQM) + NIQE)
Args:
- color_space (str): color space of y channel, default ycbcr.
- crop_border (int): Cropped pixels in each edge of an image, default 4.
"""
def __init__(self, crop_border=4, color_space='ycbcr'):
super(PI, self).__init__()
self.nrqm = NRQM(crop_border=crop_border, color_space=color_space)
self.niqe = NIQE(crop_border=crop_border, color_space=color_space)
[docs]
def forward(self, X: Tensor) -> Tensor:
r"""Computation of PI metric.
Args:
X: An input tensor. Shape :math:`(N, C, H, W)`.
Returns:
Value of PI metric.
"""
nrqm_score = self.nrqm(X)
niqe_score = self.niqe(X)
score = 1 / 2 * (10 - nrqm_score + niqe_score)
return score