r"""Q-ReAlign: a modern, model-agnostic Q-Align visual-quality scorer.
Q-ReAlign reproduces the Q-Align recipe on a Qwen3.5-VL backbone (``model_type:
qwen3_5``): the model is asked to rate quality, and the probability mass it places
on the discrete words ``excellent / good / fair / poor / bad`` is collapsed -- via
the fixed weighting ``[1.0, 0.75, 0.5, 0.25, 0.0]`` -- into a single scalar in
``[0, 1]`` (higher = better).
Three public checkpoints (selected with ``model=``):
mini -> q-future/Q-ReAlign-Mini-0.8B (fast, throughput champion)
lite -> q-future/Q-ReAlign-Lite-4B
pro -> q-future/Q-ReAlign-Pro-9B (best average SRCC/PLCC)
Requires ``transformers>=5.0`` (native qwen3_5 from 5.2; 5.0/5.1 run via the
vendored modeling + shim in ``pyiqa.archs.qrealign``).
Reference:
Wu et al., "Q-Align: Teaching LMMs for Visual Scoring via Discrete Text-Defined
Levels", arXiv:2312.17090. Q-ReAlign: https://github.com/Q-Future/Q-ReAlign
"""
import torch
from torch import nn
import torchvision.transforms.functional as TF
from pyiqa.utils.registry import ARCH_REGISTRY
# best -> worst; weights map each level to a point in [0, 1]
[docs]
LEVELS = ["excellent", "good", "fair", "poor", "bad"]
[docs]
WEIGHTS = [1.0, 0.75, 0.5, 0.25, 0.0]
# named checkpoints
[docs]
MODELS = {
"mini": "q-future/Q-ReAlign-Mini-0.8B",
"lite": "q-future/Q-ReAlign-Lite-4B",
"pro": "q-future/Q-ReAlign-Pro-9B",
}
# (prompt, answer-stem) per task -- matches the Q-ReAlign toolkit's default inference
[docs]
TASKS = {
"quality": ("Can you rate the quality of this picture?", "The quality of the image is"),
"aesthetic": ("Can you rate the aesthetics of this picture?", "The aesthetics of the image is"),
}
def _level_token_ids(tokenizer, names):
"""First-token id of each level word, space-prefixed as it appears after the stem."""
ids = []
for w in names:
tok = None
for cand in (" " + w, w):
t = tokenizer(cand, add_special_tokens=False)["input_ids"]
if t:
tok = t[0]
break
if tok is None:
raise ValueError(f"level word {w!r} did not tokenize to any id")
ids.append(tok)
return ids
@ARCH_REGISTRY.register()
[docs]
class QReAlign(nn.Module):
"""Q-ReAlign multimodal visual-quality scorer (Qwen3.5-VL backbone).
Args:
model (str): one of ``'mini'`` / ``'lite'`` / ``'pro'`` (see ``MODELS``),
or any HuggingFace repo id / local path to a Q-ReAlign checkpoint.
dtype (str): torch dtype passed to ``from_pretrained`` (default ``'auto'``
-> bfloat16 weights).
task (str): default scoring task, ``'quality'`` (IQA) or ``'aesthetic'`` (IAA).
"""
def __init__(self, model="mini", dtype="auto", task="quality"):
super().__init__()
# Make qwen3_5 loadable (native on transformers>=5.2; vendored shim on 5.0/5.1).
# Imported here -- not at module import -- so merely importing pyiqa never
# patches transformers; only constructing this metric does.
from .qrealign import qrealign_compat as compat
self._check_transformers()
compat.ensure_qwen3_5()
self.default_task = task
repo = MODELS.get(model, model)
from transformers import AutoModelForImageTextToText
self.processor = compat.load_processor(repo)
self.model = AutoModelForImageTextToText.from_pretrained(repo, dtype=dtype).eval()
# collapse the level-word logits into a scalar
ids = _level_token_ids(self.processor.tokenizer, LEVELS)
self.register_buffer("level_ids", torch.tensor(ids, dtype=torch.long), persistent=False)
self.register_buffer("level_weights", torch.tensor(WEIGHTS), persistent=False)
# the model only does a single forward pass -> silence sampling-flag noise
gen = getattr(self.model, "generation_config", None)
if gen is not None and not getattr(gen, "do_sample", False):
gen.temperature = None
gen.top_p = None
@staticmethod
def _check_transformers():
import re
import transformers
m = re.match(r"(\d+)\.(\d+)", transformers.__version__ or "0.0")
major, minor = (int(m.group(1)), int(m.group(2))) if m else (0, 0)
if (major, minor) < (5, 0):
raise ImportError(
"The 'qrealign' metric needs transformers>=5.0 (Qwen3.5 backbone); "
f"found {transformers.__version__}. Please `pip install -U \"transformers>=5.2\"`."
)
@torch.no_grad()
[docs]
def forward(self, x, task_=None):
"""Score a batch of images.
Args:
x (torch.Tensor): ``(B, 3, H, W)`` in ``[0, 1]``.
task_ (str): override the default task for this call.
Returns:
torch.Tensor: shape ``(B,)`` quality scores in ``[0, 1]`` (higher = better).
"""
prompt, stem = TASKS[task_ or self.default_task]
dev = self.level_weights.device
scores = []
for i in range(x.shape[0]):
img = TF.to_pil_image(x[i].detach().cpu().clamp(0, 1))
messages = [{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": prompt},
]}]
text = self.processor.apply_chat_template(messages, add_generation_prompt=True) + stem
inputs = self.processor(text=[text], images=[img], return_tensors="pt").to(dev)
logits = self.model(**inputs).logits[0, -1, self.level_ids]
probs = logits.float().softmax(-1)
scores.append((probs * self.level_weights).sum())
return torch.stack(scores).reshape(-1)