Source code for pyiqa.data.general_fr_dataset
from PIL import Image
from os import path as osp
import torch
from torch.utils import data as data
import torchvision.transforms as tf
from pyiqa.data.transforms import transform_mapping, PairedToTensor
from pyiqa.utils.registry import DATASET_REGISTRY
from .base_iqa_dataset import BaseIQADataset
@DATASET_REGISTRY.register()
[docs]
class GeneralFRDataset(BaseIQADataset):
"""General Full Reference dataset with meta info file."""
[docs]
def init_path_mos(self, opt):
super().init_path_mos(opt)
target_img_folder = opt['dataroot_target']
ref_img_folder = opt.get('dataroot_ref', None)
if ref_img_folder is None:
ref_img_folder = target_img_folder
self.paths_mos = []
for row in self.meta_info.values:
ref_path = osp.join(ref_img_folder, row[0])
img_path = osp.join(target_img_folder, row[1])
mos_label = float(row[2])
self.paths_mos.append([ref_path, img_path, mos_label])
[docs]
def mos_normalize(self, opt):
mos_range = opt.get('mos_range', None)
mos_lower_better = opt.get('lower_better', None)
mos_normalize = opt.get('mos_normalize', False)
if mos_normalize:
assert mos_range is not None and mos_lower_better is not None, (
'mos_range and mos_lower_better should be provided when mos_normalize is True'
)
def normalize(mos_label):
mos_label = (mos_label - mos_range[0]) / (mos_range[1] - mos_range[0])
if mos_lower_better:
mos_label = 1 - mos_label
return mos_label
self.paths_mos = [
item[:2] + [normalize(item[2])] for item in self.paths_mos
]
self.logger.info(
f'mos_label is normalized from {mos_range}, lower_better[{mos_lower_better}] to [0, 1], higher better.'
)
def __getitem__(self, index):
ref_path = self.paths_mos[index][0]
img_path = self.paths_mos[index][1]
mos_label = self.paths_mos[index][2]
img_pil = Image.open(img_path).convert('RGB')
ref_pil = Image.open(ref_path).convert('RGB')
img_pil, ref_pil = self.paired_trans([img_pil, ref_pil])
img_tensor = self.common_trans(img_pil) * self.img_range
ref_tensor = self.common_trans(ref_pil) * self.img_range
mos_label_tensor = torch.Tensor([mos_label])
return {
'img': img_tensor,
'ref_img': ref_tensor,
'mos_label': mos_label_tensor,
'img_path': img_path,
'ref_img_path': ref_path,
}