Source code for pyiqa.archs.compare2score_arch

r"""Adaptive Image Quality Assessment via Teaching Large Multimodal Model to Compare

Reference:
@inproceedings{zhu2024adaptive,
  title={Adaptive Image Quality Assessment via Teaching Large Multimodal Model to Compare},
  author={Zhu, Hanwei and Wu, Haoning and Li, Yixuan and Zhang, Zicheng and Chen, Baoliang and Zhu, Lingyu and Fang, Yuming and Zhai, Guangtao and Lin, Weisi and Wang, Shiqi},
  booktitle={Conference on Neural Information Processing Systems},
  year={2024},
}

Reference url: https://github.com/Q-Future/Compare2Score
"""

import torch
from torch import nn
import warnings
from .q_align.cmp_modelling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM
from transformers import BitsAndBytesConfig

from .constants import OPENAI_CLIP_MEAN
from pyiqa.utils.registry import ARCH_REGISTRY
import torchvision.transforms.functional as F
from PIL import Image


[docs] def expand2square(pil_img): """Pad a PIL image to square with CLIP-mean background color. Args: pil_img (PIL.Image.Image): Input image. Returns: PIL.Image.Image: Square padded image. """ background_color = tuple(int(x * 255) for x in OPENAI_CLIP_MEAN) width, height = pil_img.size maxwh = max(width, height) result = Image.new(pil_img.mode, (maxwh, maxwh), background_color) result.paste(pil_img, ((maxwh - width) // 2, (maxwh - height) // 2)) return result
@ARCH_REGISTRY.register()
[docs] class Compare2Score(nn.Module): """Compare2Score large multimodal IQA model wrapper. Args: dtype (str): Inference precision mode. Supported values are ``'fp16'``, ``'4bit'``, and ``'8bit'``. Notes: Current implementation supports batch size ``1`` in preprocessing. """ def __init__(self, dtype='fp16') -> None: super().__init__() assert dtype in ['fp16', '4bit', '8bit'], ( f"Invalid dtype {dtype}. Choose from 'nf4', 'int8', or 'fp16'." ) model_kwargs = { 'trust_remote_code': True, 'torch_dtype': torch.float16 if dtype == 'fp16' else None, } if dtype in ['4bit', '8bit']: quant_kwargs = {'load_in_4bit': dtype == '4bit', 'load_in_8bit': dtype == '8bit'} if dtype == '4bit': quant_kwargs.update( { 'bnb_4bit_quant_type': 'nf4', 'bnb_4bit_compute_dtype': torch.float16, } ) try: model_kwargs['quantization_config'] = BitsAndBytesConfig(**quant_kwargs) model_kwargs['torch_dtype'] = torch.float16 except Exception as err: warnings.warn( f"Failed to enable {dtype} quantization ({err}). Falling back to fp16.", RuntimeWarning, ) with warnings.catch_warnings(): warnings.filterwarnings( 'ignore', message=r"The following generation flags are not valid and may be ignored: .*", ) warnings.filterwarnings( 'ignore', message=r"`do_sample` is set to `False`\. However, `temperature` is set to .*", ) warnings.filterwarnings( 'ignore', message=r"`do_sample` is set to `False`\. However, `top_p` is set to .*", ) self.model = MPLUGOwl2LlamaForCausalLM.from_pretrained('q-future/Compare2Score', **model_kwargs) if getattr(self.model, 'generation_config', None) is not None: gen_cfg = self.model.generation_config if not getattr(gen_cfg, 'do_sample', False): gen_cfg.temperature = None gen_cfg.top_p = None
[docs] def preprocess(self, x): """Convert a single-image tensor batch to PIL image. Args: x (torch.Tensor): Input tensor with shape ``(1, 3, H, W)``. Returns: PIL.Image.Image: Converted image. Raises: AssertionError: If batch size is not ``1``. """ assert x.shape[0] == 1, 'Currently, only support batch size 1.' images = F.to_pil_image(x[0]) return images
[docs] def forward(self, x): """Run Compare2Score model inference. Args: x (torch.Tensor): Input image tensor with shape ``(1, 3, H, W)``. Returns: torch.Tensor: Predicted quality score. """ image_tensor = self.preprocess(x) score = self.model.score(image_tensor) return score