Source code for pyiqa.matlab_utils.padding

import math
import collections.abc
from itertools import repeat
import numpy as np
from typing import Tuple

import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init


def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))

    return parse


[docs] to_1tuple = _ntuple(1)
[docs] to_2tuple = _ntuple(2)
[docs] to_3tuple = _ntuple(3)
[docs] to_4tuple = _ntuple(4)
[docs] to_ntuple = _ntuple
[docs] def symm_pad(im: torch.Tensor, padding: Tuple[int, int, int, int]): """Symmetric padding same as tensorflow. Ref: https://discuss.pytorch.org/t/symmetric-padding/19866/3 """ h, w = im.shape[-2:] left, right, top, bottom = padding x_idx = np.arange(-left, w + right) y_idx = np.arange(-top, h + bottom) def reflect(x, minx, maxx): """Reflects an array around two points making a triangular waveform that ramps up and down, allowing for pad lengths greater than the input length""" rng = maxx - minx double_rng = 2 * rng mod = np.fmod(x - minx, double_rng) normed_mod = np.where(mod < 0, mod + double_rng, mod) out = np.where(normed_mod >= rng, double_rng - normed_mod, normed_mod) + minx return np.array(out, dtype=x.dtype) x_pad = reflect(x_idx, -0.5, w - 0.5) y_pad = reflect(y_idx, -0.5, h - 0.5) xx, yy = np.meshgrid(x_pad, y_pad) return im[..., yy, xx]
[docs] def exact_padding_2d(x, kernel, stride=1, dilation=1, mode='same'): assert len(x.shape) == 4, f'Only support 4D tensor input, but got {x.shape}' kernel = to_2tuple(kernel) stride = to_2tuple(stride) dilation = to_2tuple(dilation) b, c, h, w = x.shape h2 = math.ceil(h / stride[0]) w2 = math.ceil(w / stride[1]) pad_row = (h2 - 1) * stride[0] + (kernel[0] - 1) * dilation[0] + 1 - h pad_col = (w2 - 1) * stride[1] + (kernel[1] - 1) * dilation[1] + 1 - w pad_l, pad_r, pad_t, pad_b = ( pad_col // 2, pad_col - pad_col // 2, pad_row // 2, pad_row - pad_row // 2, ) mode = mode if mode != 'same' else 'constant' if mode != 'symmetric': x = F.pad(x, (pad_l, pad_r, pad_t, pad_b), mode=mode) elif mode == 'symmetric': x = symm_pad(x, (pad_l, pad_r, pad_t, pad_b)) return x
[docs] class ExactPadding2d(nn.Module): r"""This function calculate exact padding values for 4D tensor inputs, and support the same padding mode as tensorflow. Args: kernel (int or tuple): kernel size. stride (int or tuple): stride size. dilation (int or tuple): dilation size, default with 1. mode (srt): padding mode can be ('same', 'symmetric', 'replicate', 'circular') """ def __init__(self, kernel, stride=1, dilation=1, mode='same'): super().__init__() self.kernel = to_2tuple(kernel) self.stride = to_2tuple(stride) self.dilation = to_2tuple(dilation) self.mode = mode
[docs] def forward(self, x): if self.mode is None: return x else: return exact_padding_2d( x, self.kernel, self.stride, self.dilation, self.mode )