mmcv.cnn.bricks.plugin 源代码

import inspect
import platform

from .registry import PLUGIN_LAYERS

if platform.system() == 'Windows':
    import regex as re
    import re

def infer_abbr(class_type):
    """Infer abbreviation from the class name.

    This method will infer the abbreviation to map class types to

    Rule 1: If the class has the property "abbr", return the property.
    Rule 2: Otherwise, the abbreviation falls back to snake case of class
    name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.

        class_type (type): The norm layer type.

        str: The inferred abbreviation.

    def camel2snack(word):
        """Convert camel case word into snack case.

        Modified from `inflection lib


            >>> camel2snack("FancyBlock")

        word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
        word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
        word = word.replace('-', '_')
        return word.lower()

    if not inspect.isclass(class_type):
        raise TypeError(
            f'class_type must be a type, but got {type(class_type)}')
    if hasattr(class_type, '_abbr_'):
        return class_type._abbr_
        return camel2snack(class_type.__name__)

[文档]def build_plugin_layer(cfg, postfix='', **kwargs): """Build plugin layer. Args: cfg (None or dict): cfg should contain: type (str): identify plugin layer type. layer args: args needed to instantiate a plugin layer. postfix (int, str): appended into norm abbreviation to create named layer. Default: ''. Returns: tuple[str, nn.Module]: name (str): abbreviation + postfix layer (nn.Module): created plugin layer """ if not isinstance(cfg, dict): raise TypeError('cfg must be a dict') if 'type' not in cfg: raise KeyError('the cfg dict must contain the key "type"') cfg_ = cfg.copy() layer_type = cfg_.pop('type') if layer_type not in PLUGIN_LAYERS: raise KeyError(f'Unrecognized plugin type {layer_type}') plugin_layer = PLUGIN_LAYERS.get(layer_type) abbr = infer_abbr(plugin_layer) assert isinstance(postfix, (int, str)) name = abbr + str(postfix) layer = plugin_layer(**kwargs, **cfg_) return name, layer