import math
import collections.abc
from itertools import repeat
import numpy as np
from typing import Tuple
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
[docs]
def symm_pad(im: torch.Tensor, padding: Tuple[int, int, int, int]):
"""Symmetric padding same as tensorflow.
Ref: https://discuss.pytorch.org/t/symmetric-padding/19866/3
"""
h, w = im.shape[-2:]
left, right, top, bottom = padding
x_idx = np.arange(-left, w + right)
y_idx = np.arange(-top, h + bottom)
def reflect(x, minx, maxx):
"""Reflects an array around two points making a triangular waveform that ramps up
and down, allowing for pad lengths greater than the input length"""
rng = maxx - minx
double_rng = 2 * rng
mod = np.fmod(x - minx, double_rng)
normed_mod = np.where(mod < 0, mod + double_rng, mod)
out = np.where(normed_mod >= rng, double_rng - normed_mod, normed_mod) + minx
return np.array(out, dtype=x.dtype)
x_pad = reflect(x_idx, -0.5, w - 0.5)
y_pad = reflect(y_idx, -0.5, h - 0.5)
xx, yy = np.meshgrid(x_pad, y_pad)
return im[..., yy, xx]
[docs]
def exact_padding_2d(x, kernel, stride=1, dilation=1, mode='same'):
assert len(x.shape) == 4, f'Only support 4D tensor input, but got {x.shape}'
kernel = to_2tuple(kernel)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
b, c, h, w = x.shape
h2 = math.ceil(h / stride[0])
w2 = math.ceil(w / stride[1])
pad_row = (h2 - 1) * stride[0] + (kernel[0] - 1) * dilation[0] + 1 - h
pad_col = (w2 - 1) * stride[1] + (kernel[1] - 1) * dilation[1] + 1 - w
pad_l, pad_r, pad_t, pad_b = (
pad_col // 2,
pad_col - pad_col // 2,
pad_row // 2,
pad_row - pad_row // 2,
)
mode = mode if mode != 'same' else 'constant'
if mode != 'symmetric':
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b), mode=mode)
elif mode == 'symmetric':
x = symm_pad(x, (pad_l, pad_r, pad_t, pad_b))
return x
[docs]
class ExactPadding2d(nn.Module):
r"""This function calculate exact padding values for 4D tensor inputs,
and support the same padding mode as tensorflow.
Args:
kernel (int or tuple): kernel size.
stride (int or tuple): stride size.
dilation (int or tuple): dilation size, default with 1.
mode (srt): padding mode can be ('same', 'symmetric', 'replicate', 'circular')
"""
def __init__(self, kernel, stride=1, dilation=1, mode='same'):
super().__init__()
self.kernel = to_2tuple(kernel)
self.stride = to_2tuple(stride)
self.dilation = to_2tuple(dilation)
self.mode = mode
[docs]
def forward(self, x):
if self.mode is None:
return x
else:
return exact_padding_2d(
x, self.kernel, self.stride, self.dilation, self.mode
)