Source code for mmcv.runner.hooks.logger.wandb

# Copyright (c) Open-MMLab. All rights reserved.
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook


[docs]@HOOKS.register_module() class WandbLoggerHook(LoggerHook): def __init__(self, init_kwargs=None, interval=10, ignore_last=True, reset_flag=False, commit=True, by_epoch=True, with_step=True): super(WandbLoggerHook, self).__init__(interval, ignore_last, reset_flag, by_epoch) self.import_wandb() self.init_kwargs = init_kwargs self.commit = commit self.with_step = with_step def import_wandb(self): try: import wandb except ImportError: raise ImportError( 'Please run "pip install wandb" to install wandb') self.wandb = wandb @master_only def before_run(self, runner): super(WandbLoggerHook, self).before_run(runner) if self.wandb is None: self.import_wandb() if self.init_kwargs: self.wandb.init(**self.init_kwargs) else: self.wandb.init() @master_only def log(self, runner): tags = self.get_loggable_tags(runner) if tags: if self.with_step: self.wandb.log( tags, step=self.get_iter(runner), commit=self.commit) else: tags['global_step'] = self.get_iter(runner) self.wandb.log(tags, commit=self.commit) @master_only def after_run(self, runner): self.wandb.join()