Shortcuts

mmcv.cnn.bricks.conv_ws 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.registry import MODELS


[文档]def conv_ws_2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, eps: float = 1e-5) -> torch.Tensor: c_in = weight.size(0) weight_flat = weight.view(c_in, -1) mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1) std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1) weight = (weight - mean) / (std + eps) return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
[文档]@MODELS.register_module('ConvWS') class ConvWS2d(nn.Conv2d): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, eps: float = 1e-5): super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.eps = eps
[文档] def forward(self, x: torch.Tensor) -> torch.Tensor: return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.eps)
[文档]@MODELS.register_module(name='ConvAWS') class ConvAWS2d(nn.Conv2d): """AWS (Adaptive Weight Standardization) This is a variant of Weight Standardization (https://arxiv.org/pdf/1903.10520.pdf) It is used in DetectoRS to avoid NaN (https://arxiv.org/pdf/2006.02334.pdf) Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the conv kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If set True, adds a learnable bias to the output. Default: True """ def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True): super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.register_buffer('weight_gamma', torch.ones(self.out_channels, 1, 1, 1)) self.register_buffer('weight_beta', torch.zeros(self.out_channels, 1, 1, 1)) def _get_weight(self, weight: torch.Tensor) -> torch.Tensor: weight_flat = weight.view(weight.size(0), -1) mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) weight = (weight - mean) / std weight = self.weight_gamma * weight + self.weight_beta return weight
[文档] def forward(self, x: torch.Tensor) -> torch.Tensor: weight = self._get_weight(self.weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str, local_metadata: Dict, strict: bool, missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None: """Override default load function. AWS overrides the function _load_from_state_dict to recover weight_gamma and weight_beta if they are missing. If weight_gamma and weight_beta are found in the checkpoint, this function will return after super()._load_from_state_dict. Otherwise, it will compute the mean and std of the pretrained weights and store them in weight_beta and weight_gamma. """ self.weight_gamma.data.fill_(-1) local_missing_keys: List = [] super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, local_missing_keys, unexpected_keys, error_msgs) if self.weight_gamma.data.mean() > 0: for k in local_missing_keys: missing_keys.append(k) return weight = self.weight.data weight_flat = weight.view(weight.size(0), -1) mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) self.weight_beta.data.copy_(mean) self.weight_gamma.data.copy_(std) missing_gamma_beta = [ k for k in local_missing_keys if k.endswith('weight_gamma') or k.endswith('weight_beta') ] for k in missing_gamma_beta: local_missing_keys.remove(k) for k in local_missing_keys: missing_keys.append(k)
Read the Docs v: latest
Versions
latest
stable
2.x
v2.1.0
v2.0.1
v2.0.0
1.x
v1.7.1
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.