Shortcuts

mmcv.ops.roipoint_pool3d 源代码

from torch import nn as nn
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['roipoint_pool3d_forward'])


[文档]class RoIPointPool3d(nn.Module): """Encode the geometry-specific features of each 3D proposal. Please refer to `Paper of PartA2 <https://arxiv.org/pdf/1907.03670.pdf>`_ for more details. Args: num_sampled_points (int, optional): Number of samples in each roi. Default: 512. """ def __init__(self, num_sampled_points=512): super().__init__() self.num_sampled_points = num_sampled_points
[文档] def forward(self, points, point_features, boxes3d): """ Args: points (torch.Tensor): Input points whose shape is (B, N, C). point_features (torch.Tensor): Features of input points whose shape is (B, N, C). boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7). Returns: pooled_features (torch.Tensor): The output pooled features whose shape is (B, M, 512, 3 + C). pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M). """ return RoIPointPool3dFunction.apply(points, point_features, boxes3d, self.num_sampled_points)
class RoIPointPool3dFunction(Function): @staticmethod def forward(ctx, points, point_features, boxes3d, num_sampled_points=512): """ Args: points (torch.Tensor): Input points whose shape is (B, N, C). point_features (torch.Tensor): Features of input points whose shape is (B, N, C). boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7). num_sampled_points (int, optional): The num of sampled points. Default: 512. Returns: pooled_features (torch.Tensor): The output pooled features whose shape is (B, M, 512, 3 + C). pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M). """ assert len(points.shape) == 3 and points.shape[2] == 3 batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[ 1], point_features.shape[2] pooled_boxes3d = boxes3d.view(batch_size, -1, 7) pooled_features = point_features.new_zeros( (batch_size, boxes_num, num_sampled_points, 3 + feature_len)) pooled_empty_flag = point_features.new_zeros( (batch_size, boxes_num)).int() ext_module.roipoint_pool3d_forward(points.contiguous(), pooled_boxes3d.contiguous(), point_features.contiguous(), pooled_features, pooled_empty_flag) return pooled_features, pooled_empty_flag @staticmethod def backward(ctx, grad_out): raise NotImplementedError
Read the Docs v: v1.4.0
Versions
latest
stable
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.