Source code for pyiqa.data.pieapp_dataset

from PIL import Image
import os

import torch
from torch.utils import data as data
from pyiqa.utils.registry import DATASET_REGISTRY

import pandas as pd
from .general_fr_dataset import GeneralFRDataset


@DATASET_REGISTRY.register()
[docs] class PieAPPDataset(GeneralFRDataset): """The PieAPP Dataset introduced by: Prashnani, Ekta and Cai, Hong and Mostofi, Yasamin and Sen, Pradeep PieAPP: Perceptual Image-Error Assessment Through Pairwise Preference CVPR2018 url: http://civc.ucsb.edu/graphics/Papers/CVPR2018_PieAPP/ Args: opt (dict): Config for train datasets with the following keys: phase (str): 'train' or 'val'. """
[docs] def init_path_mos(self, opt): self.dataroot = opt['dataroot_target'] if self.phase == 'test': metadata = pd.read_csv( opt['meta_info_file'], usecols=[ 'ref_img_path', 'dist_imgB_path', 'per_img score for dist_imgB', ], ) else: metadata = pd.read_csv(opt['meta_info_file']) self.paths_mos = metadata.values.tolist()
[docs] def get_split(self, opt): super().get_split(opt) # remove duplicates if self.phase == 'test': temp = [] for item in self.paths_mos: if item not in temp: temp.append(item) self.paths_mos = temp
def __getitem__(self, index): ref_path = os.path.join(self.dataroot, self.paths_mos[index][0]) if self.phase == 'test': distB_path = os.path.join(self.dataroot, self.paths_mos[index][1]) else: distA_path = os.path.join(self.dataroot, self.paths_mos[index][1]) distB_path = os.path.join(self.dataroot, self.paths_mos[index][2]) distB_pil = Image.open(distB_path).convert('RGB') ref_img_pil = Image.open(ref_path).convert('RGB') if self.phase != 'test': distA_pil = Image.open(distA_path).convert('RGB') distA_pil, distB_pil, ref_img_pil = self.paired_trans( [distA_pil, distB_pil, ref_img_pil] ) distA_tensor, distB_tensor, ref_tensor = self.common_trans( [distA_pil, distB_pil, ref_img_pil] ) else: distB_pil, ref_img_pil = self.paired_trans([distB_pil, ref_img_pil]) distB_tensor, ref_tensor = self.common_trans([distB_pil, ref_img_pil]) if self.phase == 'train': score = self.paths_mos[index][4] mos_label_tensor = torch.Tensor([score]) distB_score = torch.Tensor([-1]) elif self.phase == 'val': score = self.paths_mos[index][4] mos_label_tensor = torch.Tensor([score]) distB_score = torch.Tensor([-1]) elif self.phase == 'test': per_img_score = self.paths_mos[index][2] distB_score = torch.Tensor([per_img_score]) if self.phase == 'test': return { 'img': distB_tensor, 'ref_img': ref_tensor, 'mos_label': distB_score, 'img_path': distB_path, 'ref_img_path': ref_path, } else: return { 'distB_img': distB_tensor, 'ref_img': ref_tensor, 'distA_img': distA_tensor, 'mos_label': mos_label_tensor, 'distB_per_img_score': distB_score, 'distB_path': distB_path, 'ref_img_path': ref_path, 'distA_path': distA_path, }