Source code for pyiqa.utils.registry

# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py  # noqa: E501
import difflib
import importlib
import warnings

try:
    from importlib import metadata as importlib_metadata
except ImportError:  # pragma: no cover
[docs] importlib_metadata = None
[docs] class Registry: """ The registry that provides name -> object mapping, to support third-party users' custom modules. To create a registry (e.g. a backbone registry): .. code-block:: python BACKBONE_REGISTRY = Registry('BACKBONE') To register an object: .. code-block:: python @BACKBONE_REGISTRY.register() class MyBackbone(): ... Or: .. code-block:: python BACKBONE_REGISTRY.register(MyBackbone) """ def __init__(self, name, entry_point_group=None): """ Args: name (str): the name of this registry. entry_point_group (str | None): Optional plugin group name. """ self._name = name self._obj_map = {} self._lazy_map = {} self._entry_point_group = entry_point_group self._entry_points_loaded = False def _do_register(self, name, obj, override=False): if (not override) and (name in self._obj_map): raise AssertionError( f"An object named '{name}' was already registered " f"in '{self._name}' registry!" ) self._obj_map[name] = obj self._lazy_map.pop(name, None) def _register_with_aliases(self, obj, name=None, aliases=None, override=False): primary_name = name or obj.__name__ names = [primary_name] if aliases: names.extend(aliases) for item_name in names: self._do_register(item_name, obj, override=override) @staticmethod def _split_import_path(import_path): if ':' in import_path: module_name, attr_name = import_path.split(':', 1) return module_name.strip(), attr_name.strip() return import_path.strip(), None def _resolve_lazy(self, name): import_path = self._lazy_map.get(name) if import_path is None: return module_name, attr_name = self._split_import_path(import_path) module = importlib.import_module(module_name) if attr_name: obj = getattr(module, attr_name) if name not in self._obj_map: self._do_register(name, obj, override=True) def _load_entry_points_once(self): if self._entry_points_loaded: return self._entry_points_loaded = True if not self._entry_point_group: return if importlib_metadata is None: # pragma: no cover return try: entry_points = importlib_metadata.entry_points() if hasattr(entry_points, 'select'): group_eps = entry_points.select(group=self._entry_point_group) else: # pragma: no cover group_eps = entry_points.get(self._entry_point_group, []) except Exception as error: warnings.warn( f"Failed to query entry points for '{self._entry_point_group}': {error}", RuntimeWarning, ) return for ep in group_eps: try: loaded = ep.load() if callable(loaded): loaded() except Exception as error: warnings.warn( f"Failed to load registry plugin '{ep.name}' from '{self._entry_point_group}': {error}", RuntimeWarning, )
[docs] def register_lazy(self, name, import_path, aliases=None, override=False): """Register an object lazily from an import path. Args: name (str): Name used in this registry. import_path (str): 'module.submodule' or 'module.submodule:ClassOrFn'. aliases (list[str] | tuple[str] | None): Optional alias names. override (bool): Whether to override existing registrations. """ names = [name] if aliases: names.extend(aliases) for item_name in names: if (not override) and (item_name in self._obj_map or item_name in self._lazy_map): raise AssertionError( f"An object named '{item_name}' was already registered " f"in '{self._name}' registry!" ) self._lazy_map[item_name] = import_path
[docs] def register(self, obj=None, name=None, aliases=None, override=False): """Register the given object. Supports decorator and direct-call forms. """ if aliases is not None and not isinstance(aliases, (list, tuple)): raise TypeError('aliases must be a list or tuple of strings.') if aliases is not None and not all(isinstance(alias, str) for alias in aliases): raise TypeError('Each alias must be a string.') if name is not None and not isinstance(name, str): raise TypeError('name must be a string.') if obj is None: # used as a decorator def deco(func_or_class): self._register_with_aliases( func_or_class, name=name, aliases=aliases, override=override, ) return func_or_class return deco # used as a function call self._register_with_aliases(obj, name=name, aliases=aliases, override=override) return obj
[docs] def get(self, name): self._resolve_lazy(name) self._load_entry_points_once() self._resolve_lazy(name) ret = self._obj_map.get(name) if ret is None: candidates = list(self._obj_map.keys()) + list(self._lazy_map.keys()) hint = '' if candidates: close = difflib.get_close_matches(name, candidates, n=5) if close: hint = f" Did you mean one of: {', '.join(close)}?" raise KeyError( f"No object named '{name}' found in '{self._name}' registry!" + hint ) return ret
def __contains__(self, name): if name in self._obj_map or name in self._lazy_map: return True self._load_entry_points_once() return name in self._obj_map or name in self._lazy_map def __iter__(self): self._load_entry_points_once() return iter(self._obj_map.items())
[docs] def keys(self): self._load_entry_points_once() return self._obj_map.keys()
[docs] DATASET_REGISTRY = Registry('dataset', entry_point_group='pyiqa.datasets')
[docs] ARCH_REGISTRY = Registry('arch', entry_point_group='pyiqa.archs')
[docs] MODEL_REGISTRY = Registry('model', entry_point_group='pyiqa.models')
[docs] LOSS_REGISTRY = Registry('loss', entry_point_group='pyiqa.losses')
[docs] METRIC_REGISTRY = Registry('metric', entry_point_group='pyiqa.metrics')