Shortcuts

Source code for mmcv.ops.roiaware_pool3d

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Tuple, Union

import mmengine
import torch
from torch import nn as nn
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext(
    '_ext', ['roiaware_pool3d_forward', 'roiaware_pool3d_backward'])


[docs]class RoIAwarePool3d(nn.Module): """Encode the geometry-specific features of each 3D proposal. Please refer to `PartA2 <https://arxiv.org/pdf/1907.03670.pdf>`_ for more details. Args: out_size (int or tuple): The size of output features. n or [n1, n2, n3]. max_pts_per_voxel (int, optional): The maximum number of points per voxel. Default: 128. mode (str, optional): Pooling method of RoIAware, 'max' or 'avg'. Default: 'max'. """ def __init__(self, out_size: Union[int, tuple], max_pts_per_voxel: int = 128, mode: str = 'max'): super().__init__() self.out_size = out_size self.max_pts_per_voxel = max_pts_per_voxel assert mode in ['max', 'avg'] pool_mapping = {'max': 0, 'avg': 1} self.mode = pool_mapping[mode]
[docs] def forward(self, rois: torch.Tensor, pts: torch.Tensor, pts_feature: torch.Tensor) -> torch.Tensor: """ Args: rois (torch.Tensor): [N, 7], in LiDAR coordinate, (x, y, z) is the bottom center of rois. pts (torch.Tensor): [npoints, 3], coordinates of input points. pts_feature (torch.Tensor): [npoints, C], features of input points. Returns: torch.Tensor: Pooled features whose shape is [N, out_x, out_y, out_z, C]. """ return RoIAwarePool3dFunction.apply(rois, pts, pts_feature, self.out_size, self.max_pts_per_voxel, self.mode)
class RoIAwarePool3dFunction(Function): @staticmethod def forward(ctx: Any, rois: torch.Tensor, pts: torch.Tensor, pts_feature: torch.Tensor, out_size: Union[int, tuple], max_pts_per_voxel: int, mode: int) -> torch.Tensor: """ Args: rois (torch.Tensor): [N, 7], in LiDAR coordinate, (x, y, z) is the bottom center of rois. pts (torch.Tensor): [npoints, 3], coordinates of input points. pts_feature (torch.Tensor): [npoints, C], features of input points. out_size (int or tuple): The size of output features. n or [n1, n2, n3]. max_pts_per_voxel (int): The maximum number of points per voxel. Default: 128. mode (int): Pooling method of RoIAware, 0 (max pool) or 1 (average pool). Returns: torch.Tensor: Pooled features whose shape is [N, out_x, out_y, out_z, C]. """ if isinstance(out_size, int): out_x = out_y = out_z = out_size else: assert len(out_size) == 3 assert mmengine.is_tuple_of(out_size, int) out_x, out_y, out_z = out_size num_rois = rois.shape[0] num_channels = pts_feature.shape[-1] num_pts = pts.shape[0] pooled_features = pts_feature.new_zeros( (num_rois, out_x, out_y, out_z, num_channels)) argmax = pts_feature.new_zeros( (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int) pts_idx_of_voxels = pts_feature.new_zeros( (num_rois, out_x, out_y, out_z, max_pts_per_voxel), dtype=torch.int) ext_module.roiaware_pool3d_forward( rois, pts, pts_feature, argmax, pts_idx_of_voxels, pooled_features, pool_method=mode) ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode, num_pts, num_channels) return pooled_features @staticmethod def backward( ctx: Any, grad_out: torch.Tensor ) -> Tuple[None, None, torch.Tensor, None, None, None]: ret = ctx.roiaware_pool3d_for_backward pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret grad_in = grad_out.new_zeros((num_pts, num_channels)) ext_module.roiaware_pool3d_backward( pts_idx_of_voxels, argmax, grad_out.contiguous(), grad_in, pool_method=mode) return None, None, grad_in, None, None, None
Read the Docs v: latest
Versions
master
latest
2.x
1.x
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.