Shortcuts

Source code for mmcv.ops.roipoint_pool3d

from typing import Any, Tuple

import torch
from torch import nn as nn
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['roipoint_pool3d_forward'])


[docs]class RoIPointPool3d(nn.Module): """Encode the geometry-specific features of each 3D proposal. Please refer to `Paper of PartA2 <https://arxiv.org/pdf/1907.03670.pdf>`_ for more details. Args: num_sampled_points (int, optional): Number of samples in each roi. Default: 512. """ def __init__(self, num_sampled_points: int = 512): super().__init__() self.num_sampled_points = num_sampled_points
[docs] def forward(self, points: torch.Tensor, point_features: torch.Tensor, boxes3d: torch.Tensor) -> Tuple[torch.Tensor]: """ Args: points (torch.Tensor): Input points whose shape is (B, N, C). point_features (torch.Tensor): Features of input points whose shape is (B, N, C). boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7). Returns: tuple[torch.Tensor]: A tuple contains two elements. The first one is the pooled features whose shape is (B, M, 512, 3 + C). The second is an empty flag whose shape is (B, M). """ return RoIPointPool3dFunction.apply(points, point_features, boxes3d, self.num_sampled_points)
class RoIPointPool3dFunction(Function): @staticmethod def forward( ctx: Any, points: torch.Tensor, point_features: torch.Tensor, boxes3d: torch.Tensor, num_sampled_points: int = 512 ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: points (torch.Tensor): Input points whose shape is (B, N, C). point_features (torch.Tensor): Features of input points whose shape is (B, N, C). boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7). num_sampled_points (int, optional): The num of sampled points. Default: 512. Returns: tuple[torch.Tensor]: A tuple contains two elements. The first one is the pooled features whose shape is (B, M, 512, 3 + C). The second is an empty flag whose shape is (B, M). """ assert len(points.shape) == 3 and points.shape[2] == 3 batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[ 1], point_features.shape[2] pooled_boxes3d = boxes3d.view(batch_size, -1, 7) pooled_features = point_features.new_zeros( (batch_size, boxes_num, num_sampled_points, 3 + feature_len)) pooled_empty_flag = point_features.new_zeros( (batch_size, boxes_num)).int() ext_module.roipoint_pool3d_forward(points.contiguous(), pooled_boxes3d.contiguous(), point_features.contiguous(), pooled_features, pooled_empty_flag) return pooled_features, pooled_empty_flag @staticmethod def backward(ctx: Any, grad_out: torch.Tensor) -> torch.Tensor: raise NotImplementedError
Read the Docs v: latest
Versions
master
latest
2.x
1.x
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.