Source code for mmcv.runner.base_runner

# Copyright (c) Open-MMLab. All rights reserved.
import copy
import logging
import os.path as osp
import warnings
from abc import ABCMeta, abstractmethod

import torch
from torch.optim import Optimizer

import mmcv
from ..parallel import is_module_wrapper
from .checkpoint import load_checkpoint
from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook
from .log_buffer import LogBuffer
from .priority import Priority, get_priority
from .utils import get_time_str


[docs]class BaseRunner(metaclass=ABCMeta): """The base class of Runner, a training helper for PyTorch. All subclasses should implement the following APIs: - ``run()`` - ``train()`` - ``val()`` - ``save_checkpoint()`` Args: model (:obj:`torch.nn.Module`): The model to be run. batch_processor (callable): A callable method that process a data batch. The interface of this method should be `batch_processor(model, data, train_mode) -> dict` optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an optimizer (in most cases) or a dict of optimizers (in models that requires more than one optimizer, e.g., GAN). work_dir (str, optional): The working directory to save checkpoints and logs. Defaults to None. logger (:obj:`logging.Logger`): Logger used during training. Defaults to None. (The default value is just for backward compatibility) meta (dict | None): A dict records some import information such as environment info and seed, which will be logged in logger hook. Defaults to None. max_epochs (int, optional): Total training epochs. max_iters (int, optional): Total training iterations. """ def __init__(self, model, batch_processor=None, optimizer=None, work_dir=None, logger=None, meta=None, max_iters=None, max_epochs=None): if batch_processor is not None: if not callable(batch_processor): raise TypeError('batch_processor must be callable, ' f'but got {type(batch_processor)}') warnings.warn('batch_processor is deprecated, please implement ' 'train_step() and val_step() in the model instead.') # raise an error is `batch_processor` is not None and # `model.train_step()` exists. if is_module_wrapper(model): _model = model.module else: _model = model if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'): raise RuntimeError( 'batch_processor and model.train_step()/model.val_step() ' 'cannot be both available.') else: assert hasattr(model, 'train_step') # check the type of `optimizer` if isinstance(optimizer, dict): for name, optim in optimizer.items(): if not isinstance(optim, Optimizer): raise TypeError( f'optimizer must be a dict of torch.optim.Optimizers, ' f'but optimizer["{name}"] is a {type(optim)}') elif not isinstance(optimizer, Optimizer) and optimizer is not None: raise TypeError( f'optimizer must be a torch.optim.Optimizer object ' f'or dict or None, but got {type(optimizer)}') # check the type of `logger` if not isinstance(logger, logging.Logger): raise TypeError(f'logger must be a logging.Logger object, ' f'but got {type(logger)}') # check the type of `meta` if meta is not None and not isinstance(meta, dict): raise TypeError( f'meta must be a dict or None, but got {type(meta)}') self.model = model self.batch_processor = batch_processor self.optimizer = optimizer self.logger = logger self.meta = meta # create work_dir if mmcv.is_str(work_dir): self.work_dir = osp.abspath(work_dir) mmcv.mkdir_or_exist(self.work_dir) elif work_dir is None: self.work_dir = None else: raise TypeError('"work_dir" must be a str or None') # get model name from the model class if hasattr(self.model, 'module'): self._model_name = self.model.module.__class__.__name__ else: self._model_name = self.model.__class__.__name__ self._rank, self._world_size = get_dist_info() self.timestamp = get_time_str() self.mode = None self._hooks = [] self._epoch = 0 self._iter = 0 self._inner_iter = 0 if max_epochs is not None and max_iters is not None: raise ValueError( 'Only one of `max_epochs` or `max_iters` can be set.') self._max_epochs = max_epochs self._max_iters = max_iters # TODO: Redesign LogBuffer, it is not flexible and elegant enough self.log_buffer = LogBuffer() @property def model_name(self): """str: Name of the model, usually the module class name.""" return self._model_name @property def rank(self): """int: Rank of current process. (distributed training)""" return self._rank @property def world_size(self): """int: Number of processes participating in the job. (distributed training)""" return self._world_size @property def hooks(self): """list[:obj:`Hook`]: A list of registered hooks.""" return self._hooks @property def epoch(self): """int: Current epoch.""" return self._epoch @property def iter(self): """int: Current iteration.""" return self._iter @property def inner_iter(self): """int: Iteration in an epoch.""" return self._inner_iter @property def max_epochs(self): """int: Maximum training epochs.""" return self._max_epochs @property def max_iters(self): """int: Maximum training iterations.""" return self._max_iters @abstractmethod def train(self): pass @abstractmethod def val(self): pass @abstractmethod def run(self, data_loaders, workflow, **kwargs): pass @abstractmethod def save_checkpoint(self, out_dir, filename_tmpl, save_optimizer=True, meta=None, create_symlink=True): pass
[docs] def current_lr(self): """Get current learning rates. Returns: list[float] | dict[str, list[float]]: Current learning rates of all param groups. If the runner has a dict of optimizers, this method will return a dict. """ if isinstance(self.optimizer, torch.optim.Optimizer): lr = [group['lr'] for group in self.optimizer.param_groups] elif isinstance(self.optimizer, dict): lr = dict() for name, optim in self.optimizer.items(): lr[name] = [group['lr'] for group in optim.param_groups] else: raise RuntimeError( 'lr is not applicable because optimizer does not exist.') return lr
[docs] def current_momentum(self): """Get current momentums. Returns: list[float] | dict[str, list[float]]: Current momentums of all param groups. If the runner has a dict of optimizers, this method will return a dict. """ def _get_momentum(optimizer): momentums = [] for group in optimizer.param_groups: if 'momentum' in group.keys(): momentums.append(group['momentum']) elif 'betas' in group.keys(): momentums.append(group['betas'][0]) else: momentums.append(0) return momentums if self.optimizer is None: raise RuntimeError( 'momentum is not applicable because optimizer does not exist.') elif isinstance(self.optimizer, torch.optim.Optimizer): momentums = _get_momentum(self.optimizer) elif isinstance(self.optimizer, dict): momentums = dict() for name, optim in self.optimizer.items(): momentums[name] = _get_momentum(optim) return momentums
[docs] def register_hook(self, hook, priority='NORMAL'): """Register a hook into the hook list. The hook will be inserted into a priority queue, with the specified priority (See :class:`Priority` for details of priorities). For hooks with the same priority, they will be triggered in the same order as they are registered. Args: hook (:obj:`Hook`): The hook to be registered. priority (int or str or :obj:`Priority`): Hook priority. Lower value means higher priority. """ assert isinstance(hook, Hook) if hasattr(hook, 'priority'): raise ValueError('"priority" is a reserved attribute for hooks') priority = get_priority(priority) hook.priority = priority # insert the hook to a sorted list inserted = False for i in range(len(self._hooks) - 1, -1, -1): if priority >= self._hooks[i].priority: self._hooks.insert(i + 1, hook) inserted = True break if not inserted: self._hooks.insert(0, hook)
[docs] def register_hook_from_cfg(self, hook_cfg): """Register a hook from its cfg. Args: hook_cfg (dict): Hook config. It should have at least keys 'type' and 'priority' indicating its type and priority. Notes: The specific hook class to register should not use 'type' and 'priority' arguments during initialization. """ hook_cfg = hook_cfg.copy() priority = hook_cfg.pop('priority', 'NORMAL') hook = mmcv.build_from_cfg(hook_cfg, HOOKS) self.register_hook(hook, priority=priority)
[docs] def call_hook(self, fn_name): """Call all hooks. Args: fn_name (str): The function name in each hook to be called, such as "before_train_epoch". """ for hook in self._hooks: getattr(hook, fn_name)(self)
def get_hook_info(self): # Get hooks info in each stage stage_hook_map = {stage: [] for stage in Hook.stages} for hook in self.hooks: try: priority = Priority(hook.priority).name except ValueError: priority = hook.priority classname = hook.__class__.__name__ hook_info = f'({priority:<12}) {classname:<35}' for trigger_stage in hook.get_triggered_stages(): stage_hook_map[trigger_stage].append(hook_info) stage_hook_infos = [] for stage in Hook.stages: hook_infos = stage_hook_map[stage] if len(hook_infos) > 0: info = f'{stage}:\n' info += '\n'.join(hook_infos) info += '\n -------------------- ' stage_hook_infos.append(info) return '\n'.join(stage_hook_infos) def load_checkpoint(self, filename, map_location='cpu', strict=False, revise_keys=[(r'^module.', '')]): self.logger.info('load checkpoint from %s', filename) return load_checkpoint( self.model, filename, map_location, strict, self.logger, revise_keys=revise_keys) def resume(self, checkpoint, resume_optimizer=True, map_location='default'): if map_location == 'default': if torch.cuda.is_available(): device_id = torch.cuda.current_device() checkpoint = self.load_checkpoint( checkpoint, map_location=lambda storage, loc: storage.cuda(device_id)) else: checkpoint = self.load_checkpoint(checkpoint) else: checkpoint = self.load_checkpoint( checkpoint, map_location=map_location) self._epoch = checkpoint['meta']['epoch'] self._iter = checkpoint['meta']['iter'] if self.meta is None: self.meta = {} self.meta.setdefault('hook_msgs', {}) # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {})) # Re-calculate the number of iterations when resuming # models with different number of GPUs if 'config' in checkpoint['meta']: config = mmcv.Config.fromstring( checkpoint['meta']['config'], file_format='.py') previous_gpu_ids = config.get('gpu_ids', None) if previous_gpu_ids and len(previous_gpu_ids) > 0 and len( previous_gpu_ids) != self.world_size: self._iter = int(self._iter * len(previous_gpu_ids) / self.world_size) self.logger.info('the iteration number is changed due to ' 'change of GPU number') # resume meta information meta self.meta = checkpoint['meta'] if 'optimizer' in checkpoint and resume_optimizer: if isinstance(self.optimizer, Optimizer): self.optimizer.load_state_dict(checkpoint['optimizer']) elif isinstance(self.optimizer, dict): for k in self.optimizer.keys(): self.optimizer[k].load_state_dict( checkpoint['optimizer'][k]) else: raise TypeError( 'Optimizer should be dict or torch.optim.Optimizer ' f'but got {type(self.optimizer)}') self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) def register_lr_hook(self, lr_config): if lr_config is None: return elif isinstance(lr_config, dict): assert 'policy' in lr_config policy_type = lr_config.pop('policy') # If the type of policy is all in lower case, e.g., 'cyclic', # then its first letter will be capitalized, e.g., to be 'Cyclic'. # This is for the convenient usage of Lr updater. # Since this is not applicable for ` # CosineAnnealingLrUpdater`, # the string will not be changed if it contains capital letters. if policy_type == policy_type.lower(): policy_type = policy_type.title() hook_type = policy_type + 'LrUpdaterHook' lr_config['type'] = hook_type hook = mmcv.build_from_cfg(lr_config, HOOKS) else: hook = lr_config self.register_hook(hook, priority='VERY_HIGH') def register_momentum_hook(self, momentum_config): if momentum_config is None: return if isinstance(momentum_config, dict): assert 'policy' in momentum_config policy_type = momentum_config.pop('policy') # If the type of policy is all in lower case, e.g., 'cyclic', # then its first letter will be capitalized, e.g., to be 'Cyclic'. # This is for the convenient usage of momentum updater. # Since this is not applicable for # `CosineAnnealingMomentumUpdater`, # the string will not be changed if it contains capital letters. if policy_type == policy_type.lower(): policy_type = policy_type.title() hook_type = policy_type + 'MomentumUpdaterHook' momentum_config['type'] = hook_type hook = mmcv.build_from_cfg(momentum_config, HOOKS) else: hook = momentum_config self.register_hook(hook, priority='HIGH') def register_optimizer_hook(self, optimizer_config): if optimizer_config is None: return if isinstance(optimizer_config, dict): optimizer_config.setdefault('type', 'OptimizerHook') hook = mmcv.build_from_cfg(optimizer_config, HOOKS) else: hook = optimizer_config self.register_hook(hook, priority='ABOVE_NORMAL') def register_checkpoint_hook(self, checkpoint_config): if checkpoint_config is None: return if isinstance(checkpoint_config, dict): checkpoint_config.setdefault('type', 'CheckpointHook') hook = mmcv.build_from_cfg(checkpoint_config, HOOKS) else: hook = checkpoint_config self.register_hook(hook, priority='NORMAL') def register_logger_hooks(self, log_config): if log_config is None: return log_interval = log_config['interval'] for info in log_config['hooks']: logger_hook = mmcv.build_from_cfg( info, HOOKS, default_args=dict(interval=log_interval)) self.register_hook(logger_hook, priority='VERY_LOW') def register_timer_hook(self, timer_config): if timer_config is None: return if isinstance(timer_config, dict): timer_config_ = copy.deepcopy(timer_config) hook = mmcv.build_from_cfg(timer_config_, HOOKS) else: hook = timer_config self.register_hook(hook, priority='LOW') def register_custom_hooks(self, custom_config): if custom_config is None: return if not isinstance(custom_config, list): custom_config = [custom_config] for item in custom_config: if isinstance(item, dict): self.register_hook_from_cfg(item) else: self.register_hook(item, priority='NORMAL') def register_profiler_hook(self, profiler_config): if profiler_config is None: return if isinstance(profiler_config, dict): profiler_config.setdefault('type', 'ProfilerHook') hook = mmcv.build_from_cfg(profiler_config, HOOKS) else: hook = profiler_config self.register_hook(hook)
[docs] def register_training_hooks(self, lr_config, optimizer_config=None, checkpoint_config=None, log_config=None, momentum_config=None, timer_config=dict(type='IterTimerHook'), custom_hooks_config=None): """Register default and custom hooks for training. Default and custom hooks include: +----------------------+-------------------------+ | Hooks | Priority | +======================+=========================+ | LrUpdaterHook | VERY_HIGH (10) | +----------------------+-------------------------+ | MomentumUpdaterHook | HIGH (30) | +----------------------+-------------------------+ | OptimizerStepperHook | ABOVE_NORMAL (40) | +----------------------+-------------------------+ | CheckpointSaverHook | NORMAL (50) | +----------------------+-------------------------+ | IterTimerHook | LOW (70) | +----------------------+-------------------------+ | LoggerHook(s) | VERY_LOW (90) | +----------------------+-------------------------+ | CustomHook(s) | defaults to NORMAL (50) | +----------------------+-------------------------+ If custom hooks have same priority with default hooks, custom hooks will be triggered after default hooks. """ self.register_lr_hook(lr_config) self.register_momentum_hook(momentum_config) self.register_optimizer_hook(optimizer_config) self.register_checkpoint_hook(checkpoint_config) self.register_timer_hook(timer_config) self.register_logger_hooks(log_config) self.register_custom_hooks(custom_hooks_config)