Source code for pyiqa.models
import importlib
from copy import deepcopy
import re
from pathlib import Path
from pyiqa.utils import get_root_logger
from pyiqa.utils.registry import MODEL_REGISTRY
__all__ = ['build_model']
_MODEL_PACKAGE = __package__
_MODEL_FOLDER = Path(__file__).resolve().parent
_ALL_MODEL_IMPORTED = False
def _camel_to_snake(name: str) -> str:
s1 = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
def _lazy_import_model(model_type: str) -> None:
global _ALL_MODEL_IMPORTED
stem = model_type[:-6] if model_type.endswith('_model') else model_type
candidates = (
f'{_MODEL_PACKAGE}.{stem}_model',
f'{_MODEL_PACKAGE}.{stem.lower()}_model',
f'{_MODEL_PACKAGE}.{_camel_to_snake(stem)}_model',
)
for module_name in dict.fromkeys(candidates):
try:
importlib.import_module(module_name)
except ModuleNotFoundError as error:
if error.name != module_name:
raise
if model_type in MODEL_REGISTRY:
return
if _ALL_MODEL_IMPORTED:
return
# Compatibility fallback: import all modules once.
for file_path in _MODEL_FOLDER.glob('*_model.py'):
module_name = f'{_MODEL_PACKAGE}.{file_path.stem}'
try:
importlib.import_module(module_name)
except Exception:
continue
_ALL_MODEL_IMPORTED = True
[docs]
def build_model(opt):
"""Build model from options.
Args:
opt (dict): Configuration. It must contain:
model_type (str): Model type.
"""
opt = deepcopy(opt)
if opt['model_type'] not in MODEL_REGISTRY:
_lazy_import_model(opt['model_type'])
model = MODEL_REGISTRY.get(opt['model_type'])(opt)
logger = get_root_logger()
logger.info(f'Model [{model.__class__.__name__}] is created.')
return model