Shortcuts

Source code for mmcv.cnn.bricks.conv_ws

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from .registry import CONV_LAYERS


def conv_ws_2d(input,
               weight,
               bias=None,
               stride=1,
               padding=0,
               dilation=1,
               groups=1,
               eps=1e-5):
    c_in = weight.size(0)
    weight_flat = weight.view(c_in, -1)
    mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
    std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
    weight = (weight - mean) / (std + eps)
    return F.conv2d(input, weight, bias, stride, padding, dilation, groups)


[docs]@CONV_LAYERS.register_module('ConvWS') class ConvWS2d(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, eps=1e-5): super(ConvWS2d, self).__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.eps = eps
[docs] def forward(self, x): return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.eps)
[docs]@CONV_LAYERS.register_module(name='ConvAWS') class ConvAWS2d(nn.Conv2d): """AWS (Adaptive Weight Standardization) This is a variant of Weight Standardization (https://arxiv.org/pdf/1903.10520.pdf) It is used in DetectoRS to avoid NaN (https://arxiv.org/pdf/2006.02334.pdf) Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the conv kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If set True, adds a learnable bias to the output. Default: True """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.register_buffer('weight_gamma', torch.ones(self.out_channels, 1, 1, 1)) self.register_buffer('weight_beta', torch.zeros(self.out_channels, 1, 1, 1)) def _get_weight(self, weight): weight_flat = weight.view(weight.size(0), -1) mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) weight = (weight - mean) / std weight = self.weight_gamma * weight + self.weight_beta return weight
[docs] def forward(self, x): weight = self._get_weight(self.weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): """Override default load function. AWS overrides the function _load_from_state_dict to recover weight_gamma and weight_beta if they are missing. If weight_gamma and weight_beta are found in the checkpoint, this function will return after super()._load_from_state_dict. Otherwise, it will compute the mean and std of the pretrained weights and store them in weight_beta and weight_gamma. """ self.weight_gamma.data.fill_(-1) local_missing_keys = [] super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, local_missing_keys, unexpected_keys, error_msgs) if self.weight_gamma.data.mean() > 0: for k in local_missing_keys: missing_keys.append(k) return weight = self.weight.data weight_flat = weight.view(weight.size(0), -1) mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) self.weight_beta.data.copy_(mean) self.weight_gamma.data.copy_(std) missing_gamma_beta = [ k for k in local_missing_keys if k.endswith('weight_gamma') or k.endswith('weight_beta') ] for k in missing_gamma_beta: local_missing_keys.remove(k) for k in local_missing_keys: missing_keys.append(k)
Read the Docs v: v1.4.7
Versions
master
latest
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.