Shortcuts

mmcv.ops.corner_pool 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn
from torch.autograd import Function

_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}


def _corner_pool(x, dim, flip):
    size = x.size(dim)
    output = x.clone()

    ind = 1
    while ind < size:
        if flip:
            cur_start = 0
            cur_len = size - ind
            next_start = ind
            next_len = size - ind
        else:
            cur_start = ind
            cur_len = size - ind
            next_start = 0
            next_len = size - ind

        # max_temp should be cloned for backward computation
        max_temp = output.narrow(dim, cur_start, cur_len).clone()
        cur_temp = output.narrow(dim, cur_start, cur_len)
        next_temp = output.narrow(dim, next_start, next_len)

        cur_temp[...] = torch.where(max_temp > next_temp, max_temp, next_temp)

        ind = ind << 1

    return output


class TopPoolFunction(Function):

    @staticmethod
    def symbolic(g, input):
        output = g.op(
            'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
        return output

    @staticmethod
    def forward(ctx, input):
        return _corner_pool(input, 2, True)


class BottomPoolFunction(Function):

    @staticmethod
    def symbolic(g, input):
        output = g.op(
            'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
        return output

    @staticmethod
    def forward(ctx, input):
        return _corner_pool(input, 2, False)


class LeftPoolFunction(Function):

    @staticmethod
    def symbolic(g, input):
        output = g.op(
            'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
        return output

    @staticmethod
    def forward(ctx, input):
        return _corner_pool(input, 3, True)


class RightPoolFunction(Function):

    @staticmethod
    def symbolic(g, input):
        output = g.op(
            'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
        return output

    @staticmethod
    def forward(ctx, input):
        return _corner_pool(input, 3, False)


[文档]class CornerPool(nn.Module): """Corner Pooling. Corner Pooling is a new type of pooling layer that helps a convolutional network better localize corners of bounding boxes. Please refer to `CornerNet: Detecting Objects as Paired Keypoints <https://arxiv.org/abs/1808.01244>`_ for more details. Code is modified from https://github.com/princeton-vl/CornerNet-Lite. Args: mode (str): Pooling orientation for the pooling layer - 'bottom': Bottom Pooling - 'left': Left Pooling - 'right': Right Pooling - 'top': Top Pooling Returns: Feature map after pooling. """ pool_functions = { 'bottom': BottomPoolFunction, 'left': LeftPoolFunction, 'right': RightPoolFunction, 'top': TopPoolFunction, } cummax_dim_flip = { 'bottom': (2, False), 'left': (3, True), 'right': (3, False), 'top': (2, True), } def __init__(self, mode): super(CornerPool, self).__init__() assert mode in self.pool_functions self.mode = mode self.corner_pool = self.pool_functions[mode]
[文档] def forward(self, x): if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0': if torch.onnx.is_in_onnx_export(): assert torch.__version__ >= '1.7.0', \ 'When `cummax` serves as an intermediate component whose '\ 'outputs is used as inputs for another modules, it\'s '\ 'expected that pytorch version must be >= 1.7.0, '\ 'otherwise Error appears like: `RuntimeError: tuple '\ 'appears in op that does not forward tuples, unsupported '\ 'kind: prim::PythonOp`.' dim, flip = self.cummax_dim_flip[self.mode] if flip: x = x.flip(dim) pool_tensor, _ = torch.cummax(x, dim=dim) if flip: pool_tensor = pool_tensor.flip(dim) return pool_tensor else: if torch.onnx.is_in_onnx_export(): return self.corner_pool.apply(x) else: dim, flip = self.cummax_dim_flip[self.mode] return _corner_pool(x, dim, flip)
Read the Docs v: v1.5.1
Versions
latest
stable
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.