Shortcuts

Source code for mmcv.ops.points_sampler

from typing import List

import torch
from torch import Tensor
from torch import nn as nn

from .furthest_point_sample import (furthest_point_sample,
                                    furthest_point_sample_with_dist)


def calc_square_dist(point_feat_a: Tensor,
                     point_feat_b: Tensor,
                     norm: bool = True) -> Tensor:
    """Calculating square distance between a and b.

    Args:
        point_feat_a (torch.Tensor): (B, N, C) Feature vector of each point.
        point_feat_b (torch.Tensor): (B, M, C) Feature vector of each point.
        norm (bool, optional): Whether to normalize the distance.
            Default: True.

    Returns:
        torch.Tensor: (B, N, M) Square distance between each point pair.
    """
    num_channel = point_feat_a.shape[-1]
    dist = torch.cdist(point_feat_a, point_feat_b)
    if norm:
        dist = dist / num_channel
    else:
        dist = torch.square(dist)
    return dist


def get_sampler_cls(sampler_type: str) -> nn.Module:
    """Get the type and mode of points sampler.

    Args:
        sampler_type (str): The type of points sampler.
            The valid value are "D-FPS", "F-FPS", or "FS".

    Returns:
        class: Points sampler type.
    """
    sampler_mappings = {
        'D-FPS': DFPSSampler,
        'F-FPS': FFPSSampler,
        'FS': FSSampler,
    }
    try:
        return sampler_mappings[sampler_type]
    except KeyError:
        raise KeyError(
            f'Supported `sampler_type` are {sampler_mappings.keys()}, but got \
                {sampler_type}')


[docs]class PointsSampler(nn.Module): """Points sampling. Args: num_point (list[int]): Number of sample points. fps_mod_list (list[str], optional): Type of FPS method, valid mod ['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS']. F-FPS: using feature distances for FPS. D-FPS: using Euclidean distances of points for FPS. FS: using F-FPS and D-FPS simultaneously. fps_sample_range_list (list[int], optional): Range of points to apply FPS. Default: [-1]. """ def __init__(self, num_point: List[int], fps_mod_list: List[str] = ['D-FPS'], fps_sample_range_list: List[int] = [-1]) -> None: super().__init__() # FPS would be applied to different fps_mod in the list, # so the length of the num_point should be equal to # fps_mod_list and fps_sample_range_list. assert len(num_point) == len(fps_mod_list) == len( fps_sample_range_list) self.num_point = num_point self.fps_sample_range_list = fps_sample_range_list self.samplers = nn.ModuleList() for fps_mod in fps_mod_list: self.samplers.append(get_sampler_cls(fps_mod)()) self.fp16_enabled = False
[docs] def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor: """ Args: points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of the points. features (torch.Tensor): (B, C, N) features of the points. Returns: torch.Tensor: (B, npoint, sample_num) Indices of sampled points. """ if points_xyz.dtype == torch.half: points_xyz = points_xyz.to(torch.float32) if features is not None and features.dtype == torch.half: features = features.to(torch.float32) indices = [] last_fps_end_index = 0 for fps_sample_range, sampler, npoint in zip( self.fps_sample_range_list, self.samplers, self.num_point): assert fps_sample_range < points_xyz.shape[1] if fps_sample_range == -1: sample_points_xyz = points_xyz[:, last_fps_end_index:] if features is not None: sample_features = features[:, :, last_fps_end_index:] else: sample_features = None else: sample_points_xyz = points_xyz[:, last_fps_end_index: fps_sample_range] if features is not None: sample_features = features[:, :, last_fps_end_index: fps_sample_range] else: sample_features = None fps_idx = sampler(sample_points_xyz.contiguous(), sample_features, npoint) indices.append(fps_idx + last_fps_end_index) last_fps_end_index = fps_sample_range indices = torch.cat(indices, dim=1) return indices
class DFPSSampler(nn.Module): """Using Euclidean distances of points for FPS.""" def __init__(self) -> None: super().__init__() def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor: """Sampling points with D-FPS.""" fps_idx = furthest_point_sample(points.contiguous(), npoint) return fps_idx class FFPSSampler(nn.Module): """Using feature distances for FPS.""" def __init__(self) -> None: super().__init__() def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor: """Sampling points with F-FPS.""" assert features is not None, \ 'feature input to FFPS_Sampler should not be None' features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2) features_dist = calc_square_dist( features_for_fps, features_for_fps, norm=False) fps_idx = furthest_point_sample_with_dist(features_dist, npoint) return fps_idx class FSSampler(nn.Module): """Using F-FPS and D-FPS simultaneously.""" def __init__(self) -> None: super().__init__() def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor: """Sampling points with FS_Sampling.""" assert features is not None, \ 'feature input to FS_Sampler should not be None' ffps_sampler = FFPSSampler() dfps_sampler = DFPSSampler() fps_idx_ffps = ffps_sampler(points, features, npoint) fps_idx_dfps = dfps_sampler(points, features, npoint) fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1) return fps_idx
Read the Docs v: latest
Versions
master
latest
2.x
1.x
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.