Source code for pyiqa.data.general_nr_dataset
from PIL import Image
from os import path as osp
import torch
from torch.utils import data as data
from pyiqa.utils.registry import DATASET_REGISTRY
from .base_iqa_dataset import BaseIQADataset
@DATASET_REGISTRY.register()
[docs]
class GeneralNRDataset(BaseIQADataset):
"""General No Reference dataset with meta info file."""
[docs]
def init_path_mos(self, opt):
super().init_path_mos(opt)
target_img_folder = opt['dataroot_target']
self.paths_mos = []
for row in self.meta_info.values:
img_path = osp.join(target_img_folder, row[0])
mos_label = float(row[1])
self.paths_mos.append([img_path, mos_label])
def __getitem__(self, index):
img_path = self.paths_mos[index][0]
mos_label = float(self.paths_mos[index][1])
img_pil = Image.open(img_path).convert('RGB')
img_tensor = self.trans(img_pil) * self.img_range
mos_label_tensor = torch.Tensor([mos_label])
return {'img': img_tensor, 'mos_label': mos_label_tensor, 'img_path': img_path}