Shortcuts

Source code for mmcv.ops.rotated_feature_align

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any

import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable

from ..utils import ext_loader

ext_module = ext_loader.load_ext(
    '_ext',
    ['rotated_feature_align_forward', 'rotated_feature_align_backward'])


class RotatedFeatureAlignFunction(Function):
    """Using the feature interpolation to obtain the position information
    correspond to the refined rotate anchors and reconstruct the feature maps
    in pixel-wise manner to achieve feature alignment.

    The details are described in the paper
    `R3Det: Refined Single-Stage Detector with Feature Refinement for Rotating
    Object <https://arxiv.org/abs/1908.05612>`_.
    """

    @staticmethod
    def symbolic(g, features, best_rbboxes, spatial_scale, points):
        assert points in [1, 5]
        return g.op(
            'mmcv::MMCVRotatedFeatureAlign',
            features,
            best_rbboxes,
            spatial_scale_f=spatial_scale,
            points_i=points)

    @staticmethod
    def forward(ctx: Any, features: torch.Tensor, best_rbboxes: torch.Tensor,
                spatial_scale: float, points: int) -> torch.Tensor:
        """
        Args:
            features (torch.Tensor): Input features with shape [N,C,H,W].
            best_rbboxes (torch.Tensor): Refined rotate anchors with
                shape [N,H,W,5]. Coordinate format (cx,cx,h,w,a).
            spatial_scale (float): The scale of feature map size and
                input image size.
            points (int, optional): The number of sample points.
                Only 1 and 5 are supported. Defaults to 1.

        Returns:
            torch.Tensor: Refined features with shape [N,C,H,W].
        """
        ctx.spatial_scale = spatial_scale
        ctx.points = points
        ctx.save_for_backward(best_rbboxes)
        assert points in [1, 5]
        output = torch.zeros_like(features)
        ext_module.rotated_feature_align_forward(
            features,
            best_rbboxes,
            output,
            spatial_scale=spatial_scale,
            points=points)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
        """
        Args:
            grad_output (torch.Tensor): The gradient of output features
                with shape [N,C,H,W].

        Returns:
            torch.Tensor: The gradient of input features with shape [N,C,H,W].
        """
        best_rbboxes = ctx.saved_tensors[0]
        points = ctx.points
        spatial_scale = ctx.spatial_scale
        grad_input = None
        if ctx.needs_input_grad[0]:
            grad_input = torch.zeros_like(grad_output)
            ext_module.rotated_feature_align_backward(
                grad_output.contiguous(),
                best_rbboxes,
                grad_input,
                spatial_scale=spatial_scale,
                points=points)
        return grad_input, None, None, None


[docs]def rotated_feature_align(features: torch.Tensor, best_rbboxes: torch.Tensor, spatial_scale: float = 1 / 8, points: int = 1) -> torch.Tensor: return RotatedFeatureAlignFunction.apply(features, best_rbboxes, spatial_scale, points)
Read the Docs v: latest
Versions
master
latest
2.x
1.x
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.