Source code for mmcv.runner.hooks.logger.wandb
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Dict, Optional, Union
from mmcv.utils import scandir
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
[docs]@HOOKS.register_module()
class WandbLoggerHook(LoggerHook):
"""Class to log metrics with wandb.
It requires `wandb`_ to be installed.
Args:
init_kwargs (dict): A dict contains the initialization keys. Check
https://docs.wandb.ai/ref/python/init for more init arguments.
interval (int): Logging interval (every k iterations).
Default 10.
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.
Default: True.
reset_flag (bool): Whether to clear the output buffer after logging.
Default: False.
commit (bool): Save the metrics dict to the wandb server and increment
the step. If false ``wandb.log`` just updates the current metrics
dict with the row argument and metrics won't be saved until
``wandb.log`` is called with ``commit=True``.
Default: True.
by_epoch (bool): Whether EpochBasedRunner is used.
Default: True.
with_step (bool): If True, the step will be logged from
``self.get_iters``. Otherwise, step will not be logged.
Default: True.
log_artifact (bool): If True, artifacts in {work_dir} will be uploaded
to wandb after training ends.
Default: True
`New in version 1.4.3.`
out_suffix (str or tuple[str], optional): Those filenames ending with
``out_suffix`` will be uploaded to wandb.
Default: ('.log.json', '.log', '.py').
`New in version 1.4.3.`
define_metric_cfg (dict, optional): A dict of metrics and summaries for
wandb.define_metric. The key is metric and the value is summary.
The summary should be in ["min", "max", "mean" ,"best", "last",
"none"].
For example, if setting
``define_metric_cfg={'coco/bbox_mAP': 'max'}``, the maximum value
of ``coco/bbox_mAP`` will be logged on wandb UI. See
`wandb docs <https://docs.wandb.ai/ref/python/run#define_metric>`_
for details.
Defaults to None.
`New in version 1.6.3.`
.. _wandb:
https://docs.wandb.ai
"""
def __init__(self,
init_kwargs: Optional[Dict] = None,
interval: int = 10,
ignore_last: bool = True,
reset_flag: bool = False,
commit: bool = True,
by_epoch: bool = True,
with_step: bool = True,
log_artifact: bool = True,
out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py'),
define_metric_cfg: Optional[Dict] = None):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_wandb()
self.init_kwargs = init_kwargs
self.commit = commit
self.with_step = with_step
self.log_artifact = log_artifact
self.out_suffix = out_suffix
self.define_metric_cfg = define_metric_cfg
def import_wandb(self) -> None:
try:
import wandb
except ImportError:
raise ImportError(
'Please run "pip install wandb" to install wandb')
self.wandb = wandb
@master_only
def before_run(self, runner) -> None:
super().before_run(runner)
if self.wandb is None:
self.import_wandb()
if self.init_kwargs:
self.wandb.init(**self.init_kwargs) # type: ignore
else:
self.wandb.init() # type: ignore
summary_choice = ['min', 'max', 'mean', 'best', 'last', 'none']
if self.define_metric_cfg is not None:
for metric, summary in self.define_metric_cfg.items():
if summary not in summary_choice:
warnings.warn(
f'summary should be in {summary_choice}. '
f'metric={metric}, summary={summary} will be skipped.')
self.wandb.define_metric( # type: ignore
metric, summary=summary)
@master_only
def log(self, runner) -> None:
tags = self.get_loggable_tags(runner)
if tags:
if self.with_step:
self.wandb.log(
tags, step=self.get_iter(runner), commit=self.commit)
else:
tags['global_step'] = self.get_iter(runner)
self.wandb.log(tags, commit=self.commit)
@master_only
def after_run(self, runner) -> None:
if self.log_artifact:
wandb_artifact = self.wandb.Artifact(
name='artifacts', type='model')
for filename in scandir(runner.work_dir, self.out_suffix, True):
local_filepath = osp.join(runner.work_dir, filename)
wandb_artifact.add_file(local_filepath)
self.wandb.log_artifact(wandb_artifact)
self.wandb.join()