Shortcuts

Source code for mmcv.runner.hooks.ema

# Copyright (c) OpenMMLab. All rights reserved.
from ...parallel import is_module_wrapper
from ..hooks.hook import HOOKS, Hook


[docs]@HOOKS.register_module() class EMAHook(Hook): r"""Exponential Moving Average Hook. Use Exponential Moving Average on all parameters of model in training process. All parameters have a ema backup, which update by the formula as below. EMAHook takes priority over EvalHook and CheckpointSaverHook. .. math:: \text{Xema\_{t+1}} = (1 - \text{momentum}) \times \text{Xema\_{t}} + \text{momentum} \times X_t Args: momentum (float): The momentum used for updating ema parameter. Defaults to 0.0002. interval (int): Update ema parameter every interval iteration. Defaults to 1. warm_up (int): During first warm_up steps, we may use smaller momentum to update ema parameters more slowly. Defaults to 100. resume_from (str): The checkpoint path. Defaults to None. """ def __init__(self, momentum=0.0002, interval=1, warm_up=100, resume_from=None): assert isinstance(interval, int) and interval > 0 self.warm_up = warm_up self.interval = interval assert momentum > 0 and momentum < 1 self.momentum = momentum**interval self.checkpoint = resume_from
[docs] def before_run(self, runner): """To resume model with it's ema parameters more friendly. Register ema parameter as ``named_buffer`` to model """ model = runner.model if is_module_wrapper(model): model = model.module self.param_ema_buffer = {} self.model_parameters = dict(model.named_parameters(recurse=True)) for name, value in self.model_parameters.items(): # "." is not allowed in module's buffer name buffer_name = f"ema_{name.replace('.', '_')}" self.param_ema_buffer[name] = buffer_name model.register_buffer(buffer_name, value.data.clone()) self.model_buffers = dict(model.named_buffers(recurse=True)) if self.checkpoint is not None: runner.resume(self.checkpoint)
[docs] def after_train_iter(self, runner): """Update ema parameter every self.interval iterations.""" curr_step = runner.iter # We warm up the momentum considering the instability at beginning momentum = min(self.momentum, (1 + curr_step) / (self.warm_up + curr_step)) if curr_step % self.interval != 0: return for name, parameter in self.model_parameters.items(): buffer_name = self.param_ema_buffer[name] buffer_parameter = self.model_buffers[buffer_name] buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)
[docs] def after_train_epoch(self, runner): """We load parameter values from ema backup to model before the EvalHook.""" self._swap_ema_parameters()
[docs] def before_train_epoch(self, runner): """We recover model's parameter from ema backup after last epoch's EvalHook.""" self._swap_ema_parameters()
def _swap_ema_parameters(self): """Swap the parameter of model with parameter in ema_buffer.""" for name, value in self.model_parameters.items(): temp = value.data.clone() ema_buffer = self.model_buffers[self.param_ema_buffer[name]] value.data.copy_(ema_buffer.data) ema_buffer.data.copy_(temp)
Read the Docs v: v1.3.14
Versions
master
latest
v1.3.14
v1.3.13
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.