mmcv.ops.saconv 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
from mmcv.ops.deform_conv import deform_conv2d
from mmcv.utils import TORCH_VERSION, digit_version


[文档]@CONV_LAYERS.register_module(name='SAC') class SAConv2d(ConvAWS2d): """SAC (Switchable Atrous Convolution) This is an implementation of SAC in DetectoRS (https://arxiv.org/pdf/2006.02334.pdf). Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` use_deform: If ``True``, replace convolution with deformable convolution. Default: ``False``. """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, use_deform=False): super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.use_deform = use_deform self.switch = nn.Conv2d( self.in_channels, 1, kernel_size=1, stride=stride, bias=True) self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size())) self.pre_context = nn.Conv2d( self.in_channels, self.in_channels, kernel_size=1, bias=True) self.post_context = nn.Conv2d( self.out_channels, self.out_channels, kernel_size=1, bias=True) if self.use_deform: self.offset_s = nn.Conv2d( self.in_channels, 18, kernel_size=3, padding=1, stride=stride, bias=True) self.offset_l = nn.Conv2d( self.in_channels, 18, kernel_size=3, padding=1, stride=stride, bias=True) self.init_weights() def init_weights(self): constant_init(self.switch, 0, bias=1) self.weight_diff.data.zero_() constant_init(self.pre_context, 0) constant_init(self.post_context, 0) if self.use_deform: constant_init(self.offset_s, 0) constant_init(self.offset_l, 0)
[文档] def forward(self, x): # pre-context avg_x = F.adaptive_avg_pool2d(x, output_size=1) avg_x = self.pre_context(avg_x) avg_x = avg_x.expand_as(x) x = x + avg_x # switch avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect') avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0) switch = self.switch(avg_x) # sac weight = self._get_weight(self.weight) zero_bias = torch.zeros( self.out_channels, device=weight.device, dtype=weight.dtype) if self.use_deform: offset = self.offset_s(avg_x) out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, self.dilation, self.groups, 1) else: if (TORCH_VERSION == 'parrots' or digit_version(TORCH_VERSION) < digit_version('1.5.0')): out_s = super().conv2d_forward(x, weight) elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): # bias is a required argument of _conv_forward in torch 1.8.0 out_s = super()._conv_forward(x, weight, zero_bias) else: out_s = super()._conv_forward(x, weight) ori_p = self.padding ori_d = self.dilation self.padding = tuple(3 * p for p in self.padding) self.dilation = tuple(3 * d for d in self.dilation) weight = weight + self.weight_diff if self.use_deform: offset = self.offset_l(avg_x) out_l = deform_conv2d(x, offset, weight, self.stride, self.padding, self.dilation, self.groups, 1) else: if (TORCH_VERSION == 'parrots' or digit_version(TORCH_VERSION) < digit_version('1.5.0')): out_l = super().conv2d_forward(x, weight) elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): # bias is a required argument of _conv_forward in torch 1.8.0 out_l = super()._conv_forward(x, weight, zero_bias) else: out_l = super()._conv_forward(x, weight) out = switch * out_s + (1 - switch) * out_l self.padding = ori_p self.dilation = ori_d # post-context avg_x = F.adaptive_avg_pool2d(out, output_size=1) avg_x = self.post_context(avg_x) avg_x = avg_x.expand_as(out) out = out + avg_x return out