"""A-FINE architecture for generalized image quality assessment.
Reference:
Chen, D., Wu, T., Ma, K., and Zhang, L.
Toward Generalized Image Quality Assessment: Relaxing the Perfect
Reference Quality Assumption. CVPR 2025.
Project page:
https://github.com/ChrisDud0257/AFINE
This implementation is intended for inference with pretrained checkpoints.
"""
import torch
import torch.nn as nn
import math
from torchvision.transforms.functional import normalize
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.utils.download_util import load_file_from_url
from copy import deepcopy
from .afineclip_model import load
import torch.nn.functional as F
from pyiqa.archs.arch_util import get_url_from_name
[docs]
default_model_urls = {
'afine': get_url_from_name('afine.pth'),
}
[docs]
def scale_finalscore(score, yita1 = 100, yita2 = 0, yita3 = -1.9710, yita4 = -2.3734):
"""Map raw A-FINE score to a bounded, human-readable range.
Args:
score (torch.Tensor): Raw score tensor.
yita1 (float): Upper bound of target range.
yita2 (float): Lower bound of target range.
yita3 (float): Logistic midpoint parameter.
yita4 (float): Logistic scale parameter.
Returns:
torch.Tensor: Scaled score tensor.
"""
exp_pow = -1 * (score - yita3) / (math.fabs(yita4) + 1e-10)
if exp_pow >=10:
scale_score = (yita1 - yita2) * torch.exp(-1 * exp_pow) / (1 + torch.exp(-1 * exp_pow)) + yita2
else:
scale_score = (yita1 - yita2) / (1 + torch.exp(exp_pow)) + yita2
# scale_score = (yita1 - yita2) / (1 + math.exp(-1 * (score - yita3) / (np.abs(yita4)))) + yita2
return scale_score
[docs]
class AFINEQhead(nn.Module):
"""Naturalness head used by A-FINE.
This head aggregates mean and variance statistics from CLIP feature maps and
predicts the no-reference naturalness term.
Args:
chns (tuple[int, ...]): Channel dimensions for input image and feature
levels.
feature_out_channel (int): Number of output channels for the score head.
input_dim (int): Channel dimension of CLIP feature tokens.
hidden_dim (int): Hidden width for projection layers.
mean (tuple[float, float, float]): RGB normalization mean.
std (tuple[float, float, float]): RGB normalization standard deviation.
"""
def __init__(self, chns = (3, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768), feature_out_channel = 1,
input_dim = 768, hidden_dim = 128,
mean = (0.48145466, 0.4578275, 0.40821073), std = (0.26862954, 0.26130258, 0.27577711)):
super(AFINEQhead, self).__init__()
self.chns = chns
self.feature_out_channel = feature_out_channel
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.register_buffer("mean", torch.tensor(mean).view(1,-1,1,1))
self.register_buffer("std", torch.tensor(std).view(1,-1,1,1))
self.proj_feat = nn.Linear(input_dim * 2, hidden_dim)
self.proj_head = nn.Sequential(
nn.Linear(self.chns[0] * 2 + hidden_dim * (len(self.chns) - 1), hidden_dim * 6),
nn.GELU(),
nn.Linear(hidden_dim * 6, self.feature_out_channel)
)
[docs]
def forward(self, x, h_list_x):
x = x * self.std + self.mean
img_feature_x = x.flatten(2).permute(0, 2, 1)
feature_list_x = []
feature_list_x.append(img_feature_x)
for h_x in h_list_x:
feature_list_x.append(F.relu(h_x))
final_feature_list_x = []
for k in range(len(self.chns)):
x_mean = feature_list_x[k].mean(1, keepdim=True)
x_var = ((feature_list_x[k]-x_mean)**2).mean(1, keepdim=True)
concat_x_feature = torch.cat((x_mean.flatten(1), x_var.flatten(1)), dim=1)
if k != 0:
concat_x_feature = self.proj_feat(concat_x_feature)
final_feature_list_x.append(concat_x_feature)
concat_final_feature_lixt_x = torch.cat(final_feature_list_x, dim = 1)
n_x = self.proj_head(concat_final_feature_lixt_x)
return n_x
[docs]
class AFINEDhead(nn.Module):
"""Fidelity head used by A-FINE.
The module computes similarity statistics between distorted and reference
features and outputs a full-reference fidelity term.
Args:
chns (tuple[int, ...]): Channel dimensions for each feature level.
mean (tuple[float, float, float]): RGB normalization mean.
std (tuple[float, float, float]): RGB normalization standard deviation.
"""
def __init__(self, chns = (3, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768),
mean = (0.48145466, 0.4578275, 0.40821073), std = (0.26862954, 0.26130258, 0.27577711)):
super(AFINEDhead, self).__init__()
self.chns = chns
self.register_parameter("alpha", nn.Parameter(torch.randn(1, 1, sum(self.chns)), requires_grad=True))
self.register_parameter("beta", nn.Parameter(torch.randn(1, 1, sum(self.chns)), requires_grad=True))
self.alpha.data.normal_(0.1,0.01)
self.beta.data.normal_(0.1,0.01)
self.softplus = nn.Softplus()
self.register_buffer("mean", torch.tensor(mean).view(1,-1,1,1))
self.register_buffer("std", torch.tensor(std).view(1,-1,1,1))
[docs]
def forward(self, x, y, h_list_x, h_list_y):
### the input image should be generalized back to its original values
x = x * self.std + self.mean
y = y * self.std + self.mean
# print(f"mean is {self.mean}, std is {self.std}")
img_feature_x = x.flatten(2).permute(0, 2, 1)
img_feature_y = y.flatten(2).permute(0, 2, 1)
feature_list_x = []
feature_list_y = []
feature_list_x.append(img_feature_x)
for h_x in h_list_x:
feature_list_x.append(F.relu(h_x))
feature_list_y.append(img_feature_y)
for h_y in h_list_y:
feature_list_y.append(F.relu(h_y))
dist1 = 0
dist2 = 0
c1 = 1e-10
c2 = 1e-10
alpha_ = self.softplus(self.alpha)
beta_ = self.softplus(self.beta)
w_sum = alpha_.sum() + beta_.sum() + 1e-10
alpha = torch.split(alpha_/w_sum, self.chns, dim=2)
beta = torch.split(beta_/w_sum, self.chns, dim=2)
for k in range(len(self.chns)):
x_mean = feature_list_x[k].mean(1, keepdim=True)
y_mean = feature_list_y[k].mean(1, keepdim=True)
S1 = (2*x_mean*y_mean+c1)/(x_mean**2+y_mean**2+c1)
# print(f"feature_list_x{[k]} shape is {feature_list_x[k].shape}, feature_list_y{[k]} shape is {feature_list_y[k].shape}, alpha[{k}] shape is {alpha[k].shape}, S1 shape is {S1.shape}")
dist1 = dist1+(alpha[k]*S1).sum(2,keepdim=True)
x_var = ((feature_list_x[k]-x_mean)**2).mean(1, keepdim=True)
y_var = ((feature_list_y[k]-y_mean)**2).mean(1, keepdim=True)
xy_cov = (feature_list_x[k]*feature_list_y[k]).mean(1,keepdim=True) - x_mean*y_mean
S2 = (2*xy_cov+c2)/(x_var+y_var+c2)
dist2 = dist2+(beta[k]*S2).sum(2,keepdim=True)
score = 1 - (dist1+dist2).squeeze(2)
# print(f"score shape is {score.shape}")
return score
### Non-linear mapping to generalize NR and FR scores to a fixed limitation
[docs]
class AFINENLM_NR_Fit(nn.Module):
"""Nonlinear calibration layer for the naturalness branch."""
def __init__(self, yita1 = 2, yita2 = -2, yita3 = 3.7833, yita4 = 7.5676):
super(AFINENLM_NR_Fit, self).__init__()
self.yita3 = nn.Parameter(torch.tensor(yita3, dtype=torch.float32), requires_grad = True)
self.yita4 = nn.Parameter(torch.tensor(yita4, dtype=torch.float32), requires_grad = True)
self.yita1 = yita1
self.yita2 = yita2
[docs]
def forward(self, x):
# print(f"For NR, self.yita3 is {self.yita3}, self.yita4 is {self.yita4}")
# d_hat = (self.yita1 - self.yita2) / (1 + torch.exp(-1 * (x - self.yita3) / (torch.abs(self.yita4) + 1e-10))) + self.yita2
exp_pow = -1 * (x - self.yita3) / (torch.abs(self.yita4) + 1e-10)
if exp_pow >=10:
d_hat = (self.yita1 - self.yita2) * torch.exp(-1 * exp_pow) / (1 + torch.exp(-1 * exp_pow)) + self.yita2
else:
d_hat = (self.yita1 - self.yita2) / (1 + torch.exp(exp_pow)) + self.yita2
return d_hat
### Non-linear mapping to generalize NR and FR scores to a fixed limitation
[docs]
class AFINENLM_FR_Fit_with_limit(nn.Module):
"""Bounded nonlinear calibration layer for the fidelity branch."""
def __init__(self, yita1 = 2, yita2 = -2, yita3 = -24.1335, yita4 = 8.1093, yita3_upper = -21, yita3_lower = -27, yita4_upper = 9, yita4_lower = 7):
super(AFINENLM_FR_Fit_with_limit, self).__init__()
self.yita3 = nn.Parameter(torch.tensor(yita3, dtype=torch.float32), requires_grad = True)
self.yita4 = nn.Parameter(torch.tensor(yita4, dtype=torch.float32), requires_grad = True)
self.yita1 = yita1
self.yita2 = yita2
self.yita3_upper = yita3_upper
self.yita3_lower = yita3_lower
self.yita4_upper = yita4_upper
self.yita4_lower = yita4_lower
[docs]
def forward(self, x):
yita3_ = torch.clamp(self.yita3, self.yita3_lower, self.yita3_upper)
yita4_ = torch.clamp(self.yita4, self.yita4_lower, self.yita4_upper)
# print(f"For FR, self.yita3 is {self.yita3}, yita3 is {yita3_}, self.yita4 is {self.yita4}, yita4 is {yita4_}")
# d_hat = (self.yita1 - self.yita2) / (1 + torch.exp(-1 * (x - yita3_) / (torch.abs(yita4_) + 1e-10))) + self.yita2
exp_pow = -1 * (x - yita3_) / (torch.abs(yita4_) + 1e-10)
if exp_pow >=10:
d_hat = (self.yita1 - self.yita2) * torch.exp(-1 * exp_pow) / (1 + torch.exp(-1 * exp_pow)) + self.yita2
else:
d_hat = (self.yita1 - self.yita2) / (1 + torch.exp(exp_pow)) + self.yita2
return d_hat
### adapter
[docs]
class AFINELearnLambda(nn.Module):
"""Adaptive fusion layer for naturalness and fidelity terms."""
def __init__(self, k = 5):
super(AFINELearnLambda, self).__init__()
self.k = nn.Parameter(torch.tensor(k, dtype=torch.float32), requires_grad = True)
[docs]
def forward(self, x_nr, ref_nr, xref_fr):
k_ = F.softplus(self.k)
# print(f"self.k is {self.k}, k_ is {k_}")
u = torch.exp(k_*(ref_nr - x_nr)) * x_nr + xref_fr
return u
@ARCH_REGISTRY.register()
[docs]
class AFINE(nn.Module):
"""A-FINE inference model.
Args:
model_type (str): Output type.
Supported values are ``"afine_all_scale"``, ``"afine_all"``,
``"afine_fr"``, and ``"afine_nr"``.
clip_backbone (str): CLIP backbone identifier.
step (int): Kept for compatibility with original config interface.
num_patch (int): Kept for compatibility with original config interface.
pretrained (bool): Kept for compatibility. This implementation expects
a pretrained checkpoint.
pretrained_model_path (str | None): Local checkpoint path. If ``None``,
the model is downloaded from ``url_key``.
url_key (str): Key used to resolve default checkpoint URL.
Example:
>>> metric = AFINE(model_type='afine_all_scale')
>>> dis = torch.rand(1, 3, 224, 224)
>>> ref = torch.rand(1, 3, 224, 224)
>>> score = metric(dis, ref)
"""
def __init__(
self,
model_type='afine_all_scale',
clip_backbone='ViT-B/32',
# clip_backbone='/mnt/bn/chenduchris/pretrained_models/CLIP/ViT-B-32.pt',
step=32,
num_patch=15,
pretrained=True,
# pretrained_model_path='/mnt/bn/chenduchris/pretrained_models/AFINE/afine.pth',
pretrained_model_path=None,
url_key = 'afine'
) -> None:
super().__init__()
self.clip_backbone = clip_backbone
### If you cannot download the pretrained CLIP model in on-line manner when you infer A-FINE, then please manually download it from "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
### After you download the CLIP model, please substitute the following code as self.clip_model = load("/your save path/ViT-B-32.pt", device="cpu", jit = False)
clip_model, _ = load(self.clip_backbone, device="cpu", jit = False)
###afine_all_scale means the final scaled A-FINE score which adaptively combines fidelity term and naturalness term, afine_fr means the fidelity term, afine_nr means the naturalness term
self.model_type = model_type
self.mean = (0.48145466, 0.4578275, 0.40821073)
self.std = (0.26862954, 0.26130258, 0.27577711)
# load our finetuned CLIP
if pretrained_model_path is None:
model_path = default_model_urls[url_key]
if model_path.startswith('https://') or model_path.startswith('http://'):
pretrained_model_path = load_file_from_url(model_path)
finetuned_clip_checkpoint = torch.load(pretrained_model_path, map_location = 'cpu')['finetuned_clip']
clip_model.load_state_dict(finetuned_clip_checkpoint)
self.clip_model = clip_model
# load naturalness term
net_qhead = AFINEQhead()
net_qhead.load_state_dict(torch.load(pretrained_model_path, map_location = 'cpu')['natural'], strict=True)
self.net_qhead = net_qhead
# load fidelity term
net_dhead = AFINEDhead()
net_dhead.load_state_dict(torch.load(pretrained_model_path, map_location = 'cpu')['fidelity'], strict=True)
self.net_dhead = net_dhead
# load non-linear mapping for fidelity term
net_scale_fr = AFINENLM_FR_Fit_with_limit(yita1=2,yita2=-2,yita3=0.5,yita4=0.15,
yita3_upper=0.95,yita3_lower=0.05,yita4_upper=0.70,yita4_lower=0.01)
net_scale_fr.load_state_dict(torch.load(pretrained_model_path, map_location = 'cpu')['fidelity_scale'], strict=True)
self.net_scale_fr = net_scale_fr
# load non-linear mapping for naturalness term
net_scale_nr = AFINENLM_NR_Fit(yita1 = 2, yita2 = -2, yita3 = 4.9592, yita4 = 21.5968)
net_scale_nr.load_state_dict(torch.load(pretrained_model_path, map_location = 'cpu')['natural_scale'], strict=True)
self.net_scale_nr = net_scale_nr
# load adptive term
adapter = AFINELearnLambda(k=5)
adapter.load_state_dict(torch.load(pretrained_model_path, map_location = 'cpu')['adapter'], strict=True)
self.adapter = adapter
self.clip_model.eval()
self.net_qhead.eval()
self.net_dhead.eval()
self.net_scale_fr.eval()
self.net_scale_nr.eval()
self.adapter.eval()
[docs]
def forward(self, dis, ref=None):
"""Run A-FINE scoring.
Args:
dis (torch.Tensor): Distorted image tensor with shape
``(N, 3, H, W)``.
ref (torch.Tensor | None): Optional reference image tensor with the
same shape as ``dis``. If ``None``, ``dis`` is reused.
Returns:
torch.Tensor: Score tensor according to ``model_type``.
Raises:
AssertionError: If image height or width is not divisible by 32.
ValueError: If ``model_type`` is unsupported.
Notes:
Lower values indicate better quality for A-FINE terms.
"""
### note that, dis must path to the distortion image path, while ref must path to the reference image path, you cannot switch them.
# preprocess for distortion image and reference image
dis = dis.squeeze(0)
normalize(dis, self.mean, self.std, inplace=True)
dis = dis.unsqueeze(0)
if ref is None:
ref = deepcopy(dis)
ref = ref.squeeze(0)
normalize(ref, self.mean, self.std, inplace=True)
ref = ref.unsqueeze(0)
# The height and width of all the images must be divisible by 32, since we utilize the pretrained CLIP ViT-B-32 model
assert dis.shape[2] % 32 == 0 and dis.shape[3] % 32 == 0, "The height and width of the input image must be divisible by 32."
_,c,h,w = dis.shape
if h % 32 != 0:
pad_h = 32 - h % 32
else:
pad_h = 0
if w % 32 != 0:
pad_w = 32 - w % 32
else:
pad_w = 0
if pad_h > 0 or pad_w > 0:
dis = F.interpolate(dis, size = (h + pad_h, w + pad_w), mode = 'bicubic', align_corners = False)
ref = F.interpolate(ref, size = (h + pad_h, w + pad_w), mode = 'bicubic', align_corners = False)
# Compute A-FINE scores
# Please note that, for all terms, including the final A-FINE score, the A-FINE fidelity/naturalness term, lower values indicate better quality
# To prevent from numerical overflow, we use 'afine_all_scale' value to indicate the final scaled Full-reference score for (dis, ref)
with torch.no_grad():
cls_dis, feat_dis = self.clip_model.encode_image(dis)
cls_ref, feat_ref = self.clip_model.encode_image(ref)
natural_dis = self.net_qhead(dis, feat_dis)
natural_ref = self.net_qhead(ref, feat_ref)
natural_dis_scale = self.net_scale_nr(natural_dis)
natural_ref_scale = self.net_scale_nr(natural_ref)
fidelity_disref = self.net_dhead(dis, ref, feat_dis, feat_ref)
fidelity_disref_scale = self.net_scale_fr(fidelity_disref)
afine_all = self.adapter(natural_dis_scale, natural_ref_scale, fidelity_disref_scale)
afine_all_scale = scale_finalscore(score = afine_all)
if self.model_type == 'afine_nr':
return natural_dis_scale
elif self.model_type == 'afine_fr':
return fidelity_disref_scale
elif self.model_type == 'afine_all_scale':
return afine_all_scale
elif self.model_type == 'afine_all':
return afine_all
else:
raise ValueError("self.model_type must be afine_nr, afine_fr, afine_all or afine_all_scale.")