mmcv.runner.iter_based_runner 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import platform
import shutil
import time
import warnings

import torch
from torch.optim import Optimizer

import mmcv
from .base_runner import BaseRunner
from .builder import RUNNERS
from .checkpoint import save_checkpoint
from .hooks import IterTimerHook
from .utils import get_host_info

class IterLoader:

    def __init__(self, dataloader):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)
        self._epoch = 0

    def epoch(self):
        return self._epoch

    def __next__(self):
            data = next(self.iter_loader)
        except StopIteration:
            self._epoch += 1
            if hasattr(self._dataloader.sampler, 'set_epoch'):
            self.iter_loader = iter(self._dataloader)
            data = next(self.iter_loader)

        return data

    def __len__(self):
        return len(self._dataloader)

[文档]@RUNNERS.register_module() class IterBasedRunner(BaseRunner): """Iteration-based Runner. This runner train models iteration by iteration. """ def train(self, data_loader, **kwargs): self.model.train() self.mode = 'train' self.data_loader = data_loader self._epoch = data_loader.epoch data_batch = next(data_loader) self.call_hook('before_train_iter') outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) if not isinstance(outputs, dict): raise TypeError('model.train_step() must return a dict') if 'log_vars' in outputs: self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) self.outputs = outputs self.call_hook('after_train_iter') self._inner_iter += 1 self._iter += 1 @torch.no_grad() def val(self, data_loader, **kwargs): self.model.eval() self.mode = 'val' self.data_loader = data_loader data_batch = next(data_loader) self.call_hook('before_val_iter') outputs = self.model.val_step(data_batch, **kwargs) if not isinstance(outputs, dict): raise TypeError('model.val_step() must return a dict') if 'log_vars' in outputs: self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) self.outputs = outputs self.call_hook('after_val_iter') self._inner_iter += 1
[文档] def run(self, data_loaders, workflow, max_iters=None, **kwargs): """Start running. Args: data_loaders (list[:obj:`DataLoader`]): Dataloaders for training and validation. workflow (list[tuple]): A list of (phase, iters) to specify the running order and iterations. E.g, [('train', 10000), ('val', 1000)] means running 10000 iterations for training and 1000 iterations for validation, iteratively. """ assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) assert len(data_loaders) == len(workflow) if max_iters is not None: warnings.warn( 'setting max_iters in run is deprecated, ' 'please set max_iters in runner_config', DeprecationWarning) self._max_iters = max_iters assert self._max_iters is not None, ( 'max_iters must be specified during instantiation') work_dir = self.work_dir if self.work_dir is not None else 'NONE''Start running, host: %s, work_dir: %s', get_host_info(), work_dir)'Hooks will be executed in the following order:\n%s', self.get_hook_info())'workflow: %s, max: %d iters', workflow, self._max_iters) self.call_hook('before_run') iter_loaders = [IterLoader(x) for x in data_loaders] self.call_hook('before_epoch') while self.iter < self._max_iters: for i, flow in enumerate(workflow): self._inner_iter = 0 mode, iters = flow if not isinstance(mode, str) or not hasattr(self, mode): raise ValueError( 'runner has no method named "{}" to run a workflow'. format(mode)) iter_runner = getattr(self, mode) for _ in range(iters): if mode == 'train' and self.iter >= self._max_iters: break iter_runner(iter_loaders[i], **kwargs) time.sleep(1) # wait for some hooks like loggers to finish self.call_hook('after_epoch') self.call_hook('after_run')
[文档] def resume(self, checkpoint, resume_optimizer=True, map_location='default'): """Resume model from checkpoint. Args: checkpoint (str): Checkpoint to resume from. resume_optimizer (bool, optional): Whether resume the optimizer(s) if the checkpoint file includes optimizer(s). Default to True. map_location (str, optional): Same as :func:`torch.load`. Default to 'default'. """ if map_location == 'default': device_id = torch.cuda.current_device() checkpoint = self.load_checkpoint( checkpoint, map_location=lambda storage, loc: storage.cuda(device_id)) else: checkpoint = self.load_checkpoint( checkpoint, map_location=map_location) self._epoch = checkpoint['meta']['epoch'] self._iter = checkpoint['meta']['iter'] self._inner_iter = checkpoint['meta']['iter'] if 'optimizer' in checkpoint and resume_optimizer: if isinstance(self.optimizer, Optimizer): self.optimizer.load_state_dict(checkpoint['optimizer']) elif isinstance(self.optimizer, dict): for k in self.optimizer.keys(): self.optimizer[k].load_state_dict( checkpoint['optimizer'][k]) else: raise TypeError( 'Optimizer should be dict or torch.optim.Optimizer ' f'but got {type(self.optimizer)}')'resumed from epoch: {self.epoch}, iter {self.iter}')
[文档] def save_checkpoint(self, out_dir, filename_tmpl='iter_{}.pth', meta=None, save_optimizer=True, create_symlink=True): """Save checkpoint to file. Args: out_dir (str): Directory to save checkpoint files. filename_tmpl (str, optional): Checkpoint file template. Defaults to 'iter_{}.pth'. meta (dict, optional): Metadata to be saved in checkpoint. Defaults to None. save_optimizer (bool, optional): Whether save optimizer. Defaults to True. create_symlink (bool, optional): Whether create symlink to the latest checkpoint file. Defaults to True. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError( f'meta should be a dict or None, but got {type(meta)}') if self.meta is not None: meta.update(self.meta) # Note: meta.update(self.meta) should be done before # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise # there will be problems with resumed checkpoints. # More details in meta.update(epoch=self.epoch + 1, iter=self.iter) filename = filename_tmpl.format(self.iter + 1) filepath = osp.join(out_dir, filename) optimizer = self.optimizer if save_optimizer else None save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) # in some environments, `os.symlink` is not supported, you may need to # set `create_symlink` to False if create_symlink: dst_file = osp.join(out_dir, 'latest.pth') if platform.system() != 'Windows': mmcv.symlink(filename, dst_file) else: shutil.copy(filepath, dst_file)
[文档] def register_training_hooks(self, lr_config, optimizer_config=None, checkpoint_config=None, log_config=None, momentum_config=None, custom_hooks_config=None): """Register default hooks for iter-based training. Checkpoint hook, optimizer stepper hook and logger hooks will be set to `by_epoch=False` by default. Default hooks include: +----------------------+-------------------------+ | Hooks | Priority | +======================+=========================+ | LrUpdaterHook | VERY_HIGH (10) | +----------------------+-------------------------+ | MomentumUpdaterHook | HIGH (30) | +----------------------+-------------------------+ | OptimizerStepperHook | ABOVE_NORMAL (40) | +----------------------+-------------------------+ | CheckpointSaverHook | NORMAL (50) | +----------------------+-------------------------+ | IterTimerHook | LOW (70) | +----------------------+-------------------------+ | LoggerHook(s) | VERY_LOW (90) | +----------------------+-------------------------+ | CustomHook(s) | defaults to NORMAL (50) | +----------------------+-------------------------+ If custom hooks have same priority with default hooks, custom hooks will be triggered after default hooks. """ if checkpoint_config is not None: checkpoint_config.setdefault('by_epoch', False) if lr_config is not None: lr_config.setdefault('by_epoch', False) if log_config is not None: for info in log_config['hooks']: info.setdefault('by_epoch', False) super(IterBasedRunner, self).register_training_hooks( lr_config=lr_config, momentum_config=momentum_config, optimizer_config=optimizer_config, checkpoint_config=checkpoint_config, log_config=log_config, timer_config=IterTimerHook(), custom_hooks_config=custom_hooks_config)
Read the Docs v: v1.3.15
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.