Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # X-Decoder -- Generalized Decoding for Pixel, Image, and Language | |
| # Copyright (c) 2022 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # Modified by Xueyan Zou ([email protected]) | |
| # -------------------------------------------------------- | |
| import logging | |
| import os | |
| import json | |
| import random | |
| import copy | |
| import itertools | |
| from typing import Any, Dict, List, Set, Union | |
| from datetime import datetime | |
| from mpi4py import MPI | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from detectron2.projects.deeplab import build_lr_scheduler | |
| from fvcore.common.config import CfgNode | |
| from infinibatch import iterators | |
| from utilities.distributed import is_main_process, get_world_size | |
| from .default_trainer import DefaultTrainer | |
| from .utils.serialization import JSONEncoder, filter_jsonable | |
| logger = logging.getLogger(__name__) | |
| class XDecoder_Trainer(DefaultTrainer): | |
| """ | |
| Construct Mask2Former_Trainer for optimizer and lr_scheduler | |
| """ | |
| def create_optimizer_and_scheduler(self): | |
| """ | |
| Set up self.optimizers and self.lr_schedulers | |
| This method initializes self.optimizers and self.lr_schedulers as dictionaries of | |
| instances of the classes that OPTIMIZER and LR_SCHEDULER in the config file points to. | |
| One optimizer and lr scheduler for each model in self.raw_models. They have the same keys | |
| as self.raw_models. | |
| """ | |
| self.opt['init_optimizer_in_deepspeed'] = False | |
| self.opt['init_lr_scheduler_in_deepspeed'] = False | |
| self.optimizers = {module_name: None for module_name in self.model_names} | |
| self.lr_schedulers = {module_name: None for module_name in self.model_names} | |
| cfg_solver = self.opt['SOLVER'] | |
| weight_decay_norm = cfg_solver['WEIGHT_DECAY_NORM'] | |
| weight_decay_embed = cfg_solver['WEIGHT_DECAY_EMBED'] | |
| weight_decay_bias = cfg_solver.get('WEIGHT_DECAY_BIAS', 0.0) | |
| defaults = {} | |
| defaults["lr"] = cfg_solver['BASE_LR'] | |
| defaults["weight_decay"] = cfg_solver['WEIGHT_DECAY'] | |
| norm_module_types = ( | |
| torch.nn.BatchNorm1d, | |
| torch.nn.BatchNorm2d, | |
| torch.nn.BatchNorm3d, | |
| torch.nn.SyncBatchNorm, | |
| # NaiveSyncBatchNorm inherits from BatchNorm2d | |
| torch.nn.GroupNorm, | |
| torch.nn.InstanceNorm1d, | |
| torch.nn.InstanceNorm2d, | |
| torch.nn.InstanceNorm3d, | |
| torch.nn.LayerNorm, | |
| torch.nn.LocalResponseNorm, | |
| ) | |
| fix_param = self.opt['SOLVER'].get('FIX_PARAM',{}) | |
| ignore_fix = self.opt['SOLVER'].get('IGNORE_FIX',[]) | |
| for _module_name in self.model_names: | |
| flag_continue = False | |
| module_params = {} | |
| for name, param in self.raw_models[_module_name].named_parameters(): | |
| for ig in ignore_fix: | |
| if ig in name: | |
| flag_continue = True | |
| break | |
| if flag_continue: | |
| flag_continue = False | |
| continue | |
| for key, value in fix_param.items(): | |
| if key in name and value == True: | |
| param.requires_grad = False | |
| if key in name: | |
| if key not in module_params: | |
| module_params[key] = 0 | |
| module_params[key] += param.numel() | |
| logger.info(f"Module {_module_name} has parameters: {module_params}") | |
| #raise NotImplementedError("Please check the fix_param and ignore_fix in the config file") | |
| lr_multiplier = self.opt['SOLVER']['LR_MULTIPLIER'] | |
| for _module_name in self.model_names: | |
| # parameters = self.raw_models[module_name].get_training_parameters() | |
| # self.optimizers[module_name] = optimizer_class(parameters, **optimizer_parameters) | |
| # params = [] | |
| # for module_param_name, value in self.raw_models[module_name].named_parameters(recurse=True): | |
| params: List[Dict[str, Any]] = [] | |
| memo: Set[torch.nn.parameter.Parameter] = set() | |
| for module_name, module in self.raw_models[_module_name].named_modules(): | |
| for module_param_name, value in module.named_parameters(recurse=False): | |
| if not value.requires_grad: | |
| continue | |
| # Avoid duplicating parameters | |
| if value in memo: | |
| continue | |
| memo.add(value) | |
| hyperparams = copy.copy(defaults) | |
| for key, lr_mul in lr_multiplier.items(): | |
| if key in "{}.{}".format(module_name, module_param_name): | |
| hyperparams["lr"] = hyperparams["lr"] * lr_mul | |
| if is_main_process(): | |
| logger.info("Modify Learning rate of {}: {}".format("{}.{}".format(module_name, module_param_name), lr_mul)) | |
| if ( | |
| "relative_position_bias_table" in module_param_name | |
| or "absolute_pos_embed" in module_param_name | |
| ): | |
| hyperparams["weight_decay"] = 0.0 | |
| if isinstance(module, norm_module_types): | |
| hyperparams["weight_decay"] = weight_decay_norm | |
| if isinstance(module, torch.nn.Embedding): | |
| hyperparams["weight_decay"] = weight_decay_embed | |
| if "bias" in module_name: | |
| hyperparams["weight_decay"] = weight_decay_bias | |
| params.append({"params": [value], **hyperparams}) | |
| def maybe_add_full_model_gradient_clipping(optim): | |
| # detectron2 doesn't have full model gradient clipping now | |
| clip_norm_val = cfg_solver['CLIP_GRADIENTS']['CLIP_VALUE'] | |
| enable = ( | |
| cfg_solver['CLIP_GRADIENTS']['ENABLED'] | |
| and cfg_solver['CLIP_GRADIENTS']['CLIP_TYPE'] == "full_model" | |
| and clip_norm_val > 0.0 | |
| ) | |
| class FullModelGradientClippingOptimizer(optim): | |
| def step(self, closure=None): | |
| all_params = itertools.chain(*[x["params"] for x in self.param_groups]) | |
| torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) | |
| super().step(closure=closure) | |
| return FullModelGradientClippingOptimizer if enable else optim | |
| optimizer_type = cfg_solver['OPTIMIZER'] | |
| if optimizer_type == "SGD": | |
| optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( | |
| params, cfg_solver['BASE_LR'], momentum=cfg_solver['MOMENTUM'] | |
| ) | |
| elif optimizer_type == "ADAMW": | |
| optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( | |
| params, cfg_solver['BASE_LR'] | |
| ) | |
| else: | |
| raise NotImplementedError(f"no optimizer type {optimizer_type}") | |
| self.optimizers[_module_name] = optimizer | |
| self.optimizers[_module_name].zero_grad() | |
| num_epoch = self.opt['SOLVER']['MAX_NUM_EPOCHS'] | |
| cfg_solver['MAX_ITER'] = num_epoch * self.train_params['updates_per_epoch'] | |
| cfg_solver['STEPS'] = [int(x*cfg_solver['MAX_ITER']) for x in cfg_solver['STEPS']] | |
| logger.info(f"Calculate MAX_ITER @ {cfg_solver['MAX_ITER']} and STEPS @ {cfg_solver['STEPS']}") | |
| for module_name in self.model_names: | |
| scheduler_cfg = CfgNode({'SOLVER': cfg_solver}) | |
| self.lr_schedulers[module_name] = build_lr_scheduler(scheduler_cfg, self.optimizers[module_name]) | |
| for module_name in self.model_names: | |
| num_params = 0 | |
| num_trainable_params = 0 | |
| for name, param in self.raw_models[module_name].named_parameters(): | |
| num_params += param.numel() | |
| if param.requires_grad: | |
| num_trainable_params += param.numel() | |
| logger.info(f"Total number of parameters in {module_name} module (on each GPU): {num_params}") | |
| logger.info(f"Number of trainable parameters in {module_name} module (on each GPU): {num_trainable_params}") |