Shortcuts

Source code for mmcv.cnn.bricks.wrappers

# Copyright (c) OpenMMLab. All rights reserved.
r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py  # noqa: E501

Wrap some nn modules to support empty tensor input. Currently, these wrappers
are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask
heads are trained on only positive RoIs.
"""
import math

import torch
import torch.nn as nn
from mmengine.registry import MODELS
from torch.nn.modules.utils import _pair, _triple

if torch.__version__ == 'parrots':
    TORCH_VERSION = torch.__version__
else:
    # torch.__version__ could be 1.3.1+cu92, we only need the first two
    # for comparison
    TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])


def obsolete_torch_version(torch_version, version_threshold) -> bool:
    return torch_version == 'parrots' or torch_version <= version_threshold


class NewEmptyTensorOp(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x: torch.Tensor, new_shape: tuple) -> torch.Tensor:
        ctx.shape = x.shape
        return x.new_empty(new_shape)

    @staticmethod
    def backward(ctx, grad: torch.Tensor) -> tuple:
        shape = ctx.shape
        return NewEmptyTensorOp.apply(grad, shape), None


[docs]@MODELS.register_module('Conv', force=True) class Conv2d(nn.Conv2d):
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, self.padding, self.stride, self.dilation): o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1 out_shape.append(o) empty = NewEmptyTensorOp.apply(x, out_shape) if self.training: # produce dummy gradient to avoid DDP warning. dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 return empty + dummy else: return empty return super().forward(x)
[docs]@MODELS.register_module('Conv3d', force=True) class Conv3d(nn.Conv3d):
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size, self.padding, self.stride, self.dilation): o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1 out_shape.append(o) empty = NewEmptyTensorOp.apply(x, out_shape) if self.training: # produce dummy gradient to avoid DDP warning. dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 return empty + dummy else: return empty return super().forward(x)
[docs]@MODELS.register_module() @MODELS.register_module('deconv') class ConvTranspose2d(nn.ConvTranspose2d):
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size, self.padding, self.stride, self.dilation, self.output_padding): out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op) empty = NewEmptyTensorOp.apply(x, out_shape) if self.training: # produce dummy gradient to avoid DDP warning. dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 return empty + dummy else: return empty return super().forward(x)
[docs]@MODELS.register_module() @MODELS.register_module('deconv3d') class ConvTranspose3d(nn.ConvTranspose3d):
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size, self.padding, self.stride, self.dilation, self.output_padding): out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op) empty = NewEmptyTensorOp.apply(x, out_shape) if self.training: # produce dummy gradient to avoid DDP warning. dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 return empty + dummy else: return empty return super().forward(x)
[docs]class MaxPool2d(nn.MaxPool2d):
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # PyTorch 1.9 does not support empty tensor inference yet if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0: out_shape = list(x.shape[:2]) for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), _pair(self.padding), _pair(self.stride), _pair(self.dilation)): o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1 o = math.ceil(o) if self.ceil_mode else math.floor(o) out_shape.append(o) empty = NewEmptyTensorOp.apply(x, out_shape) return empty return super().forward(x)
[docs]class MaxPool3d(nn.MaxPool3d):
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # PyTorch 1.9 does not support empty tensor inference yet if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0: out_shape = list(x.shape[:2]) for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size), _triple(self.padding), _triple(self.stride), _triple(self.dilation)): o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1 o = math.ceil(o) if self.ceil_mode else math.floor(o) out_shape.append(o) empty = NewEmptyTensorOp.apply(x, out_shape) return empty return super().forward(x)
[docs]class Linear(torch.nn.Linear):
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # empty tensor forward of Linear layer is supported in Pytorch 1.6 if obsolete_torch_version(TORCH_VERSION, (1, 5)) and x.numel() == 0: out_shape = [x.shape[0], self.out_features] empty = NewEmptyTensorOp.apply(x, out_shape) if self.training: # produce dummy gradient to avoid DDP warning. dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 return empty + dummy else: return empty return super().forward(x)
Read the Docs v: latest
Versions
master
latest
2.x
1.x
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.