Shortcuts

Source code for mmcv.ops.saconv

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import constant_init
from mmengine.registry import MODELS
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

from mmcv.cnn import ConvAWS2d
from mmcv.ops.deform_conv import deform_conv2d


[docs]@MODELS.register_module(name='SAC') class SAConv2d(ConvAWS2d): """SAC (Switchable Atrous Convolution) This is an implementation of `DetectoRS: Detecting Objects with Recursive Feature Pyramid and Switchable Atrous Convolution <https://arxiv.org/abs/2006.02334>`_. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` use_deform: If ``True``, replace convolution with deformable convolution. Default: ``False``. """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, use_deform=False): super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.use_deform = use_deform self.switch = nn.Conv2d( self.in_channels, 1, kernel_size=1, stride=stride, bias=True) self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size())) self.pre_context = nn.Conv2d( self.in_channels, self.in_channels, kernel_size=1, bias=True) self.post_context = nn.Conv2d( self.out_channels, self.out_channels, kernel_size=1, bias=True) if self.use_deform: self.offset_s = nn.Conv2d( self.in_channels, 18, kernel_size=3, padding=1, stride=stride, bias=True) self.offset_l = nn.Conv2d( self.in_channels, 18, kernel_size=3, padding=1, stride=stride, bias=True) self.init_weights() def init_weights(self): constant_init(self.switch, 0, bias=1) self.weight_diff.data.zero_() constant_init(self.pre_context, 0) constant_init(self.post_context, 0) if self.use_deform: constant_init(self.offset_s, 0) constant_init(self.offset_l, 0)
[docs] def forward(self, x): # pre-context avg_x = F.adaptive_avg_pool2d(x, output_size=1) avg_x = self.pre_context(avg_x) avg_x = avg_x.expand_as(x) x = x + avg_x # switch avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect') avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0) switch = self.switch(avg_x) # sac weight = self._get_weight(self.weight) zero_bias = torch.zeros( self.out_channels, device=weight.device, dtype=weight.dtype) if self.use_deform: offset = self.offset_s(avg_x) out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, self.dilation, self.groups, 1) else: if (TORCH_VERSION == 'parrots' or digit_version(TORCH_VERSION) < digit_version('1.5.0')): out_s = super().conv2d_forward(x, weight) elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): # bias is a required argument of _conv_forward in torch 1.8.0 out_s = super()._conv_forward(x, weight, zero_bias) else: out_s = super()._conv_forward(x, weight) ori_p = self.padding ori_d = self.dilation self.padding = tuple(3 * p for p in self.padding) self.dilation = tuple(3 * d for d in self.dilation) weight = weight + self.weight_diff if self.use_deform: offset = self.offset_l(avg_x) out_l = deform_conv2d(x, offset, weight, self.stride, self.padding, self.dilation, self.groups, 1) else: if (TORCH_VERSION == 'parrots' or digit_version(TORCH_VERSION) < digit_version('1.5.0')): out_l = super().conv2d_forward(x, weight) elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): # bias is a required argument of _conv_forward in torch 1.8.0 out_l = super()._conv_forward(x, weight, zero_bias) else: out_l = super()._conv_forward(x, weight) out = switch * out_s + (1 - switch) * out_l self.padding = ori_p self.dilation = ori_d # post-context avg_x = F.adaptive_avg_pool2d(out, output_size=1) avg_x = self.post_context(avg_x) avg_x = avg_x.expand_as(out) out = out + avg_x return out
Read the Docs v: latest
Versions
master
latest
2.x
1.x
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.