Source code for pyiqa.train_nsplits
import torch
import os
import numpy as np
from os import path as osp
from pyiqa.utils.options import parse_options, make_paths
from pyiqa.train import train_pipeline
[docs]
def train_nsplits(root_path):
torch.backends.cudnn.benchmark = True
opt, args = parse_options(root_path, is_train=True)
n_splits = opt['split_num']
save_path = opt['save_final_results_path']
os.makedirs(os.path.dirname(save_path), exist_ok=True)
all_split_results = []
prefix_name = opt['name']
for i in range(n_splits):
# update split specific options
opt['name'] = prefix_name + f'_Split{i:02d}'
make_paths(opt, root_path)
for k in opt['datasets'].keys():
opt['datasets'][k]['split_index'] = i + 1
tmp_results = train_pipeline(root_path, opt, args)
all_split_results.append(tmp_results)
with open(save_path, 'w') as sf:
datasets = list(all_split_results[0].keys())
metrics = list(all_split_results[0][datasets[0]].keys())
print(datasets, metrics)
sf.write('Val Datasets\tSplits\t{}\n'.format('\t'.join(metrics)))
for ds in datasets:
all_results = []
for i in range(n_splits):
results_msg = f'{ds}\t{i:02d}\t'
tmp_metric_results = []
for mt in metrics:
tmp_metric_results.append(all_split_results[i][ds][mt]['val'])
results_msg += f'{all_split_results[i][ds][mt]["val"]:04f}\t'
results_msg += f'@{all_split_results[i][ds][mt]["iter"]:05d}\n'
sf.write(results_msg)
all_results.append(tmp_metric_results)
results_avg = np.array(all_results).mean(axis=0)
results_std = np.array(all_results).std(axis=0)
sf.write(f'Overall results in {ds}: {results_avg}\t{results_std}\n')
if __name__ == '__main__':
[docs]
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_nsplits(root_path)