Source code for pyiqa.data.piq_dataset
import torch
import os
import csv
from PIL import Image
from pyiqa.utils.registry import DATASET_REGISTRY
from pyiqa.utils import get_root_logger
from .general_nr_dataset import GeneralNRDataset
@DATASET_REGISTRY.register()
[docs]
class PIQDataset(GeneralNRDataset):
"""General No Reference dataset with meta info file."""
[docs]
def init_path_mos(self, opt):
logger = get_root_logger()
target_img_folder = opt['dataroot_target']
attr = opt.get('attribute', 'Overall')
assert attr in ['Details', 'Exposure', 'Overall'], (
f'attribute should be in [Details, Exposure, Overall], got {attr}'
)
logger.info(f'Training on PIQ2023 dataset with attribute [{attr}]')
with open(opt['meta_info_file'], 'r') as fin:
csvreader = csv.reader(fin)
name_mos = list(csvreader)[1:]
self.paths_mos = name_mos
self.paths_mos = []
for item in name_mos:
if attr in item[0]:
item[0] = os.path.join(target_img_folder, item[0])
self.paths_mos.append(item)
[docs]
def get_split(self, opt):
"""Get split for PIQ2023 dataset:
1: device split
2: scene split
"""
logger = get_root_logger()
split_index = opt.get('split_index', None)
if split_index is not None:
assert split_index in [1, 2], (
'split indexes should be, 1: device split; 2: scene split'
)
assert self.phase in ['train', 'test'], (
f'PIQDataset has no {self.phase} split'
)
logger.info(
f'Training on PIQ2023 dataset with split [{split_index}](1: device split; 2: scene split)'
)
new_paths_mos = []
for item in self.paths_mos:
if self.phase == 'train' and item[split_index - 3] == 'Train':
new_paths_mos.append(item)
elif self.phase == 'test' and item[split_index - 3] == 'Test':
new_paths_mos.append(item)
self.paths_mos = new_paths_mos
def __getitem__(self, index):
img_path = self.paths_mos[index][0]
mos_label = float(self.paths_mos[index][1])
img_pil = Image.open(img_path).convert('RGB')
img_tensor = self.trans(img_pil) * self.img_range
mos_label_tensor = torch.Tensor([mos_label])
scene_idx = int(self.paths_mos[index][-4])
return {
'img': img_tensor,
'mos_label': mos_label_tensor,
'img_path': img_path,
'scene_idx': scene_idx,
}