Source code for mmcv.runner.base_module

# Copyright (c) Open-MMLab. All rights reserved.
import warnings
from abc import ABCMeta

import torch.nn as nn

from mmcv import ConfigDict


[docs]class BaseModule(nn.Module, metaclass=ABCMeta): """Base module for all modules in openmmlab.""" def __init__(self, init_cfg=None): """Initialize BaseModule, inherited from `torch.nn.Module` Args: init_cfg (dict, optional): Initialization config dict. """ # NOTE init_cfg can be defined in different levels, but init_cfg # in low levels has a higher priority. super(BaseModule, self).__init__() # define default value of init_cfg instead of hard code # in init_weight() function self._is_init = False self.init_cfg = init_cfg # Backward compatibility in derived classes # if pretrained is not None: # warnings.warn('DeprecationWarning: pretrained is a deprecated \ # key, please consider using init_cfg') # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) @property def is_init(self): return self._is_init
[docs] def init_weights(self): """Initialize the weights.""" from ..cnn import initialize if not self._is_init: if self.init_cfg: initialize(self, self.init_cfg) if isinstance(self.init_cfg, (dict, ConfigDict)): # Avoid the parameters of the pre-training model # being overwritten by the init_weights # of the children. if self.init_cfg['type'] == 'Pretrained': return for m in self.children(): if hasattr(m, 'init_weights'): m.init_weights() self._is_init = True else: warnings.warn(f'init_weights of {self.__class__.__name__} has ' f'been called more than once.')
def __repr__(self): s = super().__repr__() if self.init_cfg: s += f'\ninit_cfg={self.init_cfg}' return s
[docs]class Sequential(BaseModule, nn.Sequential): """Sequential module in openmmlab. Args: init_cfg (dict, optional): Initialization config dict. """ def __init__(self, *args, init_cfg=None): BaseModule.__init__(self, init_cfg) nn.Sequential.__init__(self, *args)
[docs]class ModuleList(BaseModule, nn.ModuleList): """ModuleList in openmmlab. Args: modules (iterable, optional): an iterable of modules to add. init_cfg (dict, optional): Initialization config dict. """ def __init__(self, modules=None, init_cfg=None): BaseModule.__init__(self, init_cfg) nn.ModuleList.__init__(self, modules)