Source code for pyiqa.data
import importlib
import numpy as np
import random
import re
import torch
import torch.utils.data
from copy import deepcopy
from functools import partial
from os import path as osp
from pathlib import Path
from pyiqa.data.prefetch_dataloader import PrefetchDataLoader
from pyiqa.utils import get_root_logger, scandir
from pyiqa.utils.dist_util import get_dist_info
from pyiqa.utils.registry import DATASET_REGISTRY
__all__ = ['build_dataset', 'build_dataloader']
_DATA_PACKAGE = __package__
_DATA_FOLDER = Path(__file__).resolve().parent
_ALL_DATASET_IMPORTED = False
def _camel_to_snake(name: str) -> str:
s1 = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
def _lazy_import_dataset(dataset_type: str) -> None:
global _ALL_DATASET_IMPORTED
stem = dataset_type[:-8] if dataset_type.endswith('_dataset') else dataset_type
candidates = (
f'{_DATA_PACKAGE}.{stem}_dataset',
f'{_DATA_PACKAGE}.{stem.lower()}_dataset',
f'{_DATA_PACKAGE}.{_camel_to_snake(stem)}_dataset',
)
for module_name in dict.fromkeys(candidates):
try:
importlib.import_module(module_name)
except ModuleNotFoundError as error:
if error.name != module_name:
raise
if dataset_type in DATASET_REGISTRY:
return
if _ALL_DATASET_IMPORTED:
return
# Compatibility fallback: import all modules once.
for file_path in _DATA_FOLDER.glob('*_dataset.py'):
module_name = f'{_DATA_PACKAGE}.{file_path.stem}'
try:
importlib.import_module(module_name)
except Exception:
continue
_ALL_DATASET_IMPORTED = True
[docs]
def build_dataset(dataset_opt):
"""Build dataset from options.
Args:
dataset_opt (dict): Configuration for dataset. It must contain:
name (str): Dataset name.
type (str): Dataset type.
"""
dataset_opt = deepcopy(dataset_opt)
if dataset_opt['type'] not in DATASET_REGISTRY:
_lazy_import_dataset(dataset_opt['type'])
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
logger = get_root_logger()
logger.info(
f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.'
)
return dataset
[docs]
def build_dataloader(
dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None
):
"""Build dataloader.
Args:
dataset (torch.utils.data.Dataset): Dataset.
dataset_opt (dict): Dataset options. It contains the following keys:
phase (str): 'train' or 'val'.
num_worker_per_gpu (int): Number of workers for each GPU.
batch_size_per_gpu (int): Training batch size for each GPU.
num_gpu (int): Number of GPUs. Used only in the train phase.
Default: 1.
dist (bool): Whether in distributed training. Used only in the train
phase. Default: False.
sampler (torch.utils.data.sampler): Data sampler. Default: None.
seed (int | None): Seed. Default: None
"""
phase = dataset_opt['phase']
rank, _ = get_dist_info()
if phase == 'train':
if dist: # distributed training
batch_size = dataset_opt['batch_size_per_gpu']
num_workers = dataset_opt['num_worker_per_gpu']
else: # non-distributed training
multiplier = 1 if num_gpu == 0 else num_gpu
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
dataloader_args = dict(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
drop_last=True,
)
if sampler is None:
dataloader_args['shuffle'] = True
dataloader_args['worker_init_fn'] = (
partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed)
if seed is not None
else None
)
elif phase in ['val', 'test']: # validation
batch_size = dataset_opt.get('batch_size_per_gpu', 1)
num_workers = dataset_opt.get('num_worker_per_gpu', 0)
dataloader_args = dict(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
)
else:
raise ValueError(
f'Wrong dataset phase: {phase}. '
"Supported ones are 'train', 'val' and 'test'."
)
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
prefetch_mode = dataset_opt.get('prefetch_mode')
if prefetch_mode == 'cpu': # CPUPrefetcher
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
logger = get_root_logger()
logger.info(
f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}'
)
return PrefetchDataLoader(
num_prefetch_queue=num_prefetch_queue, **dataloader_args
)
else:
# prefetch_mode=None: Normal dataloader
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
return torch.utils.data.DataLoader(**dataloader_args)
def worker_init_fn(worker_id, num_workers, rank, seed):
# Set the worker seed to num_workers * rank + worker_id + seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)