import torch.distributed as dist import torch.multiprocessing as mp import os import os.path as osp import sys import numpy as np import copy from lib.cfg_holder import cfg_unique_holder as cfguh from lib.cfg_helper import \ get_command_line_args, \ cfg_initiates from lib.model_zoo.sd import version from lib.utils import get_obj_from_str if __name__ == "__main__": cfg = get_command_line_args() cfg = cfg_initiates(cfg) if 'train' in cfg: trainer = get_obj_from_str(cfg.train.main)(cfg) tstage = get_obj_from_str(cfg.train.stage)() if 'eval' in cfg: tstage.nested_eval_stage = get_obj_from_str(cfg.eval.stage)() trainer.register_stage(tstage) if cfg.env.gpu_count == 1: trainer(0) else: mp.spawn(trainer, args=(), nprocs=cfg.env.gpu_count, join=True) trainer.destroy() else: evaler = get_obj_from_str(cfg.eval.main)(cfg) estage = get_obj_from_str(cfg.eval.stage)() evaler.register_stage(estage) if cfg.env.gpu_count == 1: evaler(0) else: mp.spawn(evaler, args=(), nprocs=cfg.env.gpu_count, join=True) evaler.destroy()