Source code for pyiqa.archs.clipscore_arch
r"""CLIPScore for no reference image caption matching.
Reference:
@inproceedings{hessel2021clipscore,
title={{CLIPScore:} A Reference-free Evaluation Metric for Image Captioning},
author={Hessel, Jack and Holtzman, Ari and Forbes, Maxwell and Bras, Ronan Le and Choi, Yejin},
booktitle={EMNLP},
year={2021}
}
Reference url: https://github.com/jmhessel/clipscore
Re-implemented by: Chaofeng Chen (https://github.com/chaofengc)
"""
import torch
import torch.nn as nn
from .clip_imports import clip
from pyiqa.utils.registry import ARCH_REGISTRY
from .arch_util import clip_preprocess_tensor
@ARCH_REGISTRY.register()
[docs]
class CLIPScore(nn.Module):
"""Compute CLIPScore between an image and one or more captions.
The implementation follows the original CLIPScore formulation and returns a
non-negative image-text similarity score:
.. math::
s = w \\cdot \\max(\\cos(f_{img}, f_{txt}), 0)
Args:
backbone (str): CLIP backbone name accepted by :mod:`clip`, for example
``"ViT-B/32"``.
w (float): Multiplicative scaling factor applied to cosine similarity.
prefix (str): Text prefix prepended to each caption before tokenization.
Example:
>>> metric = CLIPScore(backbone='ViT-B/32')
>>> img = torch.rand(2, 3, 224, 224)
>>> score = metric(img, ['a dog on grass', 'a city street'])
>>> score.shape
torch.Size([2])
"""
def __init__(self, backbone='ViT-B/32', w=2.5, prefix='A photo depicts') -> None:
super().__init__()
self.clip_model, _ = clip.load(backbone)
self.prefix = prefix
self.w = w
[docs]
def forward(self, img, caption_list=None):
"""Compute CLIPScore for each image-caption pair.
Args:
img (torch.Tensor): Input tensor with shape ``(N, 3, H, W)``.
caption_list (list[str] | None): List of length ``N`` containing
captions paired with each image.
Returns:
torch.Tensor: Score tensor with shape ``(N,)``.
Raises:
AssertionError: If ``caption_list`` is not provided.
"""
assert caption_list is not None, 'caption_list is None'
text = clip.tokenize(
[self.prefix + ' ' + caption for caption in caption_list], truncate=True
).to(img.device)
img_features = self.clip_model.encode_image(
clip_preprocess_tensor(img, self.clip_model)
)
text_features = self.clip_model.encode_text(text)
img_features = img_features / img_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
score = self.w * torch.relu((img_features * text_features).sum(dim=-1))
return score