Shortcuts

mmcv.runner.iter_based_runner 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import platform
import shutil
import time
import warnings

import torch
from torch.optim import Optimizer

import mmcv
from .base_runner import BaseRunner
from .builder import RUNNERS
from .checkpoint import save_checkpoint
from .hooks import IterTimerHook
from .utils import get_host_info


class IterLoader:

    def __init__(self, dataloader):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)
        self._epoch = 0

    @property
    def epoch(self):
        return self._epoch

    def __next__(self):
        try:
            data = next(self.iter_loader)
        except StopIteration:
            self._epoch += 1
            if hasattr(self._dataloader.sampler, 'set_epoch'):
                self._dataloader.sampler.set_epoch(self._epoch)
            time.sleep(2)  # Prevent possible deadlock during epoch transition
            self.iter_loader = iter(self._dataloader)
            data = next(self.iter_loader)

        return data

    def __len__(self):
        return len(self._dataloader)


[文档]@RUNNERS.register_module() class IterBasedRunner(BaseRunner): """Iteration-based Runner. This runner train models iteration by iteration. """ def train(self, data_loader, **kwargs): self.model.train() self.mode = 'train' self.data_loader = data_loader self._epoch = data_loader.epoch data_batch = next(data_loader) self.call_hook('before_train_iter') outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) if not isinstance(outputs, dict): raise TypeError('model.train_step() must return a dict') if 'log_vars' in outputs: self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) self.outputs = outputs self.call_hook('after_train_iter') self._inner_iter += 1 self._iter += 1 @torch.no_grad() def val(self, data_loader, **kwargs): self.model.eval() self.mode = 'val' self.data_loader = data_loader data_batch = next(data_loader) self.call_hook('before_val_iter') outputs = self.model.val_step(data_batch, **kwargs) if not isinstance(outputs, dict): raise TypeError('model.val_step() must return a dict') if 'log_vars' in outputs: self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) self.outputs = outputs self.call_hook('after_val_iter') self._inner_iter += 1
[文档] def run(self, data_loaders, workflow, max_iters=None, **kwargs): """Start running. Args: data_loaders (list[:obj:`DataLoader`]): Dataloaders for training and validation. workflow (list[tuple]): A list of (phase, iters) to specify the running order and iterations. E.g, [('train', 10000), ('val', 1000)] means running 10000 iterations for training and 1000 iterations for validation, iteratively. """ assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) assert len(data_loaders) == len(workflow) if max_iters is not None: warnings.warn( 'setting max_iters in run is deprecated, ' 'please set max_iters in runner_config', DeprecationWarning) self._max_iters = max_iters assert self._max_iters is not None, ( 'max_iters must be specified during instantiation') work_dir = self.work_dir if self.work_dir is not None else 'NONE' self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir) self.logger.info('Hooks will be executed in the following order:\n%s', self.get_hook_info()) self.logger.info('workflow: %s, max: %d iters', workflow, self._max_iters) self.call_hook('before_run') iter_loaders = [IterLoader(x) for x in data_loaders] self.call_hook('before_epoch') while self.iter < self._max_iters: for i, flow in enumerate(workflow): self._inner_iter = 0 mode, iters = flow if not isinstance(mode, str) or not hasattr(self, mode): raise ValueError( 'runner has no method named "{}" to run a workflow'. format(mode)) iter_runner = getattr(self, mode) for _ in range(iters): if mode == 'train' and self.iter >= self._max_iters: break iter_runner(iter_loaders[i], **kwargs) time.sleep(1) # wait for some hooks like loggers to finish self.call_hook('after_epoch') self.call_hook('after_run')
[文档] def resume(self, checkpoint, resume_optimizer=True, map_location='default'): """Resume model from checkpoint. Args: checkpoint (str): Checkpoint to resume from. resume_optimizer (bool, optional): Whether resume the optimizer(s) if the checkpoint file includes optimizer(s). Default to True. map_location (str, optional): Same as :func:`torch.load`. Default to 'default'. """ if map_location == 'default': device_id = torch.cuda.current_device() checkpoint = self.load_checkpoint( checkpoint, map_location=lambda storage, loc: storage.cuda(device_id)) else: checkpoint = self.load_checkpoint( checkpoint, map_location=map_location) self._epoch = checkpoint['meta']['epoch'] self._iter = checkpoint['meta']['iter'] self._inner_iter = checkpoint['meta']['iter'] if 'optimizer' in checkpoint and resume_optimizer: if isinstance(self.optimizer, Optimizer): self.optimizer.load_state_dict(checkpoint['optimizer']) elif isinstance(self.optimizer, dict): for k in self.optimizer.keys(): self.optimizer[k].load_state_dict( checkpoint['optimizer'][k]) else: raise TypeError( 'Optimizer should be dict or torch.optim.Optimizer ' f'but got {type(self.optimizer)}') self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
[文档] def save_checkpoint(self, out_dir, filename_tmpl='iter_{}.pth', meta=None, save_optimizer=True, create_symlink=True): """Save checkpoint to file. Args: out_dir (str): Directory to save checkpoint files. filename_tmpl (str, optional): Checkpoint file template. Defaults to 'iter_{}.pth'. meta (dict, optional): Metadata to be saved in checkpoint. Defaults to None. save_optimizer (bool, optional): Whether save optimizer. Defaults to True. create_symlink (bool, optional): Whether create symlink to the latest checkpoint file. Defaults to True. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError( f'meta should be a dict or None, but got {type(meta)}') if self.meta is not None: meta.update(self.meta) # Note: meta.update(self.meta) should be done before # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise # there will be problems with resumed checkpoints. # More details in https://github.com/open-mmlab/mmcv/pull/1108 meta.update(epoch=self.epoch + 1, iter=self.iter) filename = filename_tmpl.format(self.iter + 1) filepath = osp.join(out_dir, filename) optimizer = self.optimizer if save_optimizer else None save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) # in some environments, `os.symlink` is not supported, you may need to # set `create_symlink` to False if create_symlink: dst_file = osp.join(out_dir, 'latest.pth') if platform.system() != 'Windows': mmcv.symlink(filename, dst_file) else: shutil.copy(filepath, dst_file)
[文档] def register_training_hooks(self, lr_config, optimizer_config=None, checkpoint_config=None, log_config=None, momentum_config=None, custom_hooks_config=None): """Register default hooks for iter-based training. Checkpoint hook, optimizer stepper hook and logger hooks will be set to `by_epoch=False` by default. Default hooks include: +----------------------+-------------------------+ | Hooks | Priority | +======================+=========================+ | LrUpdaterHook | VERY_HIGH (10) | +----------------------+-------------------------+ | MomentumUpdaterHook | HIGH (30) | +----------------------+-------------------------+ | OptimizerStepperHook | ABOVE_NORMAL (40) | +----------------------+-------------------------+ | CheckpointSaverHook | NORMAL (50) | +----------------------+-------------------------+ | IterTimerHook | LOW (70) | +----------------------+-------------------------+ | LoggerHook(s) | VERY_LOW (90) | +----------------------+-------------------------+ | CustomHook(s) | defaults to NORMAL (50) | +----------------------+-------------------------+ If custom hooks have same priority with default hooks, custom hooks will be triggered after default hooks. """ if checkpoint_config is not None: checkpoint_config.setdefault('by_epoch', False) if lr_config is not None: lr_config.setdefault('by_epoch', False) if log_config is not None: for info in log_config['hooks']: info.setdefault('by_epoch', False) super(IterBasedRunner, self).register_training_hooks( lr_config=lr_config, momentum_config=momentum_config, optimizer_config=optimizer_config, checkpoint_config=checkpoint_config, log_config=log_config, timer_config=IterTimerHook(), custom_hooks_config=custom_hooks_config)
Read the Docs v: v1.5.1
Versions
latest
stable
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.