Shortcuts

Source code for mmcv.runner.hooks.checkpoint

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Optional

from mmcv.fileio import FileClient
from ..dist_utils import allreduce_params, master_only
from .hook import HOOKS, Hook


[docs]@HOOKS.register_module() class CheckpointHook(Hook): """Save checkpoints periodically. Args: interval (int): The saving period. If ``by_epoch=True``, interval indicates epochs, otherwise it indicates iterations. Default: -1, which means "never". by_epoch (bool): Saving checkpoints by epoch or by iteration. Default: True. save_optimizer (bool): Whether to save optimizer state_dict in the checkpoint. It is usually used for resuming experiments. Default: True. out_dir (str, optional): The root directory to save checkpoints. If not specified, ``runner.work_dir`` will be used by default. If specified, the ``out_dir`` will be the concatenation of ``out_dir`` and the last level directory of ``runner.work_dir``. `Changed in version 1.3.16.` max_keep_ckpts (int, optional): The maximum checkpoints to keep. In some cases we want only the latest few checkpoints and would like to delete old ones to save the disk space. Default: -1, which means unlimited. save_last (bool, optional): Whether to force the last checkpoint to be saved regardless of interval. Default: True. sync_buffer (bool, optional): Whether to synchronize buffers in different gpus. Default: False. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmcv.fileio.FileClient` for details. Default: None. `New in version 1.3.16.` .. warning:: Before v1.3.16, the ``out_dir`` argument indicates the path where the checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the root directory and the final path to save checkpoint is the concatenation of ``out_dir`` and the last level directory of ``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A" and the value of ``runner.work_dir`` is "/path/of/B", then the final path will be "/path/of/A/B". """ def __init__(self, interval: int = -1, by_epoch: bool = True, save_optimizer: bool = True, out_dir: Optional[str] = None, max_keep_ckpts: int = -1, save_last: bool = True, sync_buffer: bool = False, file_client_args: Optional[dict] = None, **kwargs): self.interval = interval self.by_epoch = by_epoch self.save_optimizer = save_optimizer self.out_dir = out_dir self.max_keep_ckpts = max_keep_ckpts self.save_last = save_last self.args = kwargs self.sync_buffer = sync_buffer self.file_client_args = file_client_args def before_run(self, runner): if not self.out_dir: self.out_dir = runner.work_dir self.file_client = FileClient.infer_client(self.file_client_args, self.out_dir) # if `self.out_dir` is not equal to `runner.work_dir`, it means that # `self.out_dir` is set so the final `self.out_dir` is the # concatenation of `self.out_dir` and the last level directory of # `runner.work_dir` if self.out_dir != runner.work_dir: basename = osp.basename(runner.work_dir.rstrip(osp.sep)) self.out_dir = self.file_client.join_path(self.out_dir, basename) runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by ' f'{self.file_client.name}.') # disable the create_symlink option because some file backends do not # allow to create a symlink if 'create_symlink' in self.args: if self.args[ 'create_symlink'] and not self.file_client.allow_symlink: self.args['create_symlink'] = False warnings.warn( 'create_symlink is set as True by the user but is changed' 'to be False because creating symbolic link is not ' f'allowed in {self.file_client.name}') else: self.args['create_symlink'] = self.file_client.allow_symlink def after_train_epoch(self, runner): if not self.by_epoch: return # save checkpoint for following cases: # 1. every ``self.interval`` epochs # 2. reach the last epoch of training if self.every_n_epochs( runner, self.interval) or (self.save_last and self.is_last_epoch(runner)): runner.logger.info( f'Saving checkpoint at {runner.epoch + 1} epochs') if self.sync_buffer: allreduce_params(runner.model.buffers()) self._save_checkpoint(runner) @master_only def _save_checkpoint(self, runner): """Save the current checkpoint and delete unwanted checkpoint.""" runner.save_checkpoint( self.out_dir, save_optimizer=self.save_optimizer, **self.args) if runner.meta is not None: if self.by_epoch: cur_ckpt_filename = self.args.get( 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1) else: cur_ckpt_filename = self.args.get( 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) runner.meta.setdefault('hook_msgs', dict()) runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path( self.out_dir, cur_ckpt_filename) # remove other checkpoints if self.max_keep_ckpts > 0: if self.by_epoch: name = 'epoch_{}.pth' current_ckpt = runner.epoch + 1 else: name = 'iter_{}.pth' current_ckpt = runner.iter + 1 redundant_ckpts = range( current_ckpt - self.max_keep_ckpts * self.interval, 0, -self.interval) filename_tmpl = self.args.get('filename_tmpl', name) for _step in redundant_ckpts: ckpt_path = self.file_client.join_path( self.out_dir, filename_tmpl.format(_step)) if self.file_client.isfile(ckpt_path): self.file_client.remove(ckpt_path) else: break def after_train_iter(self, runner): if self.by_epoch: return # save checkpoint for following cases: # 1. every ``self.interval`` iterations # 2. reach the last iteration of training if self.every_n_iters( runner, self.interval) or (self.save_last and self.is_last_iter(runner)): runner.logger.info( f'Saving checkpoint at {runner.iter + 1} iterations') if self.sync_buffer: allreduce_params(runner.model.buffers()) self._save_checkpoint(runner)
Read the Docs v: master
Versions
master
latest
2.x
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.