# Copyright (c) Open-MMLab. All rights reserved.
import os
from ..dist_utils import allreduce_params, master_only
from .hook import HOOKS, Hook
[docs]@HOOKS.register_module()
class CheckpointHook(Hook):
"""Save checkpoints periodically.
Args:
interval (int): The saving period. If ``by_epoch=True``, interval
indicates epochs, otherwise it indicates iterations.
Default: -1, which means "never".
by_epoch (bool): Saving checkpoints by epoch or by iteration.
Default: True.
save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments.
Default: True.
out_dir (str, optional): The directory to save checkpoints. If not
specified, ``runner.work_dir`` will be used by default.
max_keep_ckpts (int, optional): The maximum checkpoints to keep.
In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space.
Default: -1, which means unlimited.
save_last (bool): Whether to force the last checkpoint to be saved
regardless of interval.
sync_buffer (bool): Whether to synchronize buffers in different
gpus. Default: False.
"""
def __init__(self,
interval=-1,
by_epoch=True,
save_optimizer=True,
out_dir=None,
max_keep_ckpts=-1,
save_last=True,
sync_buffer=False,
**kwargs):
self.interval = interval
self.by_epoch = by_epoch
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last
self.args = kwargs
self.sync_buffer = sync_buffer
def before_run(self, runner):
if not self.out_dir:
self.out_dir = runner.work_dir
def after_train_epoch(self, runner):
if not self.by_epoch:
return
# save checkpoint for following cases:
# 1. every ``self.interval`` epochs
# 2. reach the last epoch of training
if self.every_n_epochs(
runner, self.interval) or (self.save_last
and self.is_last_epoch(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
@master_only
def _save_checkpoint(self, runner):
"""Save the current checkpoint and delete unwanted checkpoint."""
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
if runner.meta is not None:
if self.by_epoch:
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
else:
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
runner.meta.setdefault('hook_msgs', dict())
runner.meta['hook_msgs']['last_ckpt'] = os.path.join(
self.out_dir, cur_ckpt_filename)
# remove other checkpoints
if self.max_keep_ckpts > 0:
if self.by_epoch:
name = 'epoch_{}.pth'
current_ckpt = runner.epoch + 1
else:
name = 'iter_{}.pth'
current_ckpt = runner.iter + 1
redundant_ckpts = range(
current_ckpt - self.max_keep_ckpts * self.interval, 0,
-self.interval)
filename_tmpl = self.args.get('filename_tmpl', name)
for _step in redundant_ckpts:
ckpt_path = os.path.join(self.out_dir,
filename_tmpl.format(_step))
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
else:
break
def after_train_iter(self, runner):
if self.by_epoch:
return
# save checkpoint for following cases:
# 1. every ``self.interval`` iterations
# 2. reach the last iteration of training
if self.every_n_iters(
runner, self.interval) or (self.save_last
and self.is_last_iter(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)