Shortcuts

mmcv.cnn.rfsearch.operator 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import copy

import numpy as np
import torch
import torch.nn as nn
from mmengine.logging import print_log
from mmengine.model import BaseModule
from torch import Tensor

from .utils import expand_rates, get_single_padding


class BaseConvRFSearchOp(BaseModule):
    """Based class of ConvRFSearchOp.

    Args:
        op_layer (nn.Module): pytorch module, e,g, Conv2d
        global_config (dict): config dict.
    """

    def __init__(self, op_layer: nn.Module, global_config: dict):
        super().__init__()
        self.op_layer = op_layer
        self.global_config = global_config

    def normlize(self, weights: nn.Parameter) -> nn.Parameter:
        """Normalize weights.

        Args:
            weights (nn.Parameter): Weights to be normalized.

        Returns:
            nn.Parameters: Normalized weights.
        """
        abs_weights = torch.abs(weights)
        normalized_weights = abs_weights / torch.sum(abs_weights)
        return normalized_weights


[文档]class Conv2dRFSearchOp(BaseConvRFSearchOp): """Enable Conv2d with receptive field searching ability. Args: op_layer (nn.Module): pytorch module, e,g, Conv2d global_config (dict): config dict. Defaults to None. By default this must include: - "init_alphas": The value for initializing weights of each branch. - "num_branches": The controller of the size of search space (the number of branches). - "exp_rate": The controller of the sparsity of search space. - "mmin": The minimum dilation rate. - "mmax": The maximum dilation rate. Extra keys may exist, but are used by RFSearchHook, e.g., "step", "max_step", "search_interval", and "skip_layer". verbose (bool): Determines whether to print rf-next related logging messages. Defaults to True. """ def __init__(self, op_layer: nn.Module, global_config: dict, verbose: bool = True): super().__init__(op_layer, global_config) assert global_config is not None, 'global_config is None' self.num_branches = global_config['num_branches'] assert self.num_branches in [2, 3] self.verbose = verbose init_dilation = op_layer.dilation self.dilation_rates = expand_rates(init_dilation, global_config) if self.op_layer.kernel_size[ 0] == 1 or self.op_layer.kernel_size[0] % 2 == 0: self.dilation_rates = [(op_layer.dilation[0], r[1]) for r in self.dilation_rates] if self.op_layer.kernel_size[ 1] == 1 or self.op_layer.kernel_size[1] % 2 == 0: self.dilation_rates = [(r[0], op_layer.dilation[1]) for r in self.dilation_rates] self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches)) if self.verbose: print_log(f'Expand as {self.dilation_rates}', 'current') nn.init.constant_(self.branch_weights, global_config['init_alphas'])
[文档] def forward(self, input: Tensor) -> Tensor: norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)]) if len(self.dilation_rates) == 1: outputs = [ nn.functional.conv2d( input, weight=self.op_layer.weight, bias=self.op_layer.bias, stride=self.op_layer.stride, padding=self.get_padding(self.dilation_rates[0]), dilation=self.dilation_rates[0], groups=self.op_layer.groups, ) ] else: outputs = [ nn.functional.conv2d( input, weight=self.op_layer.weight, bias=self.op_layer.bias, stride=self.op_layer.stride, padding=self.get_padding(r), dilation=r, groups=self.op_layer.groups, ) * norm_w[i] for i, r in enumerate(self.dilation_rates) ] output = outputs[0] for i in range(1, len(self.dilation_rates)): output += outputs[i] return output
[文档] def estimate_rates(self) -> None: """Estimate new dilation rate based on trained branch_weights.""" norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)]) if self.verbose: print_log( 'Estimate dilation {} with weight {}.'.format( self.dilation_rates, norm_w.detach().cpu().numpy().tolist()), 'current') sum0, sum1, w_sum = 0, 0, 0 for i in range(len(self.dilation_rates)): sum0 += norm_w[i].item() * self.dilation_rates[i][0] sum1 += norm_w[i].item() * self.dilation_rates[i][1] w_sum += norm_w[i].item() estimated = [ np.clip( int(round(sum0 / w_sum)), self.global_config['mmin'], self.global_config['mmax']).item(), np.clip( int(round(sum1 / w_sum)), self.global_config['mmin'], self.global_config['mmax']).item() ] self.op_layer.dilation = tuple(estimated) self.op_layer.padding = self.get_padding(self.op_layer.dilation) self.dilation_rates = [tuple(estimated)] if self.verbose: print_log(f'Estimate as {tuple(estimated)}', 'current')
[文档] def expand_rates(self) -> None: """Expand dilation rate.""" dilation = self.op_layer.dilation dilation_rates = expand_rates(dilation, self.global_config) if self.op_layer.kernel_size[ 0] == 1 or self.op_layer.kernel_size[0] % 2 == 0: dilation_rates = [(dilation[0], r[1]) for r in dilation_rates] if self.op_layer.kernel_size[ 1] == 1 or self.op_layer.kernel_size[1] % 2 == 0: dilation_rates = [(r[0], dilation[1]) for r in dilation_rates] self.dilation_rates = copy.deepcopy(dilation_rates) if self.verbose: print_log(f'Expand as {self.dilation_rates}', 'current') nn.init.constant_(self.branch_weights, self.global_config['init_alphas'])
def get_padding(self, dilation) -> tuple: padding = (get_single_padding(self.op_layer.kernel_size[0], self.op_layer.stride[0], dilation[0]), get_single_padding(self.op_layer.kernel_size[1], self.op_layer.stride[1], dilation[1])) return padding