Source code for pyiqa.archs.maclip_arch

r"""Beyond Cosine Similarity: Magnitude-Aware CLIP for No-Reference Image Quality Assessment

@article{liao2025beyond,
  title={Beyond Cosine Similarity Magnitude-Aware CLIP for No-Reference Image Quality Assessment},
  author={Liao, Zhicheng and Wu, Dongxu and Shi, Zhenshan and Mai, Sijie and Zhu, Hanwei and Zhu, Lingyu and Jiang, Yuncheng and Chen, Baoliang},
  journal={arXiv preprint arXiv:2511.09948},
  year={2025}
}

Accepted by AAAI 2026.

Reference:
    - Arxiv link: https://arxiv.org/abs/2511.09948
    - Official Github: https://github.com/zhix000/MA-CLIP
"""

import torch
import torch.nn as nn
import clip
import torch.nn.functional as F
from torchvision.transforms import Normalize
import torchvision

from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from pyiqa.archs.clip_model import load


[docs] class CustomCLIP(nn.Module): def __init__(self, backbone: str, device="cpu"): super().__init__() self.clip_model = load(backbone, device) self.encode_image = self.clip_model.encode_image self.encode_text = self.clip_model.encode_text self.logit_scale = self.clip_model.logit_scale
[docs] def forward(self, image, text, pos_embedding=False, text_features=None): image_features_org = self.encode_image(image, pos_embedding) if text_features is None: text_features = self.encode_text(text) # L2 normalize image_features_nrm = image_features_org.norm(dim=-1, keepdim=True) image_features = image_features_org / image_features_nrm text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() return logits_per_image, logits_per_text, image_features_org
@ARCH_REGISTRY.register()
[docs] class MACLIP(nn.Module): def __init__(self, model_type='clipiqa',backbone='RN50',pos_embedding=False) -> None: ''' Args: backbone: CLIP backbone model (default: `RN50`, optional: `ViT-B/32`, `RN101` etc., from `clip_model.py`). ''' super().__init__() self.clip_model = CustomCLIP(backbone=backbone, device='cuda') self.prompt_pairs = clip.tokenize([ 'Good image', 'bad image', 'Sharp image', 'blurry image', 'sharp edges', 'blurry edges', 'High resolution image', 'low resolution image', 'Noise-free image', 'noisy image', ]) self.default_mean = torch.Tensor(OPENAI_CLIP_MEAN).view(1, 3, 1, 1) self.default_std = torch.Tensor(OPENAI_CLIP_STD).view(1, 3, 1, 1) self.model_type = model_type self.pos_embedding = pos_embedding for p in self.clip_model.parameters(): p.requires_grad = False
[docs] def preprocess(self, img): transforms = torchvision.transforms.Compose([ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]) raw_image = transforms(img) unfold = nn.Unfold(kernel_size=(224, 224), stride=128) img = unfold(raw_image).view(1, 3, 224, 224, -1)[0] img = img.permute(3,0,1,2).cuda() img_s = F.interpolate(raw_image, size=(224, 224), mode='bilinear', align_corners=False).to('cuda') img = torch.cat([img, img_s], dim=0) return img
[docs] def box_cox(self, x, lam=0.5, epsilon=1e-6): x = (x) / (x.std(dim=1, keepdim=True) + epsilon) # [B, D] if lam == 0: transformed = torch.log(x+1) else: transformed = ((x + 1) ** lam - 1) / lam return transformed
[docs] def fusion(self, cos, norm, base_cos=1.0, base_norm=0.6, alpha=1.0): ''' Args: box_lam: Lambda parameter for Box-Cox transformation (default: 0.5) base_cos/base_norm: Base weights for fusion of cosine similarity and magnitude cues (default: 1.0/0.6). alpha: Fusion coefficient (default: 1.0) ''' d = cos - norm cos_param = base_cos + alpha * d norm_param = base_norm - alpha * d weights = F.softmax(torch.stack([cos_param, norm_param], dim=-1), dim=-1) w_cos, w_norm = weights.unbind(dim=-1) weighted_metric = w_cos * cos + w_norm * norm return weighted_metric, w_cos, w_norm
[docs] def forward(self, x, box_lam=0.5, base_cos=1.0, base_norm=0.6, alpha=1.0): x = self.preprocess(x) clip_model = self.clip_model.to(x.device) prompts = self.prompt_pairs.to(x.device) logits_per_image, logits_per_text, image_features_org = clip_model(x, prompts, pos_embedding=self.pos_embedding) probs = logits_per_image.reshape(logits_per_image.shape[0], -1, 2).softmax(dim=-1) clipiqa = probs[..., 0].mean(dim=1, keepdim=True) # Magnitude cue computation image_features_org_abs = torch.abs(image_features_org) image_features_org_abs_box = self.box_cox(image_features_org_abs, lam=box_lam) nrm_score2 = image_features_org_abs_box.mean(dim=-1) # Fusion comb, w1, w2 = self.fusion(clipiqa.squeeze(1), nrm_score2, base_cos, base_norm, alpha) comb = torch.mean(comb) return comb