Source code for pyiqa.api_helpers
import fnmatch
import re
from pyiqa.default_model_configs import DEFAULT_CONFIGS
import yaml
import os
from pyiqa.utils import get_root_logger
from pyiqa.models.inference_model import InferenceModel
[docs]
def create_metric(metric_name, as_loss=False, device=None, **kwargs):
assert metric_name in DEFAULT_CONFIGS.keys(), (
f'Metric {metric_name} not implemented yet.'
)
metric = InferenceModel(metric_name, as_loss=as_loss, device=device, **kwargs)
logger = get_root_logger()
logger.info(f'Metric [{metric.net.__class__.__name__}] is created.')
return metric
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
[docs]
def list_models(metric_mode=None, filter='', exclude_filters=''):
"""Return list of available model names, sorted alphabetically
Args:
filter (str) - Wildcard filter string that works with fnmatch
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
Example:
model_list('*ssim*') -- returns all models including 'ssim'
"""
if metric_mode is None:
all_models = DEFAULT_CONFIGS.keys()
else:
assert metric_mode in ['FR', 'NR'], (
f'Metric mode only support [FR, NR], but got {metric_mode}'
)
all_models = [
key
for key in DEFAULT_CONFIGS.keys()
if DEFAULT_CONFIGS[key]['metric_mode'] == metric_mode
]
if filter:
models = []
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters:
include_models = fnmatch.filter(all_models, f) # include these models
if len(include_models):
models = set(models).union(include_models)
else:
models = all_models
if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)):
exclude_filters = [exclude_filters]
for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models):
models = set(models).difference(exclude_models)
return list(sorted(models, key=_natural_key))
[docs]
def get_dataset_info(dataset_name=None):
dataset_info = yaml.safe_load(
open(
f'{os.path.dirname(os.path.abspath(__file__))}/default_dataset_configs.yml',
'r',
)
)
if dataset_name == None:
return dataset_info
else:
assert dataset_name in dataset_info.keys(), (
f'Dataset {dataset_name} not implemented yet.'
)
return dataset_info[dataset_name]