Source code for pyiqa.matlab_utils.math_util

r"""Mathematical utilities

Created by: https://github.com/tomrunia/PyTorchSteerablePyramid/blob/master/steerable/math_utils.py

Modified by: Jiadi Mo (https://github.com/JiadiMo)
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import torch


[docs] def abs(x): return torch.sqrt(x[..., 0] ** 2 + x[..., 1] ** 2 + 1e-12)
[docs] def roll_n(X, axis, n): f_idx = tuple( slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim()) ) b_idx = tuple( slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim()) ) front = X[f_idx] back = X[b_idx] return torch.cat([back, front], axis)
[docs] def batch_fftshift2d(x): """Args: x: An complex tensor. Shape :math:`(N, C, H, W)`. Pytroch version >= 1.8.0 """ real, imag = x.real, x.imag for dim in range(1, len(real.size())): n_shift = real.size(dim) // 2 if real.size(dim) % 2 != 0: n_shift += 1 # for odd-sized images real = roll_n(real, axis=dim, n=n_shift) imag = roll_n(imag, axis=dim, n=n_shift) return torch.stack((real, imag), -1) # last dim=2 (real&imag)
[docs] def batch_ifftshift2d(x): """Args: x: An input tensor. Shape :math:`(N, C, H, W, 2)`. Return: An complex tensor. Shape :math:`(N, C, H, W)`. """ real, imag = torch.unbind(x, -1) for dim in range(len(real.size()) - 1, 0, -1): real = roll_n(real, axis=dim, n=real.size(dim) // 2) imag = roll_n(imag, axis=dim, n=imag.size(dim) // 2) return torch.complex(real, imag) # convert to complex (real&imag)
[docs] def prepare_grid(m, n): x = np.linspace( -(m // 2) / (m / 2), (m // 2) / (m / 2) - (1 - m % 2) * 2 / m, num=m ) y = np.linspace( -(n // 2) / (n / 2), (n // 2) / (n / 2) - (1 - n % 2) * 2 / n, num=n ) xv, yv = np.meshgrid(y, x) angle = np.arctan2(yv, xv) rad = np.sqrt(xv**2 + yv**2) rad[m // 2][n // 2] = rad[m // 2][n // 2 - 1] log_rad = np.log2(rad) return log_rad, angle
[docs] def rcosFn(width, position): N = 256 # arbitrary X = np.pi * np.array(range(-N - 1, 2)) / 2 / N Y = np.cos(X) ** 2 Y[0] = Y[1] Y[N + 2] = Y[N + 1] X = position + 2 * width / np.pi * (X + np.pi / 4) return X, Y
[docs] def pointOp(im, Y, X): out = np.interp(im.flatten(), X, Y) return np.reshape(out, im.shape)
[docs] def getlist(coeff): straight = [bands for scale in coeff[1:-1] for bands in scale] straight = [coeff[0]] + straight + [coeff[-1]] return straight