Source code for pyiqa.archs.psnr_arch

r"""Peak signal-to-noise ratio (PSNR) Metric

Created by: https://github.com/photosynthesis-team/piq

Modified by: Jiadi Mo (https://github.com/JiadiMo)

Refer to:
    Wikipedia from https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
    QIQA from https://github.com/francois-rozet/piqa/blob/master/piqa/psnr.py

"""

import torch
import torch.nn as nn

from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.utils.color_util import to_y_channel


[docs] def psnr(x, y, test_y_channel=False, data_range=1.0, eps=1e-8, color_space='yiq'): r"""Compute Peak Signal-to-Noise Ratio for a batch of images. Supports both greyscale and color images with RGB channel order. Args: - x: An input tensor. Shape :math:`(N, C, H, W)`. - y: A target tensor. Shape :math:`(N, C, H, W)`. - test_y_channel (Boolean): Convert RGB image to YCbCr format and computes PSNR only on luminance channel if `True`. Compute on all 3 channels otherwise. - data_range: Maximum value range of images (default 1.0). Returns: PSNR Index of similarity between two images. """ if (x.shape[1] == 3) and test_y_channel: # Convert RGB image to YCbCr and use Y-channel x = to_y_channel(x, data_range, color_space) y = to_y_channel(y, data_range, color_space) mse = torch.mean((x - y) ** 2, dim=[1, 2, 3]) score = 10 * torch.log10(data_range**2 / (mse + eps)) return score
@ARCH_REGISTRY.register()
[docs] class PSNR(nn.Module): r""" Args: - X, Y (torch.Tensor): distorted image and reference image tensor with shape (B, 3, H, W) - test_y_channel (Boolean): Convert RGB image to YCbCr format and computes PSNR only on luminance channel if `True`. Compute on all 3 channels otherwise. - kwargs: other parameters, including - data_range: maximum numeric value - eps: small constant for numeric stability Return: score (torch.Tensor): (B, 1) """ def __init__(self, test_y_channel=False, crop_border=0, **kwargs): super().__init__() self.test_y_channel = test_y_channel self.kwargs = kwargs self.crop_border = crop_border
[docs] def forward(self, X, Y): assert X.shape == Y.shape, ( f'Input and reference images should have the same shape, but got {X.shape} and {Y.shape}' ) if self.crop_border != 0: crop_border = self.crop_border X = X[..., crop_border:-crop_border, crop_border:-crop_border] Y = Y[..., crop_border:-crop_border, crop_border:-crop_border] score = psnr(X, Y, self.test_y_channel, **self.kwargs) return score