Source code for pyiqa.archs.uranker_arch

r"""URanker model.

This file contains the code for the URanker model, as described in the paper:

    Underwater Ranker: Learn Which Is Better and How to Be Better
    Chunle Guo#, Ruiqi Wu#, Xin Jin, Linghao Han, Zhi Chai, Weidong Zhang, Chongyi Li*
    Proceedings of the AAAI conference on artificial intelligence (AAAI), 2023

Official codes: https://github.com/RQ-Wu/UnderwaterRanker

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum

from timm.models.layers import DropPath, to_2tuple, trunc_normal_

from einops import rearrange
from functools import partial
import math

from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network
from pyiqa.archs.arch_util import get_url_from_name


[docs] default_model_urls = { 'uranker': get_url_from_name('URanker_ckpt-450eb36d.pth'), }
[docs] def padding_img(img): b, c, h, w = img.shape h_out = math.ceil(h / 32) * 32 w_out = math.ceil(w / 32) * 32 left_pad = (w_out - w) // 2 right_pad = w_out - w - left_pad top_pad = (h_out - h) // 2 bottom_pad = h_out - h - top_pad img = nn.ZeroPad2d((left_pad, right_pad, top_pad, bottom_pad))(img) return img
@torch.no_grad()
[docs] def build_historgram(img): b, _, _, _ = img.shape r_his = torch.histc(img[0][0], 64, min=0.0, max=1.0) g_his = torch.histc(img[0][1], 64, min=0.0, max=1.0) b_his = torch.histc(img[0][2], 64, min=0.0, max=1.0) historgram = torch.cat((r_his, g_his, b_his)).unsqueeze(0).unsqueeze(0) for i in range(1, b): r_his = torch.histc(img[i][0], 64, min=0.0, max=1.0) g_his = torch.histc(img[i][1], 64, min=0.0, max=1.0) b_his = torch.histc(img[i][2], 64, min=0.0, max=1.0) historgram_temp = torch.cat((r_his, g_his, b_his)).unsqueeze(0).unsqueeze(0) historgram = torch.cat((historgram, historgram_temp), dim=0) return historgram
[docs] def preprocessing(d_img_org): d_img_org = padding_img(d_img_org) x_his = build_historgram(d_img_org) return d_img_org, x_his
[docs] class Mlp(nn.Module): """Feed-forward network (FFN, a.k.a. MLP) class.""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop)
[docs] def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x
[docs] class ConvRelPosEnc(nn.Module): """Convolutional relative position encoding.""" def __init__(self, Ch, h, window): """ Initialization. Ch: Channels per head. h: Number of heads. window: Window size(s) in convolutional relative positional encoding. It can have two forms: 1. An integer of window size, which assigns all attention heads with the same window size in ConvRelPosEnc. 2. A dict mapping window size to #attention head splits (e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2}) It will apply different window size to the attention head splits. """ super().__init__() if isinstance(window, int): window = {window: h} # Set the same window size for all attention heads. self.window = window elif isinstance(window, dict): self.window = window else: raise ValueError() self.conv_list = nn.ModuleList() self.head_splits = [] for cur_window, cur_head_split in window.items(): dilation = 1 # Use dilation=1 at default. padding_size = ( (cur_window + (cur_window - 1) * (dilation - 1)) // 2 ) # Determine padding size. Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338 cur_conv = nn.Conv2d( cur_head_split * Ch, cur_head_split * Ch, kernel_size=(cur_window, cur_window), padding=(padding_size, padding_size), dilation=(dilation, dilation), groups=cur_head_split * Ch, ) self.conv_list.append(cur_conv) self.head_splits.append(cur_head_split) self.channel_splits = [x * Ch for x in self.head_splits]
[docs] def forward(self, q, v, size): B, h, N, Ch = q.shape H, W = size assert N == 1 + H * W or N == 2 + H * W diff = N - H * W # Convolutional relative position encoding. q_img = q[:, :, diff:, :] # Shape: [B, h, H*W, Ch]. v_img = v[:, :, diff:, :] # Shape: [B, h, H*W, Ch]. v_img = rearrange( v_img, 'B h (H W) Ch -> B (h Ch) H W', H=H, W=W ) # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W]. v_img_list = torch.split( v_img, self.channel_splits, dim=1 ) # Split according to channels. conv_v_img_list = [conv(x) for conv, x in zip(self.conv_list, v_img_list)] conv_v_img = torch.cat(conv_v_img_list, dim=1) conv_v_img = rearrange( conv_v_img, 'B (h Ch) H W -> B h (H W) Ch', h=h ) # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch]. EV_hat_img = q_img * conv_v_img zero = torch.zeros( (B, h, diff, Ch), dtype=q.dtype, layout=q.layout, device=q.device ) EV_hat = torch.cat((zero, EV_hat_img), dim=2) # Shape: [B, h, N, Ch]. return EV_hat
[docs] class FactorAtt_ConvRelPosEnc(nn.Module): """Factorized attention with convolutional relative position encoding class.""" def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, shared_crpe=None, ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) # Shared convolutional relative position encoding. self.crpe = shared_crpe
[docs] def forward(self, x, size): B, N, C = x.shape # Generate Q, K, V. qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) # Shape: [3, B, h, N, Ch]. q, k, v = qkv[0], qkv[1], qkv[2] # Shape: [B, h, N, Ch]. # Factorized attention. k_softmax = k.softmax(dim=2) # Softmax on dim N. k_softmax_T_dot_v = einsum( 'b h n k, b h n v -> b h k v', k_softmax, v ) # Shape: [B, h, Ch, Ch]. factor_att = einsum( 'b h n k, b h k v -> b h n v', q, k_softmax_T_dot_v ) # Shape: [B, h, N, Ch]. # Convolutional relative position encoding. crpe = self.crpe(q, v, size=size) # Shape: [B, h, N, Ch]. # Merge and reshape. x = self.scale * factor_att + crpe x = x.transpose(1, 2).reshape( B, N, C ) # Shape: [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]. # Output projection. x = self.proj(x) x = self.proj_drop(x) return x # Shape: [B, N, C].
[docs] class ConvPosEnc(nn.Module): """Convolutional Position Encoding. Note: This module is similar to the conditional position encoding in CPVT. """ def __init__(self, dim, k=3): super(ConvPosEnc, self).__init__() self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
[docs] def forward(self, x, size): B, N, C = x.shape H, W = size assert N == 1 + H * W or N == 2 + H * W diff = N - H * W # Extract CLS token and image tokens. other_token, img_tokens = ( x[:, :diff], x[:, diff:], ) # Shape: [B, 2, C], [B, H*W, C]. # Depthwise convolution. feat = img_tokens.transpose(1, 2).view(B, C, H, W) x = self.proj(feat) + feat x = x.flatten(2).transpose(1, 2) # Combine with CLS token. x = torch.cat((other_token, x), dim=1) return x
[docs] class SerialBlock(nn.Module): """Serial block class. Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module.""" def __init__( self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None, ): super().__init__() # Conv-Attention. self.cpe = shared_cpe self.norm1 = norm_layer(dim) self.factoratt_crpe = FactorAtt_ConvRelPosEnc( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # MLP. self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, )
[docs] def forward(self, x, size): # Conv-Attention. x = self.cpe(x, size) # Apply convolutional position encoding. cur = self.norm1(x) cur = self.factoratt_crpe( cur, size ) # Apply factorized attention and convolutional relative position encoding. x = x + self.drop_path(cur) # MLP. cur = self.norm2(x) cur = self.mlp(cur) x = x + self.drop_path(cur) return x
[docs] class ParallelBlock(nn.Module): """Parallel block class.""" def __init__( self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpes=None, shared_crpes=None, connect_type='neighbor', ): super().__init__() self.connect_type = connect_type if self.connect_type == 'dynamic': self.alpha1 = nn.Parameter(torch.zeros(1) + 0.05) self.alpha2 = nn.Parameter(torch.zeros(1) + 0.05) self.alpha3 = nn.Parameter(torch.zeros(1) + 0.05) self.alpha4 = nn.Parameter(torch.zeros(1) + 0.05) self.alpha5 = nn.Parameter(torch.zeros(1) + 0.05) self.alpha6 = nn.Parameter(torch.zeros(1) + 0.05) # Conv-Attention. self.cpes = shared_cpes self.norm12 = norm_layer(dims[1]) self.norm13 = norm_layer(dims[2]) self.norm14 = norm_layer(dims[3]) self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc( dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[1], ) self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc( dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[2], ) self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc( dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[3], ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # MLP. self.norm22 = norm_layer(dims[1]) self.norm23 = norm_layer(dims[2]) self.norm24 = norm_layer(dims[3]) assert ( dims[1] == dims[2] == dims[3] ) # In parallel block, we assume dimensions are the same and share the linear transformation. assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3] mlp_hidden_dim = int(dims[1] * mlp_ratios[1]) self.mlp2 = self.mlp3 = self.mlp4 = Mlp( in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, )
[docs] def upsample(self, x, output_size, size): """Feature map up-sampling.""" return self.interpolate(x, output_size=output_size, size=size)
[docs] def downsample(self, x, output_size, size): """Feature map down-sampling.""" return self.interpolate(x, output_size=output_size, size=size)
[docs] def interpolate(self, x, output_size, size): """Feature map interpolation.""" B, N, C = x.shape H, W = size assert N == 1 + H * W or 2 + H * W diff = N - H * W other_token = x[:, :diff, :] img_tokens = x[:, diff:, :] img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) img_tokens = F.interpolate( img_tokens, size=output_size, mode='bilinear' ) # FIXME: May have alignment issue. img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2) out = torch.cat((other_token, img_tokens), dim=1) return out
[docs] def forward(self, x1, x2, x3, x4, sizes): _, (H2, W2), (H3, W3), (H4, W4) = sizes # Conv-Attention. x2 = self.cpes[1](x2, size=(H2, W2)) # Note: x1 is ignored. x3 = self.cpes[2](x3, size=(H3, W3)) x4 = self.cpes[3](x4, size=(H4, W4)) cur2 = self.norm12(x2) cur3 = self.norm13(x3) cur4 = self.norm14(x4) cur2 = self.factoratt_crpe2(cur2, size=(H2, W2)) cur3 = self.factoratt_crpe3(cur3, size=(H3, W3)) cur4 = self.factoratt_crpe4(cur4, size=(H4, W4)) upsample3_2 = self.upsample(cur3, output_size=(H2, W2), size=(H3, W3)) upsample4_3 = self.upsample(cur4, output_size=(H3, W3), size=(H4, W4)) upsample4_2 = self.upsample(cur4, output_size=(H2, W2), size=(H4, W4)) downsample2_3 = self.downsample(cur2, output_size=(H3, W3), size=(H2, W2)) downsample3_4 = self.downsample(cur3, output_size=(H4, W4), size=(H3, W3)) downsample2_4 = self.downsample(cur2, output_size=(H4, W4), size=(H2, W2)) if self.connect_type == 'neighbor': cur2 = cur2 + upsample3_2 cur3 = cur3 + upsample4_3 + downsample2_3 cur4 = cur4 + downsample3_4 elif self.connect_type == 'dense': cur2 = cur2 + upsample3_2 + upsample4_2 cur3 = cur3 + upsample4_3 + downsample2_3 cur4 = cur4 + downsample3_4 + downsample2_4 elif self.connect_type == 'direct': cur2 = cur2 cur3 = cur3 cur4 = cur4 elif self.connect_type == 'dynamic': cur2 = cur2 + self.alpha1 * upsample3_2 + self.alpha2 * upsample4_2 cur3 = cur3 + self.alpha3 * upsample4_3 + self.alpha4 * downsample2_3 cur4 = cur4 + self.alpha5 * downsample3_4 + self.alpha6 * downsample2_4 del ( upsample3_2, upsample4_3, upsample4_2, downsample2_3, downsample2_4, downsample3_4, ) x2 = x2 + self.drop_path(cur2) x3 = x3 + self.drop_path(cur3) x4 = x4 + self.drop_path(cur4) del cur2, cur3, cur4 # MLP. cur2 = self.norm22(x2) cur3 = self.norm23(x3) cur4 = self.norm24(x4) cur2 = self.mlp2(cur2) cur3 = self.mlp3(cur3) cur4 = self.mlp4(cur4) x2 = x2 + self.drop_path(cur2) x3 = x3 + self.drop_path(cur3) x4 = x4 + self.drop_path(cur4) return x1, x2, x3, x4
[docs] class PatchEmbed(nn.Module): """Image to Patch Embedding""" def __init__(self, patch_size=16, in_chans=3, embed_dim=768): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) self.norm = nn.LayerNorm(embed_dim)
[docs] def forward(self, x): _, _, H, W = x.shape out_H, out_W = H // self.patch_size[0], W // self.patch_size[1] x = self.proj(x).flatten(2).transpose(1, 2) out = self.norm(x) return out, (out_H, out_W)
@ARCH_REGISTRY.register()
[docs] class URanker(nn.Module): def __init__( self, patch_size=4, in_chans=3, num_classes=1, embed_dims=[152, 320, 320, 320], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=partial(nn.LayerNorm, eps=1e-6), return_interm_layers=False, out_features=None, crpe_window={3: 2, 5: 3, 7: 3}, add_historgram=True, his_channel=192, connect_type='dynamic', pretrained=True, pretrained_model_path=None, **kwargs, ): super().__init__() self.return_interm_layers = return_interm_layers self.out_features = out_features self.num_classes = num_classes self.add_historgram = add_historgram self.connect_type = connect_type if self.add_historgram: # Historgram embeddings. self.historgram_embed1 = nn.Linear(his_channel, embed_dims[0]) self.historgram_embed2 = nn.Linear(his_channel, embed_dims[1]) self.historgram_embed3 = nn.Linear(his_channel, embed_dims[2]) self.historgram_embed4 = nn.Linear(his_channel, embed_dims[3]) # Patch embeddings. self.patch_embed1 = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0] ) self.patch_embed2 = PatchEmbed( patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1] ) self.patch_embed3 = PatchEmbed( patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2] ) self.patch_embed4 = PatchEmbed( patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3] ) # Class tokens. self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0])) self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1])) self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2])) self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3])) # Convolutional position encodings. self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3) self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3) self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3) self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3) # Convolutional relative position encodings. self.crpe1 = ConvRelPosEnc( Ch=embed_dims[0] // num_heads, h=num_heads, window=crpe_window ) self.crpe2 = ConvRelPosEnc( Ch=embed_dims[1] // num_heads, h=num_heads, window=crpe_window ) self.crpe3 = ConvRelPosEnc( Ch=embed_dims[2] // num_heads, h=num_heads, window=crpe_window ) self.crpe4 = ConvRelPosEnc( Ch=embed_dims[3] // num_heads, h=num_heads, window=crpe_window ) # Enable stochastic depth. dpr = drop_path_rate # Serial blocks 1. self.serial_blocks1 = nn.ModuleList( [ SerialBlock( dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe1, shared_crpe=self.crpe1, ) for _ in range(serial_depths[0]) ] ) # Serial blocks 2. self.serial_blocks2 = nn.ModuleList( [ SerialBlock( dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe2, shared_crpe=self.crpe2, ) for _ in range(serial_depths[1]) ] ) # Serial blocks 3. self.serial_blocks3 = nn.ModuleList( [ SerialBlock( dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe3, shared_crpe=self.crpe3, ) for _ in range(serial_depths[2]) ] ) # Serial blocks 4. self.serial_blocks4 = nn.ModuleList( [ SerialBlock( dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe4, shared_crpe=self.crpe4, ) for _ in range(serial_depths[3]) ] ) # Parallel blocks. self.parallel_depth = parallel_depth if self.parallel_depth > 0: self.parallel_blocks = nn.ModuleList( [ ParallelBlock( dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4], shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4], connect_type=self.connect_type, ) for _ in range(parallel_depth) ] ) # Classification head(s). if not self.return_interm_layers: self.norm1 = norm_layer(embed_dims[0]) self.norm2 = norm_layer(embed_dims[1]) self.norm3 = norm_layer(embed_dims[2]) self.norm4 = norm_layer(embed_dims[3]) if ( self.parallel_depth > 0 ): # CoaT series: Aggregate features of last three scales for classification. assert embed_dims[1] == embed_dims[2] == embed_dims[3] self.head2 = nn.Linear(embed_dims[3], num_classes) self.head3 = nn.Linear(embed_dims[3], num_classes) self.head4 = nn.Linear(embed_dims[3], num_classes) else: self.head2 = nn.Linear(embed_dims[3], num_classes) self.head3 = nn.Linear(embed_dims[3], num_classes) self.head4 = nn.Linear( embed_dims[3], num_classes ) # CoaT-Lite series: Use feature of last scale for classification. # self.pred_weighting = nn.Parameter(torch.rand(3) * 0.33) # Initialize weights. trunc_normal_(self.cls_token1, std=0.02) trunc_normal_(self.cls_token2, std=0.02) trunc_normal_(self.cls_token3, std=0.02) trunc_normal_(self.cls_token4, std=0.02) self.apply(self._init_weights) if pretrained and pretrained_model_path is None: load_pretrained_network(self, default_model_urls['uranker']) elif pretrained_model_path is not None: load_pretrained_network( self, pretrained_model_path, True, weight_keys='params' ) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore
[docs] def no_weight_decay(self): return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'}
[docs] def get_classifier(self): return self.head
[docs] def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = ( nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() )
[docs] def insert_cls(self, x, cls_token): """Insert CLS token.""" cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) return x
[docs] def insert_his(self, x, his_token): x = torch.cat((his_token, x), dim=1) return x
[docs] def remove_token(self, x): """Remove CLS token.""" if self.add_historgram: return x[:, 2:, :] else: return x[:, 1:, :]
[docs] def forward_features(self, x0, x_his): B = x0.shape[0] # Serial blocks 1. x1, (H1, W1) = self.patch_embed1(x0) if self.add_historgram: x1 = self.insert_his(x1, self.historgram_embed1(x_his)) x1 = self.insert_cls(x1, self.cls_token1) for blk in self.serial_blocks1: x1 = blk(x1, size=(H1, W1)) x1_nocls = self.remove_token(x1) x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() # Serial blocks 2. x2, (H2, W2) = self.patch_embed2(x1_nocls) if self.add_historgram: x2 = self.insert_his(x2, self.historgram_embed2(x_his)) x2 = self.insert_cls(x2, self.cls_token2) for blk in self.serial_blocks2: x2 = blk(x2, size=(H2, W2)) x2_nocls = self.remove_token(x2) x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous() # Serial blocks 3. x3, (H3, W3) = self.patch_embed3(x2_nocls) if self.add_historgram: x3 = self.insert_his(x3, self.historgram_embed3(x_his)) x3 = self.insert_cls(x3, self.cls_token3) for blk in self.serial_blocks3: x3 = blk(x3, size=(H3, W3)) x3_nocls = self.remove_token(x3) x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous() # Serial blocks 4. x4, (H4, W4) = self.patch_embed4(x3_nocls) if self.add_historgram: x4 = self.insert_his(x4, self.historgram_embed4(x_his)) x4 = self.insert_cls(x4, self.cls_token4) for blk in self.serial_blocks4: x4 = blk(x4, size=(H4, W4)) x4_nocls = self.remove_token(x4) x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() # Only serial blocks: Early return. if self.parallel_depth == 0: if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). feat_out = {} if 'x1_nocls' in self.out_features: feat_out['x1_nocls'] = x1_nocls if 'x2_nocls' in self.out_features: feat_out['x2_nocls'] = x2_nocls if 'x3_nocls' in self.out_features: feat_out['x3_nocls'] = x3_nocls if 'x4_nocls' in self.out_features: feat_out['x4_nocls'] = x4_nocls return feat_out else: # Return features for classification. x4 = self.norm4(x4) x4_cls = x4[:, 0] return x4_cls # Parallel blocks. for blk in self.parallel_blocks: x1, x2, x3, x4 = blk( x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)] ) if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). feat_out = {} if 'x1_nocls' in self.out_features: x1_nocls = self.remove_token(x1) x1_nocls = ( x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() ) feat_out['x1_nocls'] = x1_nocls if 'x2_nocls' in self.out_features: x2_nocls = self.remove_token(x2) x2_nocls = ( x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous() ) feat_out['x2_nocls'] = x2_nocls if 'x3_nocls' in self.out_features: x3_nocls = self.remove_token(x3) x3_nocls = ( x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous() ) feat_out['x3_nocls'] = x3_nocls if 'x4_nocls' in self.out_features: x4_nocls = self.remove_token(x4) x4_nocls = ( x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() ) feat_out['x4_nocls'] = x4_nocls return feat_out else: x2 = self.norm2(x2) x3 = self.norm3(x3) x4 = self.norm4(x4) x2_cls = x2[:, :1] # Shape: [B, 1, C]. x3_cls = x3[:, :1] x4_cls = x4[:, :1] return x2_cls, x3_cls, x4_cls
[docs] def forward(self, x): x, x_his = preprocessing(x) if ( self.return_interm_layers ): # Return intermediate features (for down-stream tasks). return self.forward_features(x, x_his) else: # Return features for classification. x2, x3, x4 = self.forward_features(x, x_his) pred2 = self.head2(x2) pred3 = self.head3(x3) pred4 = self.head4(x4) x = (pred2 + pred3 + pred4) / 3.0 return x.squeeze(1)