| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| import os | |
| import os.path as osp | |
| import sys | |
| import numpy as np | |
| import copy | |
| from lib.cfg_holder import cfg_unique_holder as cfguh | |
| from lib.cfg_helper import \ | |
| get_command_line_args, \ | |
| cfg_initiates | |
| from lib.model_zoo.sd import version | |
| from lib.utils import get_obj_from_str | |
| if __name__ == "__main__": | |
| cfg = get_command_line_args() | |
| cfg = cfg_initiates(cfg) | |
| if 'train' in cfg: | |
| trainer = get_obj_from_str(cfg.train.main)(cfg) | |
| tstage = get_obj_from_str(cfg.train.stage)() | |
| if 'eval' in cfg: | |
| tstage.nested_eval_stage = get_obj_from_str(cfg.eval.stage)() | |
| trainer.register_stage(tstage) | |
| if cfg.env.gpu_count == 1: | |
| trainer(0) | |
| else: | |
| mp.spawn(trainer, | |
| args=(), | |
| nprocs=cfg.env.gpu_count, | |
| join=True) | |
| trainer.destroy() | |
| else: | |
| evaler = get_obj_from_str(cfg.eval.main)(cfg) | |
| estage = get_obj_from_str(cfg.eval.stage)() | |
| evaler.register_stage(estage) | |
| if cfg.env.gpu_count == 1: | |
| evaler(0) | |
| else: | |
| mp.spawn(evaler, | |
| args=(), | |
| nprocs=cfg.env.gpu_count, | |
| join=True) | |
| evaler.destroy() | |