Source code for pyiqa.archs

import importlib
import copy
import re
from pathlib import Path

from pyiqa.utils import get_root_logger
from pyiqa.utils.registry import ARCH_REGISTRY


__all__ = ['build_network']


_ARCH_PACKAGE = __package__
_ARCH_FOLDER = Path(__file__).resolve().parent
_ALL_ARCH_IMPORTED = False


def _camel_to_snake(name: str) -> str:
    s1 = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', name)
    return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


def _lazy_import_arch(network_type: str) -> None:
    global _ALL_ARCH_IMPORTED

    stem = network_type[:-5] if network_type.endswith('_arch') else network_type
    candidates = (
        f'{_ARCH_PACKAGE}.{stem}_arch',
        f'{_ARCH_PACKAGE}.{stem.lower()}_arch',
        f'{_ARCH_PACKAGE}.{_camel_to_snake(stem)}_arch',
    )

    for module_name in dict.fromkeys(candidates):
        try:
            importlib.import_module(module_name)
        except ModuleNotFoundError as error:
            # Ignore only when the candidate module itself is missing.
            if error.name != module_name:
                raise
        if network_type in ARCH_REGISTRY:
            return

    if _ALL_ARCH_IMPORTED:
        return

    # Compatibility fallback: import all arch modules once for shared-module class names.
    for file_path in _ARCH_FOLDER.glob('*_arch.py'):
        module_name = f'{_ARCH_PACKAGE}.{file_path.stem}'
        try:
            importlib.import_module(module_name)
        except Exception:
            # Some optional modules may fail to import; continue loading others.
            continue

    _ALL_ARCH_IMPORTED = True


[docs] def build_network(opt): """ Build a network based on the provided options. Args: opt (dict): Dictionary containing network options. Must include the 'type' key. Returns: nn.Module: The constructed network. Example: >>> net = build_network(opt) >>> print(net) """ opt = copy.deepcopy(opt) network_type = opt.pop('type') logger = get_root_logger() # Deterministic lazy import without class-name cache files. if network_type not in ARCH_REGISTRY: _lazy_import_arch(network_type) net = ARCH_REGISTRY.get(network_type)(**opt) logger.info(f'Network [{net.__class__.__name__}] is created.') return net