r"""URanker model.
This file contains the code for the URanker model, as described in the paper:
Underwater Ranker: Learn Which Is Better and How to Be Better
Chunle Guo#, Ruiqi Wu#, Xin Jin, Linghao Han, Zhi Chai, Weidong Zhang, Chongyi Li*
Proceedings of the AAAI conference on artificial intelligence (AAAI), 2023
Official codes: https://github.com/RQ-Wu/UnderwaterRanker
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from timm.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange
from functools import partial
import math
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network
from pyiqa.archs.arch_util import get_url_from_name
[docs]
default_model_urls = {
'uranker': get_url_from_name('URanker_ckpt-450eb36d.pth'),
}
[docs]
def padding_img(img):
b, c, h, w = img.shape
h_out = math.ceil(h / 32) * 32
w_out = math.ceil(w / 32) * 32
left_pad = (w_out - w) // 2
right_pad = w_out - w - left_pad
top_pad = (h_out - h) // 2
bottom_pad = h_out - h - top_pad
img = nn.ZeroPad2d((left_pad, right_pad, top_pad, bottom_pad))(img)
return img
@torch.no_grad()
[docs]
def build_historgram(img):
b, _, _, _ = img.shape
r_his = torch.histc(img[0][0], 64, min=0.0, max=1.0)
g_his = torch.histc(img[0][1], 64, min=0.0, max=1.0)
b_his = torch.histc(img[0][2], 64, min=0.0, max=1.0)
historgram = torch.cat((r_his, g_his, b_his)).unsqueeze(0).unsqueeze(0)
for i in range(1, b):
r_his = torch.histc(img[i][0], 64, min=0.0, max=1.0)
g_his = torch.histc(img[i][1], 64, min=0.0, max=1.0)
b_his = torch.histc(img[i][2], 64, min=0.0, max=1.0)
historgram_temp = torch.cat((r_his, g_his, b_his)).unsqueeze(0).unsqueeze(0)
historgram = torch.cat((historgram, historgram_temp), dim=0)
return historgram
[docs]
def preprocessing(d_img_org):
d_img_org = padding_img(d_img_org)
x_his = build_historgram(d_img_org)
return d_img_org, x_his
[docs]
class Mlp(nn.Module):
"""Feed-forward network (FFN, a.k.a. MLP) class."""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
[docs]
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
[docs]
class ConvRelPosEnc(nn.Module):
"""Convolutional relative position encoding."""
def __init__(self, Ch, h, window):
"""
Initialization.
Ch: Channels per head.
h: Number of heads.
window: Window size(s) in convolutional relative positional encoding. It can have two forms:
1. An integer of window size, which assigns all attention heads with the same window size in ConvRelPosEnc.
2. A dict mapping window size to #attention head splits (e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2})
It will apply different window size to the attention head splits.
"""
super().__init__()
if isinstance(window, int):
window = {window: h} # Set the same window size for all attention heads.
self.window = window
elif isinstance(window, dict):
self.window = window
else:
raise ValueError()
self.conv_list = nn.ModuleList()
self.head_splits = []
for cur_window, cur_head_split in window.items():
dilation = 1 # Use dilation=1 at default.
padding_size = (
(cur_window + (cur_window - 1) * (dilation - 1)) // 2
) # Determine padding size. Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
cur_conv = nn.Conv2d(
cur_head_split * Ch,
cur_head_split * Ch,
kernel_size=(cur_window, cur_window),
padding=(padding_size, padding_size),
dilation=(dilation, dilation),
groups=cur_head_split * Ch,
)
self.conv_list.append(cur_conv)
self.head_splits.append(cur_head_split)
self.channel_splits = [x * Ch for x in self.head_splits]
[docs]
def forward(self, q, v, size):
B, h, N, Ch = q.shape
H, W = size
assert N == 1 + H * W or N == 2 + H * W
diff = N - H * W
# Convolutional relative position encoding.
q_img = q[:, :, diff:, :] # Shape: [B, h, H*W, Ch].
v_img = v[:, :, diff:, :] # Shape: [B, h, H*W, Ch].
v_img = rearrange(
v_img, 'B h (H W) Ch -> B (h Ch) H W', H=H, W=W
) # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W].
v_img_list = torch.split(
v_img, self.channel_splits, dim=1
) # Split according to channels.
conv_v_img_list = [conv(x) for conv, x in zip(self.conv_list, v_img_list)]
conv_v_img = torch.cat(conv_v_img_list, dim=1)
conv_v_img = rearrange(
conv_v_img, 'B (h Ch) H W -> B h (H W) Ch', h=h
) # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch].
EV_hat_img = q_img * conv_v_img
zero = torch.zeros(
(B, h, diff, Ch), dtype=q.dtype, layout=q.layout, device=q.device
)
EV_hat = torch.cat((zero, EV_hat_img), dim=2) # Shape: [B, h, N, Ch].
return EV_hat
[docs]
class FactorAtt_ConvRelPosEnc(nn.Module):
"""Factorized attention with convolutional relative position encoding class."""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
shared_crpe=None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
# Shared convolutional relative position encoding.
self.crpe = shared_crpe
[docs]
def forward(self, x, size):
B, N, C = x.shape
# Generate Q, K, V.
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
) # Shape: [3, B, h, N, Ch].
q, k, v = qkv[0], qkv[1], qkv[2] # Shape: [B, h, N, Ch].
# Factorized attention.
k_softmax = k.softmax(dim=2) # Softmax on dim N.
k_softmax_T_dot_v = einsum(
'b h n k, b h n v -> b h k v', k_softmax, v
) # Shape: [B, h, Ch, Ch].
factor_att = einsum(
'b h n k, b h k v -> b h n v', q, k_softmax_T_dot_v
) # Shape: [B, h, N, Ch].
# Convolutional relative position encoding.
crpe = self.crpe(q, v, size=size) # Shape: [B, h, N, Ch].
# Merge and reshape.
x = self.scale * factor_att + crpe
x = x.transpose(1, 2).reshape(
B, N, C
) # Shape: [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C].
# Output projection.
x = self.proj(x)
x = self.proj_drop(x)
return x # Shape: [B, N, C].
[docs]
class ConvPosEnc(nn.Module):
"""Convolutional Position Encoding.
Note: This module is similar to the conditional position encoding in CPVT.
"""
def __init__(self, dim, k=3):
super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
[docs]
def forward(self, x, size):
B, N, C = x.shape
H, W = size
assert N == 1 + H * W or N == 2 + H * W
diff = N - H * W
# Extract CLS token and image tokens.
other_token, img_tokens = (
x[:, :diff],
x[:, diff:],
) # Shape: [B, 2, C], [B, H*W, C].
# Depthwise convolution.
feat = img_tokens.transpose(1, 2).view(B, C, H, W)
x = self.proj(feat) + feat
x = x.flatten(2).transpose(1, 2)
# Combine with CLS token.
x = torch.cat((other_token, x), dim=1)
return x
[docs]
class SerialBlock(nn.Module):
"""Serial block class.
Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module."""
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
shared_cpe=None,
shared_crpe=None,
):
super().__init__()
# Conv-Attention.
self.cpe = shared_cpe
self.norm1 = norm_layer(dim)
self.factoratt_crpe = FactorAtt_ConvRelPosEnc(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
shared_crpe=shared_crpe,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
# MLP.
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
[docs]
def forward(self, x, size):
# Conv-Attention.
x = self.cpe(x, size) # Apply convolutional position encoding.
cur = self.norm1(x)
cur = self.factoratt_crpe(
cur, size
) # Apply factorized attention and convolutional relative position encoding.
x = x + self.drop_path(cur)
# MLP.
cur = self.norm2(x)
cur = self.mlp(cur)
x = x + self.drop_path(cur)
return x
[docs]
class ParallelBlock(nn.Module):
"""Parallel block class."""
def __init__(
self,
dims,
num_heads,
mlp_ratios=[],
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
shared_cpes=None,
shared_crpes=None,
connect_type='neighbor',
):
super().__init__()
self.connect_type = connect_type
if self.connect_type == 'dynamic':
self.alpha1 = nn.Parameter(torch.zeros(1) + 0.05)
self.alpha2 = nn.Parameter(torch.zeros(1) + 0.05)
self.alpha3 = nn.Parameter(torch.zeros(1) + 0.05)
self.alpha4 = nn.Parameter(torch.zeros(1) + 0.05)
self.alpha5 = nn.Parameter(torch.zeros(1) + 0.05)
self.alpha6 = nn.Parameter(torch.zeros(1) + 0.05)
# Conv-Attention.
self.cpes = shared_cpes
self.norm12 = norm_layer(dims[1])
self.norm13 = norm_layer(dims[2])
self.norm14 = norm_layer(dims[3])
self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc(
dims[1],
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
shared_crpe=shared_crpes[1],
)
self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc(
dims[2],
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
shared_crpe=shared_crpes[2],
)
self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc(
dims[3],
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
shared_crpe=shared_crpes[3],
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
# MLP.
self.norm22 = norm_layer(dims[1])
self.norm23 = norm_layer(dims[2])
self.norm24 = norm_layer(dims[3])
assert (
dims[1] == dims[2] == dims[3]
) # In parallel block, we assume dimensions are the same and share the linear transformation.
assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
in_features=dims[1],
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
[docs]
def upsample(self, x, output_size, size):
"""Feature map up-sampling."""
return self.interpolate(x, output_size=output_size, size=size)
[docs]
def downsample(self, x, output_size, size):
"""Feature map down-sampling."""
return self.interpolate(x, output_size=output_size, size=size)
[docs]
def interpolate(self, x, output_size, size):
"""Feature map interpolation."""
B, N, C = x.shape
H, W = size
assert N == 1 + H * W or 2 + H * W
diff = N - H * W
other_token = x[:, :diff, :]
img_tokens = x[:, diff:, :]
img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
img_tokens = F.interpolate(
img_tokens, size=output_size, mode='bilinear'
) # FIXME: May have alignment issue.
img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
out = torch.cat((other_token, img_tokens), dim=1)
return out
[docs]
def forward(self, x1, x2, x3, x4, sizes):
_, (H2, W2), (H3, W3), (H4, W4) = sizes
# Conv-Attention.
x2 = self.cpes[1](x2, size=(H2, W2)) # Note: x1 is ignored.
x3 = self.cpes[2](x3, size=(H3, W3))
x4 = self.cpes[3](x4, size=(H4, W4))
cur2 = self.norm12(x2)
cur3 = self.norm13(x3)
cur4 = self.norm14(x4)
cur2 = self.factoratt_crpe2(cur2, size=(H2, W2))
cur3 = self.factoratt_crpe3(cur3, size=(H3, W3))
cur4 = self.factoratt_crpe4(cur4, size=(H4, W4))
upsample3_2 = self.upsample(cur3, output_size=(H2, W2), size=(H3, W3))
upsample4_3 = self.upsample(cur4, output_size=(H3, W3), size=(H4, W4))
upsample4_2 = self.upsample(cur4, output_size=(H2, W2), size=(H4, W4))
downsample2_3 = self.downsample(cur2, output_size=(H3, W3), size=(H2, W2))
downsample3_4 = self.downsample(cur3, output_size=(H4, W4), size=(H3, W3))
downsample2_4 = self.downsample(cur2, output_size=(H4, W4), size=(H2, W2))
if self.connect_type == 'neighbor':
cur2 = cur2 + upsample3_2
cur3 = cur3 + upsample4_3 + downsample2_3
cur4 = cur4 + downsample3_4
elif self.connect_type == 'dense':
cur2 = cur2 + upsample3_2 + upsample4_2
cur3 = cur3 + upsample4_3 + downsample2_3
cur4 = cur4 + downsample3_4 + downsample2_4
elif self.connect_type == 'direct':
cur2 = cur2
cur3 = cur3
cur4 = cur4
elif self.connect_type == 'dynamic':
cur2 = cur2 + self.alpha1 * upsample3_2 + self.alpha2 * upsample4_2
cur3 = cur3 + self.alpha3 * upsample4_3 + self.alpha4 * downsample2_3
cur4 = cur4 + self.alpha5 * downsample3_4 + self.alpha6 * downsample2_4
del (
upsample3_2,
upsample4_3,
upsample4_2,
downsample2_3,
downsample2_4,
downsample3_4,
)
x2 = x2 + self.drop_path(cur2)
x3 = x3 + self.drop_path(cur3)
x4 = x4 + self.drop_path(cur4)
del cur2, cur3, cur4
# MLP.
cur2 = self.norm22(x2)
cur3 = self.norm23(x3)
cur4 = self.norm24(x4)
cur2 = self.mlp2(cur2)
cur3 = self.mlp3(cur3)
cur4 = self.mlp4(cur4)
x2 = x2 + self.drop_path(cur2)
x3 = x3 + self.drop_path(cur3)
x4 = x4 + self.drop_path(cur4)
return x1, x2, x3, x4
[docs]
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.norm = nn.LayerNorm(embed_dim)
[docs]
def forward(self, x):
_, _, H, W = x.shape
out_H, out_W = H // self.patch_size[0], W // self.patch_size[1]
x = self.proj(x).flatten(2).transpose(1, 2)
out = self.norm(x)
return out, (out_H, out_W)
@ARCH_REGISTRY.register()
[docs]
class URanker(nn.Module):
def __init__(
self,
patch_size=4,
in_chans=3,
num_classes=1,
embed_dims=[152, 320, 320, 320],
serial_depths=[2, 2, 2, 2],
parallel_depth=6,
num_heads=8,
mlp_ratios=[4, 4, 4, 4],
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
return_interm_layers=False,
out_features=None,
crpe_window={3: 2, 5: 3, 7: 3},
add_historgram=True,
his_channel=192,
connect_type='dynamic',
pretrained=True,
pretrained_model_path=None,
**kwargs,
):
super().__init__()
self.return_interm_layers = return_interm_layers
self.out_features = out_features
self.num_classes = num_classes
self.add_historgram = add_historgram
self.connect_type = connect_type
if self.add_historgram:
# Historgram embeddings.
self.historgram_embed1 = nn.Linear(his_channel, embed_dims[0])
self.historgram_embed2 = nn.Linear(his_channel, embed_dims[1])
self.historgram_embed3 = nn.Linear(his_channel, embed_dims[2])
self.historgram_embed4 = nn.Linear(his_channel, embed_dims[3])
# Patch embeddings.
self.patch_embed1 = PatchEmbed(
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]
)
self.patch_embed2 = PatchEmbed(
patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]
)
self.patch_embed3 = PatchEmbed(
patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]
)
self.patch_embed4 = PatchEmbed(
patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]
)
# Class tokens.
self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0]))
self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1]))
self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2]))
self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))
# Convolutional position encodings.
self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3)
self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3)
self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3)
self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3)
# Convolutional relative position encodings.
self.crpe1 = ConvRelPosEnc(
Ch=embed_dims[0] // num_heads, h=num_heads, window=crpe_window
)
self.crpe2 = ConvRelPosEnc(
Ch=embed_dims[1] // num_heads, h=num_heads, window=crpe_window
)
self.crpe3 = ConvRelPosEnc(
Ch=embed_dims[2] // num_heads, h=num_heads, window=crpe_window
)
self.crpe4 = ConvRelPosEnc(
Ch=embed_dims[3] // num_heads, h=num_heads, window=crpe_window
)
# Enable stochastic depth.
dpr = drop_path_rate
# Serial blocks 1.
self.serial_blocks1 = nn.ModuleList(
[
SerialBlock(
dim=embed_dims[0],
num_heads=num_heads,
mlp_ratio=mlp_ratios[0],
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
norm_layer=norm_layer,
shared_cpe=self.cpe1,
shared_crpe=self.crpe1,
)
for _ in range(serial_depths[0])
]
)
# Serial blocks 2.
self.serial_blocks2 = nn.ModuleList(
[
SerialBlock(
dim=embed_dims[1],
num_heads=num_heads,
mlp_ratio=mlp_ratios[1],
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
norm_layer=norm_layer,
shared_cpe=self.cpe2,
shared_crpe=self.crpe2,
)
for _ in range(serial_depths[1])
]
)
# Serial blocks 3.
self.serial_blocks3 = nn.ModuleList(
[
SerialBlock(
dim=embed_dims[2],
num_heads=num_heads,
mlp_ratio=mlp_ratios[2],
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
norm_layer=norm_layer,
shared_cpe=self.cpe3,
shared_crpe=self.crpe3,
)
for _ in range(serial_depths[2])
]
)
# Serial blocks 4.
self.serial_blocks4 = nn.ModuleList(
[
SerialBlock(
dim=embed_dims[3],
num_heads=num_heads,
mlp_ratio=mlp_ratios[3],
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
norm_layer=norm_layer,
shared_cpe=self.cpe4,
shared_crpe=self.crpe4,
)
for _ in range(serial_depths[3])
]
)
# Parallel blocks.
self.parallel_depth = parallel_depth
if self.parallel_depth > 0:
self.parallel_blocks = nn.ModuleList(
[
ParallelBlock(
dims=embed_dims,
num_heads=num_heads,
mlp_ratios=mlp_ratios,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
norm_layer=norm_layer,
shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4],
shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4],
connect_type=self.connect_type,
)
for _ in range(parallel_depth)
]
)
# Classification head(s).
if not self.return_interm_layers:
self.norm1 = norm_layer(embed_dims[0])
self.norm2 = norm_layer(embed_dims[1])
self.norm3 = norm_layer(embed_dims[2])
self.norm4 = norm_layer(embed_dims[3])
if (
self.parallel_depth > 0
): # CoaT series: Aggregate features of last three scales for classification.
assert embed_dims[1] == embed_dims[2] == embed_dims[3]
self.head2 = nn.Linear(embed_dims[3], num_classes)
self.head3 = nn.Linear(embed_dims[3], num_classes)
self.head4 = nn.Linear(embed_dims[3], num_classes)
else:
self.head2 = nn.Linear(embed_dims[3], num_classes)
self.head3 = nn.Linear(embed_dims[3], num_classes)
self.head4 = nn.Linear(
embed_dims[3], num_classes
) # CoaT-Lite series: Use feature of last scale for classification.
# self.pred_weighting = nn.Parameter(torch.rand(3) * 0.33)
# Initialize weights.
trunc_normal_(self.cls_token1, std=0.02)
trunc_normal_(self.cls_token2, std=0.02)
trunc_normal_(self.cls_token3, std=0.02)
trunc_normal_(self.cls_token4, std=0.02)
self.apply(self._init_weights)
if pretrained and pretrained_model_path is None:
load_pretrained_network(self, default_model_urls['uranker'])
elif pretrained_model_path is not None:
load_pretrained_network(
self, pretrained_model_path, True, weight_keys='params'
)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
[docs]
def no_weight_decay(self):
return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'}
[docs]
def get_classifier(self):
return self.head
[docs]
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = (
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
)
[docs]
def insert_cls(self, x, cls_token):
"""Insert CLS token."""
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
return x
[docs]
def insert_his(self, x, his_token):
x = torch.cat((his_token, x), dim=1)
return x
[docs]
def remove_token(self, x):
"""Remove CLS token."""
if self.add_historgram:
return x[:, 2:, :]
else:
return x[:, 1:, :]
[docs]
def forward_features(self, x0, x_his):
B = x0.shape[0]
# Serial blocks 1.
x1, (H1, W1) = self.patch_embed1(x0)
if self.add_historgram:
x1 = self.insert_his(x1, self.historgram_embed1(x_his))
x1 = self.insert_cls(x1, self.cls_token1)
for blk in self.serial_blocks1:
x1 = blk(x1, size=(H1, W1))
x1_nocls = self.remove_token(x1)
x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
# Serial blocks 2.
x2, (H2, W2) = self.patch_embed2(x1_nocls)
if self.add_historgram:
x2 = self.insert_his(x2, self.historgram_embed2(x_his))
x2 = self.insert_cls(x2, self.cls_token2)
for blk in self.serial_blocks2:
x2 = blk(x2, size=(H2, W2))
x2_nocls = self.remove_token(x2)
x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
# Serial blocks 3.
x3, (H3, W3) = self.patch_embed3(x2_nocls)
if self.add_historgram:
x3 = self.insert_his(x3, self.historgram_embed3(x_his))
x3 = self.insert_cls(x3, self.cls_token3)
for blk in self.serial_blocks3:
x3 = blk(x3, size=(H3, W3))
x3_nocls = self.remove_token(x3)
x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
# Serial blocks 4.
x4, (H4, W4) = self.patch_embed4(x3_nocls)
if self.add_historgram:
x4 = self.insert_his(x4, self.historgram_embed4(x_his))
x4 = self.insert_cls(x4, self.cls_token4)
for blk in self.serial_blocks4:
x4 = blk(x4, size=(H4, W4))
x4_nocls = self.remove_token(x4)
x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
# Only serial blocks: Early return.
if self.parallel_depth == 0:
if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
feat_out = {}
if 'x1_nocls' in self.out_features:
feat_out['x1_nocls'] = x1_nocls
if 'x2_nocls' in self.out_features:
feat_out['x2_nocls'] = x2_nocls
if 'x3_nocls' in self.out_features:
feat_out['x3_nocls'] = x3_nocls
if 'x4_nocls' in self.out_features:
feat_out['x4_nocls'] = x4_nocls
return feat_out
else: # Return features for classification.
x4 = self.norm4(x4)
x4_cls = x4[:, 0]
return x4_cls
# Parallel blocks.
for blk in self.parallel_blocks:
x1, x2, x3, x4 = blk(
x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)]
)
if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
feat_out = {}
if 'x1_nocls' in self.out_features:
x1_nocls = self.remove_token(x1)
x1_nocls = (
x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
)
feat_out['x1_nocls'] = x1_nocls
if 'x2_nocls' in self.out_features:
x2_nocls = self.remove_token(x2)
x2_nocls = (
x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
)
feat_out['x2_nocls'] = x2_nocls
if 'x3_nocls' in self.out_features:
x3_nocls = self.remove_token(x3)
x3_nocls = (
x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
)
feat_out['x3_nocls'] = x3_nocls
if 'x4_nocls' in self.out_features:
x4_nocls = self.remove_token(x4)
x4_nocls = (
x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
)
feat_out['x4_nocls'] = x4_nocls
return feat_out
else:
x2 = self.norm2(x2)
x3 = self.norm3(x3)
x4 = self.norm4(x4)
x2_cls = x2[:, :1] # Shape: [B, 1, C].
x3_cls = x3[:, :1]
x4_cls = x4[:, :1]
return x2_cls, x3_cls, x4_cls
[docs]
def forward(self, x):
x, x_his = preprocessing(x)
if (
self.return_interm_layers
): # Return intermediate features (for down-stream tasks).
return self.forward_features(x, x_his)
else: # Return features for classification.
x2, x3, x4 = self.forward_features(x, x_his)
pred2 = self.head2(x2)
pred3 = self.head3(x3)
pred4 = self.head4(x4)
x = (pred2 + pred3 + pred4) / 3.0
return x.squeeze(1)