Source code for pyiqa.archs.clip_model

import hashlib
import os
import urllib
import warnings
from tqdm import tqdm
from typing import Tuple, Union, List
from collections import OrderedDict

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from torch.hub import get_dir


_MODELS = {
    'RN50': 'https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt',
    'RN101': 'https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt',
    'RN50x4': 'https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt',
    'RN50x16': 'https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt',
    'RN50x64': 'https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt',
    'ViT-B/32': 'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt',
    'ViT-B/16': 'https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt',
    'ViT-L/14': 'https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt',
    'ViT-L/14@336px': 'https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt',
}


def _download(url: str, root: str):
    os.makedirs(root, exist_ok=True)
    filename = os.path.basename(url)

    expected_sha256 = url.split('/')[-2]
    download_target = os.path.join(root, filename)

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f'{download_target} exists and is not a regular file')

    if os.path.isfile(download_target):
        if (
            hashlib.sha256(open(download_target, 'rb').read()).hexdigest()
            == expected_sha256
        ):
            return download_target
        else:
            warnings.warn(
                f'{download_target} exists, but the SHA256 checksum does not match; re-downloading the file'
            )

    with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output:
        with tqdm(
            total=int(source.info().get('Content-Length')),
            ncols=80,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    if (
        hashlib.sha256(open(download_target, 'rb').read()).hexdigest()
        != expected_sha256
    ):
        raise RuntimeError(
            'Model has been downloaded but the SHA256 checksum does not not match'
        )

    return download_target


[docs] def available_models() -> List[str]: """Returns the names of available CLIP models""" return list(_MODELS.keys())
[docs] def load( name: str, device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu', jit: bool = False, download_root: str = None, ): """Load a CLIP model Parameters ---------- name : str A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict device : Union[str, torch.device] The device to put the loaded model jit : bool Whether to load the optimized JIT model or more hackable non-JIT model (default). download_root: str path to download the model files; by default, it uses "~/.cache/clip" Returns ------- model : torch.nn.Module The CLIP model preprocess : Callable[[PIL.Image], torch.Tensor] A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if name in _MODELS: model_path = _download( _MODELS[name], download_root or os.path.join(get_dir(), 'clip') ) elif os.path.isfile(name): model_path = name else: raise RuntimeError( f'Model {name} not found; available models = {available_models()}' ) with open(model_path, 'rb') as opened_file: try: # loading JIT archive model = torch.jit.load( opened_file, map_location=device if jit else 'cpu' ).eval() state_dict = None except RuntimeError: # loading saved state dict if jit: warnings.warn( f'File {model_path} is not a JIT archive. Loading as a state dict instead' ) jit = False state_dict = torch.load(opened_file, map_location='cpu') if not jit: model = build_model(state_dict or model.state_dict()).to(device) if str(device) == 'cpu': model.float() return model # patch the device names device_holder = torch.jit.trace( lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] ) device_node = [ n for n in device_holder.graph.findAllNodes('prim::Constant') if 'Device' in repr(n) ][-1] def patch_device(module): try: graphs = [module.graph] if hasattr(module, 'graph') else [] except RuntimeError: graphs = [] if hasattr(module, 'forward1'): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes('prim::Constant'): if 'value' in node.attributeNames() and str(node['value']).startswith( 'cuda' ): node.copyAttributes(device_node) model.apply(patch_device) patch_device(model.encode_image) patch_device(model.encode_text) # patch dtype to float32 on CPU if str(device) == 'cpu': float_holder = torch.jit.trace( lambda: torch.ones([]).float(), example_inputs=[] ) float_input = list(float_holder.graph.findNode('aten::to').inputs())[1] float_node = float_input.node() def patch_float(module): try: graphs = [module.graph] if hasattr(module, 'graph') else [] except RuntimeError: graphs = [] if hasattr(module, 'forward1'): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes('aten::to'): inputs = list(node.inputs()) for i in [ 1, 2, ]: # dtype can be the second or third argument to aten::to() if inputs[i].node()['value'] == 5: inputs[i].node().copyAttributes(float_node) model.apply(patch_float) patch_float(model.encode_image) patch_float(model.encode_text) model.float() return model
[docs] class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1): super().__init__() # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = None self.stride = stride if stride > 1 or inplanes != planes * Bottleneck.expansion: # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 self.downsample = nn.Sequential( OrderedDict( [ ('-1', nn.AvgPool2d(stride)), ( '0', nn.Conv2d( inplanes, planes * self.expansion, 1, stride=1, bias=False, ), ), ('1', nn.BatchNorm2d(planes * self.expansion)), ] ) )
[docs] def forward(self, x: torch.Tensor): identity = x out = self.relu(self.bn1(self.conv1(x))) out = self.relu(self.bn2(self.conv2(out))) out = self.avgpool(out) out = self.bn3(self.conv3(out)) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out
[docs] class AttentionPool2d(nn.Module): def __init__( self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None ): super().__init__() self.positional_embedding = nn.Parameter( torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 ) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads self.spacial_dim = spacial_dim self.embed_dim = embed_dim
[docs] def forward(self, x, return_token=False, pos_embedding=False): n, c, h, w = x.shape x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( 2, 0, 1 ) # NCHW -> (HW)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC if pos_embedding: positional_embedding_resize = ( F.interpolate( self.positional_embedding.unsqueeze(0).unsqueeze(0), size=(x.size(0), x.size(2)), mode='bicubic', ) .squeeze(0) .squeeze(0) ) x = x + positional_embedding_resize[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x, key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat( [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] ), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0, out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False, ) if return_token: return x[0], x[1:] else: return x[0]
[docs] class ModifiedResNet(nn.Module): """ A ResNet class that is similar to torchvision's but contains the following changes: - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - The final pooling layer is a QKV attention instead of an average pool """ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): super().__init__() self.output_dim = output_dim self.input_resolution = input_resolution # the 3-layer stem self.conv1 = nn.Conv2d( 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(width // 2) self.conv2 = nn.Conv2d( width // 2, width // 2, kernel_size=3, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(width // 2) self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(width) self.avgpool = nn.AvgPool2d(2) self.relu = nn.ReLU(inplace=True) # residual layers self._inplanes = width # this is a *mutable* variable used during construction self.layer1 = self._make_layer(width, layers[0]) self.layer2 = self._make_layer(width * 2, layers[1], stride=2) self.layer3 = self._make_layer(width * 4, layers[2], stride=2) self.layer4 = self._make_layer(width * 8, layers[3], stride=2) self.feature_dim_list = [width, width * 4, width * 8, width * 16, width * 32] embed_dim = width * 32 # the ResNet feature dimension self.attnpool = AttentionPool2d( input_resolution // 32, embed_dim, heads, output_dim ) def _make_layer(self, planes, blocks, stride=1): layers = [Bottleneck(self._inplanes, planes, stride)] self._inplanes = planes * Bottleneck.expansion for _ in range(1, blocks): layers.append(Bottleneck(self._inplanes, planes)) return nn.Sequential(*layers)
[docs] def forward_features(self, x, return_token=False, pos_embedding=False): def stem(x): for conv, bn in [ (self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3), ]: x = self.relu(bn(conv(x))) x = self.avgpool(x) return x x = x.type(self.conv1.weight.dtype) x = stem(x) feat_list = [x] x = self.layer1(x) feat_list += [x] x = self.layer2(x) feat_list += [x] x = self.layer3(x) feat_list += [x] x = self.layer4(x) feat_list += [x] return feat_list
[docs] def forward(self, x, return_token=False, pos_embedding=False): def stem(x): for conv, bn in [ (self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3), ]: x = self.relu(bn(conv(x))) x = self.avgpool(x) return x x = x.type(self.conv1.weight.dtype) x = stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) if return_token: x, tokens = self.attnpool(x, return_token, pos_embedding) return x, tokens else: x = self.attnpool(x, return_token, pos_embedding) return x
[docs] class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16."""
[docs] def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type)
[docs] class QuickGELU(nn.Module):
[docs] def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x)
[docs] class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential( OrderedDict( [ ('c_fc', nn.Linear(d_model, d_model * 4)), ('gelu', QuickGELU()), ('c_proj', nn.Linear(d_model * 4, d_model)), ] ) ) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask
[docs] def attention(self, x: torch.Tensor): self.attn_mask = ( self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None ) return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
[docs] def forward(self, x: torch.Tensor): x = x + self.attention(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x
[docs] class Transformer(nn.Module): def __init__( self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None ): super().__init__() self.width = width self.layers = layers self.resblocks = nn.Sequential( *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] )
[docs] def forward(self, x: torch.Tensor): return self.resblocks(x)
[docs] class VisionTransformer(nn.Module): def __init__( self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, ): super().__init__() self.input_resolution = input_resolution self.output_dim = output_dim self.conv1 = nn.Conv2d( in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False, ) scale = width**-0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) self.positional_embedding = nn.Parameter( scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) ) self.ln_pre = LayerNorm(width) self.transformer = Transformer(width, layers, heads) self.ln_post = LayerNorm(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
[docs] def forward(self, x: torch.Tensor, return_token=False, pos_embedding=False): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat( [ self.class_embedding.to(x.dtype) + torch.zeros( x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device ), x, ], dim=1, ) # shape = [*, grid ** 2 + 1, width] if pos_embedding: positional_embedding_resize = ( F.interpolate( self.positional_embedding.unsqueeze(0).unsqueeze(0), size=(x.size(1), x.size(2)), mode='bicubic', ) .squeeze(0) .squeeze(0) ) x = x + positional_embedding_resize.to(x.dtype) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD token = self.ln_post(x[:, 1:, :]) x = self.ln_post(x[:, 0, :]) if self.proj is not None: x = x @ self.proj if return_token: return x, token else: return x
[docs] class CLIP(nn.Module): def __init__( self, embed_dim: int, # vision image_resolution: int, vision_layers: Union[Tuple[int, int, int, int], int], vision_width: int, vision_patch_size: int, # text context_length: int, vocab_size: int, transformer_width: int, transformer_heads: int, transformer_layers: int, ): super().__init__() self.context_length = context_length if isinstance(vision_layers, (tuple, list)): vision_heads = vision_width * 32 // 64 self.visual = ModifiedResNet( layers=vision_layers, output_dim=embed_dim, heads=vision_heads, input_resolution=image_resolution, width=vision_width, ) else: vision_heads = vision_width // 64 self.visual = VisionTransformer( input_resolution=image_resolution, patch_size=vision_patch_size, width=vision_width, layers=vision_layers, heads=vision_heads, output_dim=embed_dim, ) self.transformer = Transformer( width=transformer_width, layers=transformer_layers, heads=transformer_heads, attn_mask=self.build_attention_mask(), ) self.vocab_size = vocab_size self.token_embedding = nn.Embedding(vocab_size, transformer_width) self.positional_embedding = nn.Parameter( torch.empty(self.context_length, transformer_width) ) self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.initialize_parameters()
[docs] def initialize_parameters(self): nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) if isinstance(self.visual, ModifiedResNet): if self.visual.attnpool is not None: std = self.visual.attnpool.c_proj.in_features**-0.5 nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) for resnet_block in [ self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4, ]: for name, param in resnet_block.named_parameters(): if name.endswith('bn3.weight'): nn.init.zeros_(param) proj_std = (self.transformer.width**-0.5) * ( (2 * self.transformer.layers) ** -0.5 ) attn_std = self.transformer.width**-0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
[docs] def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float('-inf')) mask.triu_(1) # zero out the lower diagonal return mask
@property
[docs] def dtype(self): return self.visual.conv1.weight.dtype
[docs] def encode_image(self, image, pos_embedding): return self.visual(image.type(self.dtype), pos_embedding=pos_embedding)
[docs] def encode_text(self, text): x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return x
[docs] def forward(self, image, text, pos_embedding=False, text_features=None): image_features = self.encode_image(image, pos_embedding) if text_features is None: text_features = self.encode_text(text) # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() # shape = [global_batch_size, global_batch_size] return logits_per_image, logits_per_text
[docs] def convert_weights(model: nn.Module): """Convert applicable model parameters to fp16""" def _convert_weights_to_fp16(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.half() if l.bias is not None: l.bias.data = l.bias.data.half() if isinstance(l, nn.MultiheadAttention): for attr in [ *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], 'in_proj_bias', 'bias_k', 'bias_v', ]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.half() for name in ['text_projection', 'proj']: if hasattr(l, name): attr = getattr(l, name) if attr is not None: attr.data = attr.data.half() model.apply(_convert_weights_to_fp16)
[docs] def build_model(state_dict: dict): vit = 'visual.proj' in state_dict if vit: vision_width = state_dict['visual.conv1.weight'].shape[0] vision_layers = len( [ k for k in state_dict.keys() if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') ] ) vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] grid_size = round( (state_dict['visual.positional_embedding'].shape[0] - 1) ** 0.5 ) image_resolution = vision_patch_size * grid_size else: counts: list = [ len( set( k.split('.')[2] for k in state_dict if k.startswith(f'visual.layer{b}') ) ) for b in [1, 2, 3, 4] ] vision_layers = tuple(counts) vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0] output_width = round( (state_dict['visual.attnpool.positional_embedding'].shape[0] - 1) ** 0.5 ) vision_patch_size = None assert ( output_width**2 + 1 == state_dict['visual.attnpool.positional_embedding'].shape[0] ) image_resolution = output_width * 32 embed_dim = state_dict['text_projection'].shape[1] context_length = state_dict['positional_embedding'].shape[0] vocab_size = state_dict['token_embedding.weight'].shape[0] transformer_width = state_dict['ln_final.weight'].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len( set( k.split('.')[2] for k in state_dict if k.startswith('transformer.resblocks') ) ) model = CLIP( embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, ) for key in ['input_resolution', 'context_length', 'vocab_size']: if key in state_dict: del state_dict[key] convert_weights(model) model.load_state_dict(state_dict) return model.eval()