mmcv.runner.base_runner 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import os.path as osp
import warnings
from abc import ABCMeta, abstractmethod

import torch
from torch.optim import Optimizer

import mmcv
from ..parallel import is_module_wrapper
from .checkpoint import load_checkpoint
from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook
from .log_buffer import LogBuffer
from .priority import Priority, get_priority
from .utils import get_time_str

[文档]class BaseRunner(metaclass=ABCMeta): """The base class of Runner, a training helper for PyTorch. All subclasses should implement the following APIs: - ``run()`` - ``train()`` - ``val()`` - ``save_checkpoint()`` Args: model (:obj:`torch.nn.Module`): The model to be run. batch_processor (callable): A callable method that process a data batch. The interface of this method should be `batch_processor(model, data, train_mode) -> dict` optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an optimizer (in most cases) or a dict of optimizers (in models that requires more than one optimizer, e.g., GAN). work_dir (str, optional): The working directory to save checkpoints and logs. Defaults to None. logger (:obj:`logging.Logger`): Logger used during training. Defaults to None. (The default value is just for backward compatibility) meta (dict | None): A dict records some import information such as environment info and seed, which will be logged in logger hook. Defaults to None. max_epochs (int, optional): Total training epochs. max_iters (int, optional): Total training iterations. """ def __init__(self, model, batch_processor=None, optimizer=None, work_dir=None, logger=None, meta=None, max_iters=None, max_epochs=None): if batch_processor is not None: if not callable(batch_processor): raise TypeError('batch_processor must be callable, ' f'but got {type(batch_processor)}') warnings.warn('batch_processor is deprecated, please implement ' 'train_step() and val_step() in the model instead.') # raise an error is `batch_processor` is not None and # `model.train_step()` exists. if is_module_wrapper(model): _model = model.module else: _model = model if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'): raise RuntimeError( 'batch_processor and model.train_step()/model.val_step() ' 'cannot be both available.') else: assert hasattr(model, 'train_step') # check the type of `optimizer` if isinstance(optimizer, dict): for name, optim in optimizer.items(): if not isinstance(optim, Optimizer): raise TypeError( f'optimizer must be a dict of torch.optim.Optimizers, ' f'but optimizer["{name}"] is a {type(optim)}') elif not isinstance(optimizer, Optimizer) and optimizer is not None: raise TypeError( f'optimizer must be a torch.optim.Optimizer object ' f'or dict or None, but got {type(optimizer)}') # check the type of `logger` if not isinstance(logger, logging.Logger): raise TypeError(f'logger must be a logging.Logger object, ' f'but got {type(logger)}') # check the type of `meta` if meta is not None and not isinstance(meta, dict): raise TypeError( f'meta must be a dict or None, but got {type(meta)}') self.model = model self.batch_processor = batch_processor self.optimizer = optimizer self.logger = logger self.meta = meta # create work_dir if mmcv.is_str(work_dir): self.work_dir = osp.abspath(work_dir) mmcv.mkdir_or_exist(self.work_dir) elif work_dir is None: self.work_dir = None else: raise TypeError('"work_dir" must be a str or None') # get model name from the model class if hasattr(self.model, 'module'): self._model_name = self.model.module.__class__.__name__ else: self._model_name = self.model.__class__.__name__ self._rank, self._world_size = get_dist_info() self.timestamp = get_time_str() self.mode = None self._hooks = [] self._epoch = 0 self._iter = 0 self._inner_iter = 0 if max_epochs is not None and max_iters is not None: raise ValueError( 'Only one of `max_epochs` or `max_iters` can be set.') self._max_epochs = max_epochs self._max_iters = max_iters # TODO: Redesign LogBuffer, it is not flexible and elegant enough self.log_buffer = LogBuffer() @property def model_name(self): """str: Name of the model, usually the module class name.""" return self._model_name @property def rank(self): """int: Rank of current process. (distributed training)""" return self._rank @property def world_size(self): """int: Number of processes participating in the job. (distributed training)""" return self._world_size @property def hooks(self): """list[:obj:`Hook`]: A list of registered hooks.""" return self._hooks @property def epoch(self): """int: Current epoch.""" return self._epoch @property def iter(self): """int: Current iteration.""" return self._iter @property def inner_iter(self): """int: Iteration in an epoch.""" return self._inner_iter @property def max_epochs(self): """int: Maximum training epochs.""" return self._max_epochs @property def max_iters(self): """int: Maximum training iterations.""" return self._max_iters @abstractmethod def train(self): pass @abstractmethod def val(self): pass @abstractmethod def run(self, data_loaders, workflow, **kwargs): pass @abstractmethod def save_checkpoint(self, out_dir, filename_tmpl, save_optimizer=True, meta=None, create_symlink=True): pass
[文档] def current_lr(self): """Get current learning rates. Returns: list[float] | dict[str, list[float]]: Current learning rates of all param groups. If the runner has a dict of optimizers, this method will return a dict. """ if isinstance(self.optimizer, torch.optim.Optimizer): lr = [group['lr'] for group in self.optimizer.param_groups] elif isinstance(self.optimizer, dict): lr = dict() for name, optim in self.optimizer.items(): lr[name] = [group['lr'] for group in optim.param_groups] else: raise RuntimeError( 'lr is not applicable because optimizer does not exist.') return lr
[文档] def current_momentum(self): """Get current momentums. Returns: list[float] | dict[str, list[float]]: Current momentums of all param groups. If the runner has a dict of optimizers, this method will return a dict. """ def _get_momentum(optimizer): momentums = [] for group in optimizer.param_groups: if 'momentum' in group.keys(): momentums.append(group['momentum']) elif 'betas' in group.keys(): momentums.append(group['betas'][0]) else: momentums.append(0) return momentums if self.optimizer is None: raise RuntimeError( 'momentum is not applicable because optimizer does not exist.') elif isinstance(self.optimizer, torch.optim.Optimizer): momentums = _get_momentum(self.optimizer) elif isinstance(self.optimizer, dict): momentums = dict() for name, optim in self.optimizer.items(): momentums[name] = _get_momentum(optim) return momentums
[文档] def register_hook(self, hook, priority='NORMAL'): """Register a hook into the hook list. The hook will be inserted into a priority queue, with the specified priority (See :class:`Priority` for details of priorities). For hooks with the same priority, they will be triggered in the same order as they are registered. Args: hook (:obj:`Hook`): The hook to be registered. priority (int or str or :obj:`Priority`): Hook priority. Lower value means higher priority. """ assert isinstance(hook, Hook) if hasattr(hook, 'priority'): raise ValueError('"priority" is a reserved attribute for hooks') priority = get_priority(priority) hook.priority = priority # insert the hook to a sorted list inserted = False for i in range(len(self._hooks) - 1, -1, -1): if priority >= self._hooks[i].priority: self._hooks.insert(i + 1, hook) inserted = True break if not inserted: self._hooks.insert(0, hook)
[文档] def register_hook_from_cfg(self, hook_cfg): """Register a hook from its cfg. Args: hook_cfg (dict): Hook config. It should have at least keys 'type' and 'priority' indicating its type and priority. Notes: The specific hook class to register should not use 'type' and 'priority' arguments during initialization. """ hook_cfg = hook_cfg.copy() priority = hook_cfg.pop('priority', 'NORMAL') hook = mmcv.build_from_cfg(hook_cfg, HOOKS) self.register_hook(hook, priority=priority)
[文档] def call_hook(self, fn_name): """Call all hooks. Args: fn_name (str): The function name in each hook to be called, such as "before_train_epoch". """ for hook in self._hooks: getattr(hook, fn_name)(self)
def get_hook_info(self): # Get hooks info in each stage stage_hook_map = {stage: [] for stage in Hook.stages} for hook in self.hooks: try: priority = Priority(hook.priority).name except ValueError: priority = hook.priority classname = hook.__class__.__name__ hook_info = f'({priority:<12}) {classname:<35}' for trigger_stage in hook.get_triggered_stages(): stage_hook_map[trigger_stage].append(hook_info) stage_hook_infos = [] for stage in Hook.stages: hook_infos = stage_hook_map[stage] if len(hook_infos) > 0: info = f'{stage}:\n' info += '\n'.join(hook_infos) info += '\n -------------------- ' stage_hook_infos.append(info) return '\n'.join(stage_hook_infos) def load_checkpoint(self, filename, map_location='cpu', strict=False, revise_keys=[(r'^module.', '')]):'load checkpoint from %s', filename) return load_checkpoint( self.model, filename, map_location, strict, self.logger, revise_keys=revise_keys) def resume(self, checkpoint, resume_optimizer=True, map_location='default'): if map_location == 'default': if torch.cuda.is_available(): device_id = torch.cuda.current_device() checkpoint = self.load_checkpoint( checkpoint, map_location=lambda storage, loc: storage.cuda(device_id)) else: checkpoint = self.load_checkpoint(checkpoint) else: checkpoint = self.load_checkpoint( checkpoint, map_location=map_location) self._epoch = checkpoint['meta']['epoch'] self._iter = checkpoint['meta']['iter'] if self.meta is None: self.meta = {} self.meta.setdefault('hook_msgs', {}) # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {})) # Re-calculate the number of iterations when resuming # models with different number of GPUs if 'config' in checkpoint['meta']: config = mmcv.Config.fromstring( checkpoint['meta']['config'], file_format='.py') previous_gpu_ids = config.get('gpu_ids', None) if previous_gpu_ids and len(previous_gpu_ids) > 0 and len( previous_gpu_ids) != self.world_size: self._iter = int(self._iter * len(previous_gpu_ids) / self.world_size)'the iteration number is changed due to ' 'change of GPU number') # resume meta information meta self.meta = checkpoint['meta'] if 'optimizer' in checkpoint and resume_optimizer: if isinstance(self.optimizer, Optimizer): self.optimizer.load_state_dict(checkpoint['optimizer']) elif isinstance(self.optimizer, dict): for k in self.optimizer.keys(): self.optimizer[k].load_state_dict( checkpoint['optimizer'][k]) else: raise TypeError( 'Optimizer should be dict or torch.optim.Optimizer ' f'but got {type(self.optimizer)}')'resumed epoch %d, iter %d', self.epoch, self.iter) def register_lr_hook(self, lr_config): if lr_config is None: return elif isinstance(lr_config, dict): assert 'policy' in lr_config policy_type = lr_config.pop('policy') # If the type of policy is all in lower case, e.g., 'cyclic', # then its first letter will be capitalized, e.g., to be 'Cyclic'. # This is for the convenient usage of Lr updater. # Since this is not applicable for ` # CosineAnnealingLrUpdater`, # the string will not be changed if it contains capital letters. if policy_type == policy_type.lower(): policy_type = policy_type.title() hook_type = policy_type + 'LrUpdaterHook' lr_config['type'] = hook_type hook = mmcv.build_from_cfg(lr_config, HOOKS) else: hook = lr_config self.register_hook(hook, priority='VERY_HIGH') def register_momentum_hook(self, momentum_config): if momentum_config is None: return if isinstance(momentum_config, dict): assert 'policy' in momentum_config policy_type = momentum_config.pop('policy') # If the type of policy is all in lower case, e.g., 'cyclic', # then its first letter will be capitalized, e.g., to be 'Cyclic'. # This is for the convenient usage of momentum updater. # Since this is not applicable for # `CosineAnnealingMomentumUpdater`, # the string will not be changed if it contains capital letters. if policy_type == policy_type.lower(): policy_type = policy_type.title() hook_type = policy_type + 'MomentumUpdaterHook' momentum_config['type'] = hook_type hook = mmcv.build_from_cfg(momentum_config, HOOKS) else: hook = momentum_config self.register_hook(hook, priority='HIGH') def register_optimizer_hook(self, optimizer_config): if optimizer_config is None: return if isinstance(optimizer_config, dict): optimizer_config.setdefault('type', 'OptimizerHook') hook = mmcv.build_from_cfg(optimizer_config, HOOKS) else: hook = optimizer_config self.register_hook(hook, priority='ABOVE_NORMAL') def register_checkpoint_hook(self, checkpoint_config): if checkpoint_config is None: return if isinstance(checkpoint_config, dict): checkpoint_config.setdefault('type', 'CheckpointHook') hook = mmcv.build_from_cfg(checkpoint_config, HOOKS) else: hook = checkpoint_config self.register_hook(hook, priority='NORMAL') def register_logger_hooks(self, log_config): if log_config is None: return log_interval = log_config['interval'] for info in log_config['hooks']: logger_hook = mmcv.build_from_cfg( info, HOOKS, default_args=dict(interval=log_interval)) self.register_hook(logger_hook, priority='VERY_LOW') def register_timer_hook(self, timer_config): if timer_config is None: return if isinstance(timer_config, dict): timer_config_ = copy.deepcopy(timer_config) hook = mmcv.build_from_cfg(timer_config_, HOOKS) else: hook = timer_config self.register_hook(hook, priority='LOW') def register_custom_hooks(self, custom_config): if custom_config is None: return if not isinstance(custom_config, list): custom_config = [custom_config] for item in custom_config: if isinstance(item, dict): self.register_hook_from_cfg(item) else: self.register_hook(item, priority='NORMAL') def register_profiler_hook(self, profiler_config): if profiler_config is None: return if isinstance(profiler_config, dict): profiler_config.setdefault('type', 'ProfilerHook') hook = mmcv.build_from_cfg(profiler_config, HOOKS) else: hook = profiler_config self.register_hook(hook)
[文档] def register_training_hooks(self, lr_config, optimizer_config=None, checkpoint_config=None, log_config=None, momentum_config=None, timer_config=dict(type='IterTimerHook'), custom_hooks_config=None): """Register default and custom hooks for training. Default and custom hooks include: +----------------------+-------------------------+ | Hooks | Priority | +======================+=========================+ | LrUpdaterHook | VERY_HIGH (10) | +----------------------+-------------------------+ | MomentumUpdaterHook | HIGH (30) | +----------------------+-------------------------+ | OptimizerStepperHook | ABOVE_NORMAL (40) | +----------------------+-------------------------+ | CheckpointSaverHook | NORMAL (50) | +----------------------+-------------------------+ | IterTimerHook | LOW (70) | +----------------------+-------------------------+ | LoggerHook(s) | VERY_LOW (90) | +----------------------+-------------------------+ | CustomHook(s) | defaults to NORMAL (50) | +----------------------+-------------------------+ If custom hooks have same priority with default hooks, custom hooks will be triggered after default hooks. """ self.register_lr_hook(lr_config) self.register_momentum_hook(momentum_config) self.register_optimizer_hook(optimizer_config) self.register_checkpoint_hook(checkpoint_config) self.register_timer_hook(timer_config) self.register_logger_hooks(log_config) self.register_custom_hooks(custom_hooks_config)
Read the Docs v: v1.3.15
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.