r"""CLIP-IQA metric, proposed by
Exploring CLIP for Assessing the Look and Feel of Images.
Jianyi Wang, Kelvin C.K. Chan, Chen Change Loy.
AAAI 2023.
Ref url: https://github.com/IceClear/CLIP-IQA
Re-implemented by: Chaofeng Chen (https://github.com/chaofengc) with the following modification:
- We assemble multiple prompts to improve the results of clipiqa model.
"""
import torch
import torch.nn as nn
from .constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from .clip_imports import clip
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_file_from_url, load_pretrained_network
from .clip_model import load
from pyiqa.archs.arch_util import get_url_from_name
[docs]
default_model_urls = {
'clipiqa+': get_url_from_name('CLIP-IQA+_learned_prompts-603f3273.pth'),
'clipiqa+_rn50_512': get_url_from_name('CLIPIQA+_RN50_512-89f5d940.pth'),
'clipiqa+_vitL14_512': get_url_from_name('CLIPIQA+_ViTL14_512-e66488f2.pth'),
}
[docs]
class PromptLearner(nn.Module):
"""
PromptLearner class for learning prompts for CLIP-IQA.
Disclaimer:
This implementation follows exactly the official codes in: https://github.com/IceClear/CLIP-IQA.
We have no idea why some tricks are implemented like this, which include:
1. Using n_ctx prefix characters "X"
2. Appending extra "." at the end
3. Insert the original text embedding at the middle
"""
def __init__(self, clip_model, n_ctx=16) -> None:
"""
Initialize the PromptLearner.
Args:
clip_model (nn.Module): The CLIP model.
n_ctx (int): Number of context tokens. Default is 16.
"""
super().__init__()
# For the following codes about prompts, we follow the official codes to get the same results
prompt_prefix = ' '.join(['X'] * n_ctx) + ' '
init_prompts = [prompt_prefix + 'Good photo..', prompt_prefix + 'Bad photo..']
with torch.no_grad():
txt_token = clip.tokenize(init_prompts)
self.tokenized_prompts = txt_token
init_embedding = clip_model.token_embedding(txt_token)
init_ctx = init_embedding[:, 1 : 1 + n_ctx]
self.ctx = nn.Parameter(init_ctx)
self.n_ctx = n_ctx
self.n_cls = len(init_prompts)
self.name_lens = [
3,
3,
] # hard coded length, which does not include the extra "." at the end
self.register_buffer('token_prefix', init_embedding[:, :1, :]) # SOS
self.register_buffer(
'token_suffix', init_embedding[:, 1 + n_ctx :, :]
) # CLS, EOS
[docs]
def get_prompts_with_middle_class(self):
"""
Get prompts with the original text embedding inserted in the middle.
Returns:
torch.Tensor: The generated prompts.
"""
ctx = self.ctx.to(self.token_prefix)
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
half_n_ctx = self.n_ctx // 2
prompts = []
for i in range(self.n_cls):
name_len = self.name_lens[i]
prefix_i = self.token_prefix[i : i + 1, :, :]
class_i = self.token_suffix[i : i + 1, :name_len, :]
suffix_i = self.token_suffix[i : i + 1, name_len:, :]
ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
ctx_i_half1, # (1, n_ctx//2, dim)
class_i, # (1, name_len, dim)
ctx_i_half2, # (1, n_ctx//2, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
return prompts
[docs]
def forward(self, clip_model):
"""
Forward pass for the PromptLearner.
Args:
clip_model (nn.Module): The CLIP model.
Returns:
torch.Tensor: The output features.
"""
prompts = self.get_prompts_with_middle_class()
x = prompts + clip_model.positional_embedding.type(clip_model.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = clip_model.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = clip_model.ln_final(x).type(clip_model.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = (
x[torch.arange(x.shape[0]), self.tokenized_prompts.argmax(dim=-1)]
@ clip_model.text_projection
)
return x
@ARCH_REGISTRY.register()
[docs]
class CLIPIQA(nn.Module):
"""
CLIPIQA metric class.
Args:
model_type (str): The type of the model. Default is 'clipiqa'.
backbone (str): The backbone model. Default is 'RN50'.
pretrained (bool): Whether to load pretrained weights. Default is True.
pos_embedding (bool): Whether to use positional embedding. Default is False.
"""
def __init__(
self,
model_type='clipiqa',
backbone='RN50',
pretrained=True,
pos_embedding=False,
) -> None:
super().__init__()
self.clip_model = [load(backbone, 'cpu')] # avoid saving clip weights
# Different from original paper, we assemble multiple prompts to improve performance
self.prompt_pairs = clip.tokenize(
[
'Good image',
'bad image',
'Sharp image',
'blurry image',
'sharp edges',
'blurry edges',
'High resolution image',
'low resolution image',
'Noise-free image',
'noisy image',
]
)
self.model_type = model_type
self.pos_embedding = pos_embedding
if 'clipiqa+' in model_type:
self.prompt_learner = PromptLearner(self.clip_model[0])
self.default_mean = torch.Tensor(OPENAI_CLIP_MEAN).view(1, 3, 1, 1)
self.default_std = torch.Tensor(OPENAI_CLIP_STD).view(1, 3, 1, 1)
for p in self.clip_model[0].parameters():
p.requires_grad = False
if pretrained and 'clipiqa+' in model_type:
if model_type == 'clipiqa+' and backbone == 'RN50':
self.prompt_learner.ctx.data = torch.load(
load_file_from_url(default_model_urls['clipiqa+']),
weights_only=False,
)
elif model_type in default_model_urls.keys():
load_pretrained_network(
self, default_model_urls[model_type], True, 'params'
)
else:
raise ValueError(f'No pretrained model for {model_type}')
[docs]
def forward(self, x):
"""
Forward pass for the CLIPIQA model.
Args:
x (torch.Tensor): Input tensor with shape (N, C, H, W).
Returns:
torch.Tensor: The output probabilities.
"""
# preprocess image
x = (x - self.default_mean.to(x)) / self.default_std.to(x)
clip_model = self.clip_model[0].to(x)
if self.model_type == 'clipiqa':
prompts = self.prompt_pairs.to(x.device)
logits_per_image, logits_per_text = clip_model(
x, prompts, pos_embedding=self.pos_embedding
)
elif 'clipiqa+' in self.model_type:
learned_prompt_feature = self.prompt_learner(clip_model)
logits_per_image, logits_per_text = clip_model(
x,
None,
text_features=learned_prompt_feature,
pos_embedding=self.pos_embedding,
)
probs = logits_per_image.reshape(logits_per_image.shape[0], -1, 2).softmax(
dim=-1
)
return probs[..., 0].mean(dim=1, keepdim=True)