# Copyright (c) Open-MMLab. All rights reserved.
import warnings
from abc import ABCMeta
import torch.nn as nn
[docs]class BaseModule(nn.Module, metaclass=ABCMeta):
"""Base module for all modules in openmmlab."""
def __init__(self, init_cfg=None):
"""Initialize BaseModule, inherited from `torch.nn.Module`
Args:
init_cfg (dict, optional): Initialization config dict.
"""
# NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority.
super(BaseModule, self).__init__()
# define default value of init_cfg instead of hard code
# in init_weigt() function
self._is_init = False
if init_cfg is not None:
self.init_cfg = init_cfg
# Backward compatibility in derived classes
# if pretrained is not None:
# warnings.warn('DeprecationWarning: pretrained is a deprecated \
# key, please consider using init_cfg')
# self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@property
def is_init(self):
return self._is_init
[docs] def init_weight(self):
"""Initialize the weights."""
from ..cnn import initialize
if not self._is_init:
if hasattr(self, 'init_cfg'):
initialize(self, self.init_cfg)
for m in self.children():
if hasattr(m, 'init_weight'):
m.init_weight()
self._is_init = True
else:
warnings.warn(f'init_weight of {self.__class__.__name__} has '
f'been called more than once.')
def __repr__(self):
s = super().__repr__()
if hasattr(self, 'init_cfg'):
s += f'\ninit_cfg={self.init_cfg}'
return s
[docs]class Sequential(BaseModule, nn.Sequential):
"""Sequential module in openmmlab.
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, *args, init_cfg=None):
BaseModule.__init__(self, init_cfg)
nn.Sequential.__init__(self, *args)
[docs]class ModuleList(BaseModule, nn.ModuleList):
"""ModuleList in openmmlab.
Args:
modules (iterable, optional): an iterable of modules to add.
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, modules=None, init_cfg=None):
BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules)