Spaces:
Runtime error
Runtime error
| import weakref | |
| import numpy as np | |
| class PlainCosineScheduler(object): | |
| def __init__( | |
| self, | |
| klass, | |
| key, | |
| warmup_iters, | |
| total_iters, | |
| overwrite=False, | |
| init_value=None, | |
| base_value=None, | |
| final_value=None, | |
| step_init=-1, | |
| ): | |
| super().__init__() | |
| self.iter = step_init | |
| self.overwrite = overwrite | |
| self.base_value = base_value | |
| self.init_value = init_value if init_value is not None else base_value | |
| self.final_value = final_value | |
| self.total_iters = total_iters | |
| self.warmup_iters = warmup_iters | |
| self.key = key | |
| self.klass = klass | |
| self.schedulers = [self.get_scheduler()] | |
| def get_scheduler(self): | |
| init_value = self.init_value | |
| base_value = self.base_value | |
| final_value = self.final_value | |
| warmup_iters = self.warmup_iters | |
| total_iters = self.total_iters | |
| # normalize in 0,1, then apply function (power) and denormalize | |
| normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True) | |
| normalized_schedule = np.power(normalized_schedule, 1) | |
| warmup_schedule = (base_value - init_value) * normalized_schedule + init_value | |
| # main scheduling | |
| iters = np.arange(total_iters - warmup_iters + 1) | |
| schedule = final_value + 0.5 * (base_value - final_value) * ( | |
| 1 + np.cos(np.pi * iters / (len(iters) - 1)) | |
| ) | |
| return np.concatenate((warmup_schedule, schedule)) | |
| def step(self): | |
| self.iter = self.iter + 1 | |
| vals = self[self.iter] | |
| for i, val in enumerate(vals): | |
| setattr(self.klass, self.key, val) | |
| def __getitem__(self, it): | |
| it = min(it, self.total_iters) | |
| return [scheduler[it] for scheduler in self.schedulers] | |
| class CosineScheduler(object): | |
| def __init__( | |
| self, | |
| optimizer, | |
| warmup_iters, | |
| total_iters, | |
| key, | |
| overwrite=False, | |
| init_value=None, | |
| base_value=None, | |
| final_value=None, | |
| flat_iters=0, | |
| step_init=-1, | |
| ): | |
| super().__init__() | |
| self.iter = step_init | |
| self.overwrite = overwrite | |
| self.optimizer = optimizer | |
| self.base_value = base_value | |
| self.init_value = init_value | |
| self.final_value = final_value | |
| self.total_iters = total_iters | |
| self.warmup_iters = warmup_iters | |
| self.flat_iters = flat_iters | |
| self.key = key | |
| self.schedulers = [ | |
| self.get_schedulers(group) for group in optimizer.param_groups | |
| ] | |
| def get_schedulers(self, group): | |
| init_value = group.get(self.key + "_init", self.init_value) | |
| base_value = group.get(self.key + "_base", self.base_value) | |
| final_value = group.get(self.key + "_final", self.final_value) | |
| warmup_iters = self.warmup_iters | |
| total_iters = self.total_iters | |
| flat_iters = self.flat_iters | |
| if self.overwrite: | |
| final_value = self.final_value | |
| # normalize in 0,1, then apply function (power) and denormalize | |
| normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True) | |
| normalized_schedule = np.power(normalized_schedule, 1) | |
| warmup_schedule = (base_value - init_value) * normalized_schedule + init_value | |
| # flat scheduling] | |
| flat_schedule = np.ones(flat_iters) * base_value | |
| # decay scheduling | |
| decay_iters = np.arange(total_iters - warmup_iters - flat_iters + 1) | |
| decay_schedule = final_value + 0.5 * (base_value - final_value) * ( | |
| 1 + np.cos(np.pi * decay_iters / (len(decay_iters) - 1)) | |
| ) | |
| return np.concatenate((warmup_schedule, flat_schedule, decay_schedule)) | |
| def step(self): | |
| self.iter = self.iter + 1 | |
| vals = self[self.iter] | |
| for group, val in zip(self.optimizer.param_groups, vals): | |
| if isinstance(group[self.key], (tuple, list)): | |
| val = (val, *group[self.key][1:]) | |
| group[self.key] = val | |
| def __getitem__(self, it): | |
| it = min(it, self.total_iters) | |
| return [scheduler[it] for scheduler in self.schedulers] | |
| def get(self): | |
| return [group[self.key] for group in self.optimizer.param_groups] | |