mmcv.runner.base_module 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler

import torch.nn as nn

from mmcv.runner.dist_utils import master_only
from mmcv.utils.logging import get_logger, logger_initialized, print_log


[文档]class BaseModule(nn.Module, metaclass=ABCMeta): """Base module for all modules in openmmlab. ``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional functionality of parameter initialization. Compared with ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes. - ``init_cfg``: the config to control the initialization. - ``init_weights``: The function of parameter initialization and recording initialization information. - ``_params_init_info``: Used to track the parameter initialization information. This attribute only exists during executing the ``init_weights``. Args: init_cfg (dict, optional): Initialization config dict. """ def __init__(self, init_cfg=None): """Initialize BaseModule, inherited from `torch.nn.Module`""" # NOTE init_cfg can be defined in different levels, but init_cfg # in low levels has a higher priority. super(BaseModule, self).__init__() # define default value of init_cfg instead of hard code # in init_weights() function self._is_init = False self.init_cfg = copy.deepcopy(init_cfg) # Backward compatibility in derived classes # if pretrained is not None: # warnings.warn('DeprecationWarning: pretrained is a deprecated \ # key, please consider using init_cfg') # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) @property def is_init(self): return self._is_init
[文档] def init_weights(self): """Initialize the weights.""" is_top_level_module = False # check if it is top-level module if not hasattr(self, '_params_init_info'): # The `_params_init_info` is used to record the initialization # information of the parameters # the key should be the obj:`nn.Parameter` of model and the value # should be a dict containing # - init_info (str): The string that describes the initialization. # - tmp_mean_value (FloatTensor): The mean of the parameter, # which indicates whether the parameter has been modified. # this attribute would be deleted after all parameters # is initialized. self._params_init_info = defaultdict(dict) is_top_level_module = True # Initialize the `_params_init_info`, # When detecting the `tmp_mean_value` of # the corresponding parameter is changed, update related # initialization information for name, param in self.named_parameters(): self._params_init_info[param][ 'init_info'] = f'The value is the same before and ' \ f'after calling `init_weights` ' \ f'of {self.__class__.__name__} ' self._params_init_info[param][ 'tmp_mean_value'] = param.data.mean() # pass `params_init_info` to all submodules # All submodules share the same `params_init_info`, # so it will be updated when parameters are # modified at any level of the model. for sub_module in self.modules(): sub_module._params_init_info = self._params_init_info # Get the initialized logger, if not exist, # create a logger named `mmcv` logger_names = list(logger_initialized.keys()) logger_name = logger_names[0] if logger_names else 'mmcv' from ..cnn import initialize from ..cnn.utils.weight_init import update_init_info module_name = self.__class__.__name__ if not self._is_init: if self.init_cfg: print_log( f'initialize {module_name} with init_cfg {self.init_cfg}', logger=logger_name) initialize(self, self.init_cfg) if isinstance(self.init_cfg, dict): # prevent the parameters of # the pre-trained model # from being overwritten by # the `init_weights` if self.init_cfg['type'] == 'Pretrained': return for m in self.children(): if hasattr(m, 'init_weights'): m.init_weights() # users may overload the `init_weights` update_init_info( m, init_info=f'Initialized by ' f'user-defined `init_weights`' f' in {m.__class__.__name__} ') self._is_init = True else: warnings.warn(f'init_weights of {self.__class__.__name__} has ' f'been called more than once.') if is_top_level_module: self._dump_init_info(logger_name) for sub_module in self.modules(): del sub_module._params_init_info
@master_only def _dump_init_info(self, logger_name): """Dump the initialization information to a file named `initialization.log.json` in workdir. Args: logger_name (str): The name of logger. """ logger = get_logger(logger_name) with_file_handler = False # dump the information to the logger file if there is a `FileHandler` for handler in logger.handlers: if isinstance(handler, FileHandler): handler.stream.write( 'Name of parameter - Initialization information\n') for name, param in self.named_parameters(): handler.stream.write( f'\n{name} - {param.shape}: ' f"\n{self._params_init_info[param]['init_info']} \n") handler.stream.flush() with_file_handler = True if not with_file_handler: for name, param in self.named_parameters(): print_log( f'\n{name} - {param.shape}: ' f"\n{self._params_init_info[param]['init_info']} \n ", logger=logger_name) def __repr__(self): s = super().__repr__() if self.init_cfg: s += f'\ninit_cfg={self.init_cfg}' return s
[文档]class Sequential(BaseModule, nn.Sequential): """Sequential module in openmmlab. Args: init_cfg (dict, optional): Initialization config dict. """ def __init__(self, *args, init_cfg=None): BaseModule.__init__(self, init_cfg) nn.Sequential.__init__(self, *args)
[文档]class ModuleList(BaseModule, nn.ModuleList): """ModuleList in openmmlab. Args: modules (iterable, optional): an iterable of modules to add. init_cfg (dict, optional): Initialization config dict. """ def __init__(self, modules=None, init_cfg=None): BaseModule.__init__(self, init_cfg) nn.ModuleList.__init__(self, modules)