Shortcuts

mmcv.ops.masked_conv 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair

from ..utils import ext_loader

ext_module = ext_loader.load_ext(
    '_ext', ['masked_im2col_forward', 'masked_col2im_forward'])


class MaskedConv2dFunction(Function):

    @staticmethod
    def symbolic(g, features, mask, weight, bias, padding, stride=1):
        return g.op(
            'mmcv::MMCVMaskedConv2d',
            features,
            mask,
            weight,
            bias,
            padding_i=padding,
            stride_i=stride)

    @staticmethod
    def forward(ctx,
                features: torch.Tensor,
                mask: torch.Tensor,
                weight: torch.nn.Parameter,
                bias: torch.nn.Parameter,
                padding: int = 0,
                stride: int = 1) -> torch.Tensor:
        assert mask.dim() == 3 and mask.size(0) == 1
        assert features.dim() == 4 and features.size(0) == 1
        assert features.size()[2:] == mask.size()[1:]
        pad_h, pad_w = _pair(padding)
        stride_h, stride_w = _pair(stride)
        if stride_h != 1 or stride_w != 1:
            raise ValueError(
                'Stride could not only be 1 in masked_conv2d currently.')
        out_channel, in_channel, kernel_h, kernel_w = weight.size()

        if features.device.type == 'npu':
            import torch_npu
            output = torch_npu.npu_conv2d(
                features,
                weight,
                bias,
                stride=(stride_h, stride_w),
                padding=(pad_h, pad_w),
                dilation=(1, 1),
                groups=1)
            if mask.size()[1:] != output.size()[2:]:
                raise ValueError(
                    'The mask is inconsistent with the shape of output_conv.')
            mask = mask > 0
            mask = mask.type(output.dtype)
            output = output * mask
            return output

        batch_size = features.size(0)
        out_h = int(
            math.floor(
                torch.true_divide((features.size(2) + 2 * pad_h -
                                   (kernel_h - 1) - 1), stride_h) + 1))
        out_w = int(
            math.floor(
                torch.true_divide((features.size(3) + 2 * pad_w -
                                   (kernel_w - 1) - 1), stride_w) + 1))
        mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False)
        output = features.new_zeros(batch_size, out_channel, out_h, out_w)
        if mask_inds.numel() > 0:
            mask_h_idx = mask_inds[:, 0].contiguous()
            mask_w_idx = mask_inds[:, 1].contiguous()
            data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
                                          mask_inds.size(0))
            ext_module.masked_im2col_forward(
                features,
                mask_h_idx,
                mask_w_idx,
                data_col,
                kernel_h=kernel_h,
                kernel_w=kernel_w,
                pad_h=pad_h,
                pad_w=pad_w)
            masked_output = torch.addmm(1, bias[:, None], 1,
                                        weight.view(out_channel, -1), data_col)
            ext_module.masked_col2im_forward(
                masked_output,
                mask_h_idx,
                mask_w_idx,
                output,
                height=out_h,
                width=out_w,
                channels=out_channel)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output: torch.Tensor) -> tuple:
        return (None, ) * 5


masked_conv2d = MaskedConv2dFunction.apply


[文档]class MaskedConv2d(nn.Conv2d): """A MaskedConv2d which inherits the official Conv2d. The masked forward doesn't implement the backward function and only supports the stride parameter to be 1 currently. """ def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, ...]], stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, bias: bool = True): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
[文档] def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: if mask is None: # fallback to the normal Conv2d return super().forward(input) else: return masked_conv2d(input, mask, self.weight, self.bias, self.padding)
Read the Docs v: stable
Versions
latest
stable
2.x
v2.0.1
v2.0.0
1.x
v1.7.1
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.