Source code for pyiqa.archs.laion_aes_arch

r"""LAION-Aesthetics Predictor

Introduced by: https://github.com/christophschuhmann/improved-aesthetic-predictor
"""

import torch.nn as nn

import clip
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network, clip_preprocess_tensor


import torchvision.transforms as T
from .constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from pyiqa.archs.arch_util import get_url_from_name


[docs] default_model_urls = {'url': get_url_from_name('sac+logos+ava1-l14-linearMSE.pth')}
[docs] class MLP(nn.Module): def __init__(self, input_size, xcol='emb', ycol='avg_rating'): super().__init__() self.input_size = input_size self.xcol = xcol self.ycol = ycol self.layers = nn.Sequential( nn.Linear(self.input_size, 1024), # nn.ReLU(), nn.Dropout(0.2), nn.Linear(1024, 128), # nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, 64), # nn.ReLU(), nn.Dropout(0.1), nn.Linear(64, 16), # nn.ReLU(), nn.Linear(16, 1), )
[docs] def forward(self, x): return self.layers(x)
@ARCH_REGISTRY.register()
[docs] class LAIONAes(nn.Module): """ LAIONAes is a class that implements a neural network architecture for image quality assessment. The architecture is based on the ViT-L/14 model from the OpenAI CLIP library, and uses an MLP to predict image quality scores. Args: None Returns: A tensor representing the predicted image quality scores. """ def __init__( self, pretrained=True, pretrained_model_path=None, ) -> None: super().__init__() clip_model, _ = clip.load('ViT-L/14') self.mlp = MLP(clip_model.visual.output_dim) self.clip_model = [clip_model] if pretrained_model_path is not None: load_pretrained_network( self, pretrained_model_path, True, weight_keys='params' ) elif pretrained: load_pretrained_network(self.mlp, default_model_urls['url'])
[docs] def forward(self, x): clip_model = self.clip_model[0].to(x) if not self.training: img = clip_preprocess_tensor(x, clip_model) else: img = T.functional.normalize(x, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD) img_emb = clip_model.encode_image(img) img_emb = nn.functional.normalize(img_emb.float(), p=2, dim=-1) score = self.mlp(img_emb) return score