Shortcuts

Source code for mmcv.runner.hooks.checkpoint

# Copyright (c) OpenMMLab. All rights reserved.
import os

from ..dist_utils import allreduce_params, master_only
from .hook import HOOKS, Hook


[docs]@HOOKS.register_module() class CheckpointHook(Hook): """Save checkpoints periodically. Args: interval (int): The saving period. If ``by_epoch=True``, interval indicates epochs, otherwise it indicates iterations. Default: -1, which means "never". by_epoch (bool): Saving checkpoints by epoch or by iteration. Default: True. save_optimizer (bool): Whether to save optimizer state_dict in the checkpoint. It is usually used for resuming experiments. Default: True. out_dir (str, optional): The directory to save checkpoints. If not specified, ``runner.work_dir`` will be used by default. max_keep_ckpts (int, optional): The maximum checkpoints to keep. In some cases we want only the latest few checkpoints and would like to delete old ones to save the disk space. Default: -1, which means unlimited. save_last (bool): Whether to force the last checkpoint to be saved regardless of interval. sync_buffer (bool): Whether to synchronize buffers in different gpus. Default: False. """ def __init__(self, interval=-1, by_epoch=True, save_optimizer=True, out_dir=None, max_keep_ckpts=-1, save_last=True, sync_buffer=False, **kwargs): self.interval = interval self.by_epoch = by_epoch self.save_optimizer = save_optimizer self.out_dir = out_dir self.max_keep_ckpts = max_keep_ckpts self.save_last = save_last self.args = kwargs self.sync_buffer = sync_buffer def before_run(self, runner): if not self.out_dir: self.out_dir = runner.work_dir def after_train_epoch(self, runner): if not self.by_epoch: return # save checkpoint for following cases: # 1. every ``self.interval`` epochs # 2. reach the last epoch of training if self.every_n_epochs( runner, self.interval) or (self.save_last and self.is_last_epoch(runner)): runner.logger.info( f'Saving checkpoint at {runner.epoch + 1} epochs') if self.sync_buffer: allreduce_params(runner.model.buffers()) self._save_checkpoint(runner) @master_only def _save_checkpoint(self, runner): """Save the current checkpoint and delete unwanted checkpoint.""" runner.save_checkpoint( self.out_dir, save_optimizer=self.save_optimizer, **self.args) if runner.meta is not None: if self.by_epoch: cur_ckpt_filename = self.args.get( 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1) else: cur_ckpt_filename = self.args.get( 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) runner.meta.setdefault('hook_msgs', dict()) runner.meta['hook_msgs']['last_ckpt'] = os.path.join( self.out_dir, cur_ckpt_filename) # remove other checkpoints if self.max_keep_ckpts > 0: if self.by_epoch: name = 'epoch_{}.pth' current_ckpt = runner.epoch + 1 else: name = 'iter_{}.pth' current_ckpt = runner.iter + 1 redundant_ckpts = range( current_ckpt - self.max_keep_ckpts * self.interval, 0, -self.interval) filename_tmpl = self.args.get('filename_tmpl', name) for _step in redundant_ckpts: ckpt_path = os.path.join(self.out_dir, filename_tmpl.format(_step)) if os.path.exists(ckpt_path): os.remove(ckpt_path) else: break def after_train_iter(self, runner): if self.by_epoch: return # save checkpoint for following cases: # 1. every ``self.interval`` iterations # 2. reach the last iteration of training if self.every_n_iters( runner, self.interval) or (self.save_last and self.is_last_iter(runner)): runner.logger.info( f'Saving checkpoint at {runner.iter + 1} iterations') if self.sync_buffer: allreduce_params(runner.model.buffers()) self._save_checkpoint(runner)
Read the Docs v: v1.3.14
Versions
master
latest
v1.3.14
v1.3.13
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.