Source code for pyiqa.data.ava_dataset

import numpy as np
from PIL import Image
import os

import torch
from torch.utils import data as data

from pyiqa.utils.registry import DATASET_REGISTRY

from .base_iqa_dataset import BaseIQADataset

# avoid possible image read error in AVA dataset
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True


@DATASET_REGISTRY.register()
[docs] class AVADataset(BaseIQADataset): """AVA dataset, proposed by Murray, Naila, Luca Marchesotti, and Florent Perronnin. "AVA: A large-scale database for aesthetic visual analysis." In 2012 IEEE conference on computer vision and pattern recognition (CVPR), pp. 2408-2415. IEEE, 2012. Args: opt (dict): Config for train datasets with the following keys: phase (str): 'train' or 'val'. """
[docs] def init_path_mos(self, opt): super().init_path_mos(opt) target_img_folder = opt['dataroot_target'] self.dataroot = target_img_folder
[docs] def get_split(self, opt): split_index = opt.get('split_index', None) # compatible with previous version using split file # when using split file, previous version will use official_split or split_index=1 if opt.get('split_file', None) is not None: split_index = 'official_split' if split_index is not None: # use val_num for validation val_num = opt.get('val_num', 2000) train_split_paths_mos = [] val_split_paths_mos = [] test_split_paths_mos = [] for i in range(len(self.paths_mos)): if self.meta_info[split_index][i] == 0: # 0 for train train_split_paths_mos.append(self.paths_mos[i]) elif self.meta_info[split_index][i] == 1: # 1 for val val_split_paths_mos.append(self.paths_mos[i]) elif self.meta_info[split_index][i] == 2: # 2 for test test_split_paths_mos.append(self.paths_mos[i]) if len(val_split_paths_mos) < val_num: val_num = val_num - len(val_split_paths_mos) val_split_paths_mos = ( val_split_paths_mos + train_split_paths_mos[-val_num:] ) train_split_paths_mos = train_split_paths_mos[:-val_num] else: train_split_paths_mos = ( train_split_paths_mos + val_split_paths_mos[:-val_num] ) val_split_paths_mos = val_split_paths_mos[-val_num:] if self.phase == 'train': self.paths_mos = train_split_paths_mos elif self.phase == 'val': self.paths_mos = val_split_paths_mos elif self.phase == 'test': self.paths_mos = test_split_paths_mos self.mean_mos = np.array([item[1] for item in self.paths_mos]).mean()
def __getitem__(self, index): img_path = os.path.join(self.dataroot, self.paths_mos[index][0]) mos_label = self.paths_mos[index][1] mos_dist = self.paths_mos[index][2:12] img_pil = Image.open(img_path).convert('RGB') width, height = img_pil.size img_tensor = self.trans(img_pil) img_tensor2 = self.trans(img_pil) mos_label_tensor = torch.Tensor([mos_label]) / 10.0 mos_dist_tensor = torch.Tensor(mos_dist) / sum(mos_dist) if self.opt.get('list_imgs', False): tmp_tensor = torch.zeros((img_tensor.shape[0], 800, 800)) h, w = img_tensor.shape[1:] tmp_tensor[..., :h, :w] = img_tensor return { 'img': tmp_tensor, 'mos_label': mos_label_tensor, 'mos_dist': mos_dist_tensor, 'org_size': torch.tensor([height, width]), 'img_path': img_path, 'mean_mos': torch.tensor(self.mean_mos), } else: return { 'img': img_tensor, 'img2': img_tensor2, 'mos_label': mos_label_tensor, 'mos_dist': mos_dist_tensor, 'org_size': torch.tensor([height, width]), 'img_path': img_path, 'mean_mos': torch.tensor(self.mean_mos), }