"""Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
- https://arxiv.org/pdf/2103.14030
Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
S3 (AutoFormerV2, https://arxiv.org/abs/2111.14725) Swin weights from
- https://github.com/microsoft/Cream/tree/main/AutoFormerV2
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
"""
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models._builder import build_model_with_cfg
from timm.layers import (
PatchEmbed,
Mlp,
DropPath,
to_2tuple,
to_ntuple,
trunc_normal_,
_assert,
)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': None,
'crop_pct': 0.9,
'interpolation': 'bicubic',
'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN,
'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj',
'classifier': 'head',
**kwargs,
}
[docs]
default_cfgs = {
'swin_base_patch4_window12_384': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384),
crop_pct=1.0,
),
'swin_base_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
),
'swin_large_patch4_window12_384': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384),
crop_pct=1.0,
),
'swin_large_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
),
'swin_small_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
),
'swin_tiny_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
),
'swin_base_patch4_window12_384_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
input_size=(3, 384, 384),
crop_pct=1.0,
num_classes=21841,
),
'swin_base_patch4_window7_224_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
num_classes=21841,
),
'swin_large_patch4_window12_384_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
input_size=(3, 384, 384),
crop_pct=1.0,
num_classes=21841,
),
'swin_large_patch4_window7_224_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
num_classes=21841,
),
'swin_s3_tiny_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'
),
'swin_s3_small_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'
),
'swin_s3_base_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'
),
}
[docs]
def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
ntok_new = posemb_new.shape[1]
if num_prefix_tokens:
posemb_prefix, posemb_grid = (
posemb[:, :num_prefix_tokens],
posemb[0, num_prefix_tokens:],
)
ntok_new -= num_prefix_tokens
else:
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
assert len(gs_new) >= 2
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(
posemb_grid, size=gs_new, mode='bicubic', align_corners=False
)
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
return posemb
[docs]
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
"""convert patch embedding weight from manual patchify + linear proj to conv"""
import re
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
if 'visual.class_embedding' in state_dict:
return _convert_openai_clip(state_dict, model)
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
# For old models that I trained prior to conv based patchification
O, I, H, W = model.patch_embed.proj.weight.shape
v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(
v,
model.pos_embed,
0
if getattr(model, 'no_embed_class')
else getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size,
)
elif adapt_layer_scale and 'gamma_' in k:
# remap layer-scale gamma into sub-module (deit3 models)
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
elif 'pre_logits' in k:
# NOTE representation layer removed as not used in latest 21k/1k pretrained weights
continue
out_dict[k] = v
return out_dict
[docs]
def window_partition(x, window_size: int):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows
[docs]
def window_reverse(windows, window_size: int, H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(
B, H // window_size, W // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
[docs]
def get_relative_position_index(win_h, win_w):
# get pair-wise relative position index for each token inside the window
coords = torch.stack(
torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])
) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += win_h - 1 # shift to start from 0
relative_coords[:, :, 1] += win_w - 1
relative_coords[:, :, 0] *= 2 * win_w - 1
return relative_coords.sum(-1) # Wh*Ww, Wh*Ww
[docs]
class WindowAttention(nn.Module):
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
head_dim (int): Number of channels per head (dim // num_heads if not set)
window_size (tuple[int]): The height and width of the window.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
num_heads,
head_dim=None,
window_size=7,
qkv_bias=True,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = to_2tuple(window_size) # Wh, Ww
win_h, win_w = self.window_size
self.window_area = win_h * win_w
self.num_heads = num_heads
head_dim = head_dim or dim // num_heads
attn_dim = head_dim * num_heads
self.scale = head_dim**-0.5
# define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)
)
# get pair-wise relative position index for each token inside the window
self.register_buffer(
'relative_position_index', get_relative_position_index(win_h, win_w)
)
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(attn_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def _get_rel_pos_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
return relative_position_bias.unsqueeze(0)
[docs]
def forward(self, x, mask: Optional[torch.Tensor] = None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn + self._get_rel_pos_bias()
if mask is not None:
num_win = mask.shape[0]
attn = attn.view(
B_ // num_win, num_win, self.num_heads, N, N
) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
[docs]
class PatchMerging(nn.Module):
r"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, out_dim=None, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.out_dim = out_dim or 2 * dim
self.norm = norm_layer(4 * dim)
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
[docs]
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
_assert(L == H * W, 'input feature has wrong size')
_assert(H % 2 == 0 and W % 2 == 0, f'x size ({H}*{W}) are not even.')
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
[docs]
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
head_dim (int): Channels per head (dim // num_heads if not set)
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
"""
def __init__(
self,
dim,
out_dim,
input_resolution,
depth,
num_heads=4,
head_dim=None,
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.grad_checkpointing = False
# build blocks
self.blocks = nn.Sequential(
*[
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
head_dim=head_dim,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, list)
else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution, dim=dim, out_dim=out_dim, norm_layer=norm_layer
)
else:
self.downsample = None
[docs]
def forward(self, x):
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def _create_swin_transformer(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
SwinTransformer,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs,
)
return model
[docs]
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
"""Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k"""
model_kwargs = dict(
patch_size=4,
window_size=12,
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
**kwargs,
)
return _create_swin_transformer(
'swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs
)
[docs]
def swin_base_patch4_window7_224(pretrained=False, **kwargs):
"""Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k"""
model_kwargs = dict(
patch_size=4,
window_size=7,
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
**kwargs,
)
return _create_swin_transformer(
'swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs
)
[docs]
def swin_large_patch4_window12_384(pretrained=False, **kwargs):
"""Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k"""
model_kwargs = dict(
patch_size=4,
window_size=12,
embed_dim=192,
depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
**kwargs,
)
return _create_swin_transformer(
'swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs
)
[docs]
def swin_large_patch4_window7_224(pretrained=False, **kwargs):
"""Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k"""
model_kwargs = dict(
patch_size=4,
window_size=7,
embed_dim=192,
depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
**kwargs,
)
return _create_swin_transformer(
'swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs
)
[docs]
def swin_small_patch4_window7_224(pretrained=False, **kwargs):
"""Swin-S @ 224x224, trained ImageNet-1k"""
model_kwargs = dict(
patch_size=4,
window_size=7,
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
**kwargs,
)
return _create_swin_transformer(
'swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs
)
[docs]
def swin_tiny_patch4_window7_224(pretrained=False, **kwargs):
"""Swin-T @ 224x224, trained ImageNet-1k"""
model_kwargs = dict(
patch_size=4,
window_size=7,
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
**kwargs,
)
return _create_swin_transformer(
'swin_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs
)
[docs]
def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs):
"""Swin-B @ 384x384, trained ImageNet-22k"""
model_kwargs = dict(
patch_size=4,
window_size=12,
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
**kwargs,
)
return _create_swin_transformer(
'swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs
)
[docs]
def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs):
"""Swin-B @ 224x224, trained ImageNet-22k"""
model_kwargs = dict(
patch_size=4,
window_size=7,
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
**kwargs,
)
return _create_swin_transformer(
'swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs
)
[docs]
def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs):
"""Swin-L @ 384x384, trained ImageNet-22k"""
model_kwargs = dict(
patch_size=4,
window_size=12,
embed_dim=192,
depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
**kwargs,
)
return _create_swin_transformer(
'swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs
)
[docs]
def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
"""Swin-L @ 224x224, trained ImageNet-22k"""
model_kwargs = dict(
patch_size=4,
window_size=7,
embed_dim=192,
depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
**kwargs,
)
return _create_swin_transformer(
'swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs
)
[docs]
def swin_s3_tiny_224(pretrained=False, **kwargs):
"""Swin-S3-T @ 224x224, ImageNet-1k. https://arxiv.org/abs/2111.14725"""
model_kwargs = dict(
patch_size=4,
window_size=(7, 7, 14, 7),
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
**kwargs,
)
return _create_swin_transformer(
'swin_s3_tiny_224', pretrained=pretrained, **model_kwargs
)
[docs]
def swin_s3_small_224(pretrained=False, **kwargs):
"""Swin-S3-S @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725"""
model_kwargs = dict(
patch_size=4,
window_size=(14, 14, 14, 7),
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
**kwargs,
)
return _create_swin_transformer(
'swin_s3_small_224', pretrained=pretrained, **model_kwargs
)
[docs]
def swin_s3_base_224(pretrained=False, **kwargs):
"""Swin-S3-B @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725"""
model_kwargs = dict(
patch_size=4,
window_size=(7, 7, 14, 7),
embed_dim=96,
depths=(2, 2, 30, 2),
num_heads=(3, 6, 12, 24),
**kwargs,
)
return _create_swin_transformer(
'swin_s3_base_224', pretrained=pretrained, **model_kwargs
)
[docs]
def create_swin(name, **kwargs):
return eval(name)(pretrained_cfg=default_cfgs[name], **kwargs)