r"""FGResQ metric implementation.
Reference:
Sheng, X., Pan, X., Yang, Z., Chen, P., and Li, L.
Fine-grained Image Quality Assessment for Perceptual Image Restoration.
AAAI 2026.
Reference URL:
https://github.com/sxfly99/FGResQ
"""
import warnings
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPVisionModel
from transformers import logging as hf_logging
from torchvision.transforms import functional as TF
from pyiqa.archs.arch_util import clean_state_dict, get_url_from_name
from pyiqa.utils.download_util import load_file_from_url
from pyiqa.utils.registry import ARCH_REGISTRY
from .constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
[docs]
default_model_urls = {
'fgresq': get_url_from_name('FGResQ.pth'),
'degradation': get_url_from_name('FGResQ_Degradation.pth'),
}
[docs]
def load_checkpoint(model_path):
"""Load and normalize a checkpoint state dict.
Args:
model_path (str): Path to the checkpoint file.
Returns:
dict: Cleaned state dict.
"""
checkpoint = torch.load(
model_path,
map_location=torch.device('cpu'),
weights_only=False,
)
if isinstance(checkpoint, dict):
if 'model' in checkpoint and isinstance(checkpoint['model'], dict):
checkpoint = checkpoint['model']
elif 'state_dict' in checkpoint and isinstance(checkpoint['state_dict'], dict):
checkpoint = checkpoint['state_dict']
if not isinstance(checkpoint, dict):
raise TypeError('Checkpoint does not contain a valid state dict.')
return clean_state_dict(checkpoint)
[docs]
def get_pooler_output(model, x):
"""Extract pooled CLIP visual features.
Args:
model (CLIPVisionModel): CLIP vision backbone.
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Pooled feature tensor.
"""
outputs = model(pixel_values=x)
if hasattr(outputs, 'pooler_output'):
return outputs.pooler_output
return outputs['pooler_output']
@ARCH_REGISTRY.register()
[docs]
class FGResQ(nn.Module):
"""FGResQ no-reference image quality model.
Args:
clip_model (str): HuggingFace CLIP vision backbone id.
task_clip_model (str): CLIP backbone for the degradation-aware branch.
clip_freeze (bool): Whether to freeze the main CLIP backbone.
pretrained (bool): Whether to load official pretrained weights.
pretrained_model_path (str | None): Optional local checkpoint path for
the main FGResQ weights.
degradation_model_path (str | None): Optional local checkpoint path for
the degradation branch weights.
input_size (int): Final center crop size.
resize_size (int): Resize size before center crop.
The network returns a single quality score when only ``x0`` is given,
and returns ``quality0``, ``quality1``, ``rank``, and ``rank_prob``
when both ``x0`` and ``x1`` are provided.
default_mean (tuple[float, float, float]): Input normalization mean.
default_std (tuple[float, float, float]): Input normalization std.
score_scale (float): Scale factor used before the sigmoid output head.
"""
def __init__(
self,
clip_model='openai/clip-vit-base-patch16',
task_clip_model='openai/clip-vit-base-patch16',
clip_freeze=True,
pretrained=True,
pretrained_model_path=None,
degradation_model_path=None,
input_size=224,
resize_size=256,
default_mean=OPENAI_CLIP_MEAN,
default_std=OPENAI_CLIP_STD,
score_scale=0.3,
):
super().__init__()
self.input_size = input_size
self.resize_size = resize_size
self.score_scale = score_scale
self.clip_model = self._load_clip_vision_model(clip_model)
if task_clip_model == clip_model:
# Avoid re-loading the same checkpoint twice.
self.task_cls_clip = deepcopy(self.clip_model)
else:
self.task_cls_clip = self._load_clip_vision_model(task_clip_model)
hidden_size = self.clip_model.config.hidden_size
task_hidden_size = self.task_cls_clip.config.hidden_size
if hidden_size != task_hidden_size:
raise ValueError(
'FGResQ requires matching CLIP hidden sizes for both branches, '
f'but got {hidden_size} and {task_hidden_size}.'
)
if clip_freeze:
for param in self.clip_model.parameters():
param.requires_grad = False
self.head = nn.Linear(hidden_size * 3, 1)
self.compare_head = nn.Linear(hidden_size * 6, 3)
self.prompt = nn.Parameter(torch.rand(1, hidden_size))
self.task_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.SiLU(False),
nn.Linear(hidden_size, hidden_size),
)
self.prompt_mlp = nn.Linear(hidden_size, hidden_size)
with torch.no_grad():
self.task_mlp[0].weight.zero_()
self.task_mlp[0].bias.zero_()
self.task_mlp[2].weight.zero_()
self.task_mlp[2].bias.zero_()
self.prompt_mlp.weight.zero_()
self.prompt_mlp.bias.zero_()
for param in self.task_cls_clip.parameters():
param.requires_grad = False
for layer in self.task_cls_clip.vision_model.encoder.layers[-2:]:
for param in layer.parameters():
param.requires_grad = True
self.register_buffer('default_mean', torch.tensor(default_mean).view(1, 3, 1, 1))
self.register_buffer('default_std', torch.tensor(default_std).view(1, 3, 1, 1))
if degradation_model_path is not None:
self.load_degradation_weights(degradation_model_path)
elif pretrained:
self.load_degradation_weights(
load_file_from_url(default_model_urls['degradation'])
)
if pretrained_model_path is not None:
self.load_pretrained_weights(pretrained_model_path)
elif pretrained:
self.load_pretrained_weights(load_file_from_url(default_model_urls['fgresq']))
@staticmethod
def _load_clip_vision_model(model_name):
"""Load CLIP vision weights without verbose task-mismatch reports.
CLIP vision backbones are often loaded from full CLIP checkpoints that
also contain text-branch keys. These are expected to be unused.
"""
previous_verbosity = hf_logging.get_verbosity()
hf_logging.set_verbosity_error()
try:
return CLIPVisionModel.from_pretrained(model_name)
finally:
hf_logging.set_verbosity(previous_verbosity)
@staticmethod
def _summarize_keys(keys, max_items=8):
if not keys:
return '[]'
keys = list(keys)
head = ', '.join(keys[:max_items])
if len(keys) > max_items:
return f'[{head}, ...] (total={len(keys)})'
return f'[{head}]'
@staticmethod
def _extract_vision_state_dict(state_dict):
"""Extract CLIP vision-only keys from mixed checkpoints."""
clip_prefixed = {
key.replace('clip_model.', '', 1): value
for key, value in state_dict.items()
if key.startswith('clip_model.')
}
candidate = clip_prefixed if clip_prefixed else state_dict
vision_only = {
key: value
for key, value in candidate.items()
if key.startswith('vision_model.')
}
return vision_only if vision_only else candidate
@staticmethod
def _filter_expected_missing_keys(missing):
expected_missing = {'default_mean', 'default_std'}
return [key for key in missing if key not in expected_missing]
def _warn_load_mismatch(self, name, missing, unexpected):
missing = self._filter_expected_missing_keys(missing)
if not missing and not unexpected:
return
warnings.warn(
f'{name} loaded with '
f'missing={len(missing)} {self._summarize_keys(missing)}; '
f'unexpected={len(unexpected)} {self._summarize_keys(unexpected)}',
RuntimeWarning,
)
[docs]
def load_degradation_weights(self, model_path):
"""Load degradation-branch weights.
Args:
model_path (str): Path to the degradation checkpoint.
"""
state_dict = load_checkpoint(model_path)
clip_state_dict = self._extract_vision_state_dict(state_dict)
missing, unexpected = self.task_cls_clip.load_state_dict(
clip_state_dict,
strict=False,
)
self._warn_load_mismatch('FGResQ degradation weights', missing, unexpected)
[docs]
def load_pretrained_weights(self, model_path):
"""Load main FGResQ weights.
Args:
model_path (str): Path to the FGResQ checkpoint.
"""
state_dict = load_checkpoint(model_path)
missing, unexpected = self.load_state_dict(state_dict, strict=False)
self._warn_load_mismatch('FGResQ weights', missing, unexpected)
[docs]
def preprocess(self, x):
"""Preprocess an input image tensor.
Args:
x (torch.Tensor): Input tensor with shape ``(N, C, H, W)``.
Returns:
torch.Tensor: Preprocessed tensor.
"""
if x.dim() != 4:
raise ValueError(
f'FGResQ expects a 4D tensor, but got shape {tuple(x.shape)}.'
)
if x.shape[1] == 1:
x = x.repeat(1, 3, 1, 1)
elif x.shape[1] != 3:
raise ValueError(
f'FGResQ expects 1 or 3 channels, but got {x.shape[1]}.'
)
x = F.interpolate(
x,
size=(self.resize_size, self.resize_size),
mode='bilinear',
align_corners=False,
)
x = TF.center_crop(x, self.input_size)
x = (x - self.default_mean) / self.default_std
return x
[docs]
def get_quality_features(self, x):
"""Extract FGResQ quality features.
Args:
x (torch.Tensor): Preprocessed image tensor.
Returns:
torch.Tensor: Concatenated quality features.
"""
features = get_pooler_output(self.clip_model, x)
task_features = get_pooler_output(self.task_cls_clip, x)
task_embedding = torch.softmax(self.task_mlp(task_features), dim=1) * self.prompt
task_embedding = self.prompt_mlp(task_embedding)
return torch.cat(
[features, task_embedding, features + task_embedding],
dim=1,
)
[docs]
def forward_single(self, x):
"""Predict single-image quality.
Args:
x (torch.Tensor): Preprocessed image tensor.
Returns:
torch.Tensor: Predicted quality score.
"""
features = self.get_quality_features(x)
return torch.sigmoid(self.head(features) * self.score_scale)
[docs]
def forward_pair(self, x0, x1):
"""Predict pairwise quality and comparison logits.
Args:
x0 (torch.Tensor): First preprocessed image tensor.
x1 (torch.Tensor): Second preprocessed image tensor.
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Single-image scores
for both inputs and pairwise comparison logits.
"""
features0 = self.get_quality_features(x0)
features1 = self.get_quality_features(x1)
quality0 = torch.sigmoid(self.head(features0) * self.score_scale)
quality1 = torch.sigmoid(self.head(features1) * self.score_scale)
compare_logits = self.compare_head(torch.cat([features0, features1], dim=1))
return quality0, quality1, compare_logits
[docs]
def get_pair_rank(self, compare_logits):
"""Convert comparison logits to a discrete rank label.
Args:
compare_logits (torch.Tensor): Pairwise comparison logits.
Returns:
torch.Tensor: Rank tensor with shape ``(N, 1)`` where
``0=image2_better``, ``1=image1_better``, and
``2=similar_quality``.
"""
compare_probs = torch.softmax(compare_logits, dim=-1)
return compare_probs.argmax(dim=-1, keepdim=True)
[docs]
def get_pair_result(self, quality0, quality1, compare_logits):
"""Format pairwise prediction outputs.
Args:
quality0 (torch.Tensor): Quality of the first image.
quality1 (torch.Tensor): Quality of the second image.
compare_logits (torch.Tensor): Pairwise comparison logits.
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
``quality0``, ``quality1``, ``rank``, and ``rank_prob``.
"""
compare_probs = torch.softmax(compare_logits, dim=-1)
rank = self.get_pair_rank(compare_logits)
rank_prob = compare_probs.gather(dim=-1, index=rank)
return quality0, quality1, rank, rank_prob
[docs]
def forward(self, x0, x1=None):
"""Forward pass for FGResQ.
Args:
x0 (torch.Tensor): First input tensor.
x1 (torch.Tensor | None): Optional second input tensor.
Returns:
torch.Tensor | tuple: Output depends on whether ``x1`` is given.
"""
x0 = self.preprocess(x0)
if x1 is None:
return self.forward_single(x0)
x1 = self.preprocess(x1)
quality0, quality1, compare_logits = self.forward_pair(x0, x1)
return self.get_pair_result(quality0, quality1, compare_logits)