Source code for mmcv.ops.group_points
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import torch
from torch import nn as nn
from torch.autograd import Function
from ..utils import ext_loader
from .ball_query import ball_query
from .knn import knn
ext_module = ext_loader.load_ext('_ext', [
'group_points_forward', 'group_points_backward',
'stack_group_points_forward', 'stack_group_points_backward'
])
[docs]class QueryAndGroup(nn.Module):
"""Groups points with a ball query of radius.
Args:
max_radius (float): The maximum radius of the balls.
If None is given, we will use kNN sampling instead of ball query.
sample_num (int): Maximum number of features to gather in the ball.
min_radius (float, optional): The minimum radius of the balls.
Default: 0.
use_xyz (bool, optional): Whether to use xyz.
Default: True.
return_grouped_xyz (bool, optional): Whether to return grouped xyz.
Default: False.
normalize_xyz (bool, optional): Whether to normalize xyz.
Default: False.
uniform_sample (bool, optional): Whether to sample uniformly.
Default: False
return_unique_cnt (bool, optional): Whether to return the count of
unique samples. Default: False.
return_grouped_idx (bool, optional): Whether to return grouped idx.
Default: False.
"""
def __init__(self,
max_radius: float,
sample_num: int,
min_radius: float = 0.,
use_xyz: bool = True,
return_grouped_xyz: bool = False,
normalize_xyz: bool = False,
uniform_sample: bool = False,
return_unique_cnt: bool = False,
return_grouped_idx: bool = False):
super().__init__()
self.max_radius = max_radius
self.min_radius = min_radius
self.sample_num = sample_num
self.use_xyz = use_xyz
self.return_grouped_xyz = return_grouped_xyz
self.normalize_xyz = normalize_xyz
self.uniform_sample = uniform_sample
self.return_unique_cnt = return_unique_cnt
self.return_grouped_idx = return_grouped_idx
if self.return_unique_cnt:
assert self.uniform_sample, \
'uniform_sample should be True when ' \
'returning the count of unique samples'
if self.max_radius is None:
assert not self.normalize_xyz, \
'can not normalize grouped xyz when max_radius is None'
[docs] def forward(
self,
points_xyz: torch.Tensor,
center_xyz: torch.Tensor,
features: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple]:
"""
Args:
points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of the
points.
center_xyz (torch.Tensor): (B, npoint, 3) coordinates of the
centriods.
features (torch.Tensor): (B, C, N) The features of grouped
points.
Returns:
Tuple | torch.Tensor: (B, 3 + C, npoint, sample_num) Grouped
concatenated coordinates and features of points.
"""
# if self.max_radius is None, we will perform kNN instead of ball query
# idx is of shape [B, npoint, sample_num]
if self.max_radius is None:
idx = knn(self.sample_num, points_xyz, center_xyz, False)
idx = idx.transpose(1, 2).contiguous()
else:
idx = ball_query(self.min_radius, self.max_radius, self.sample_num,
points_xyz, center_xyz)
if self.uniform_sample:
unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
for i_batch in range(idx.shape[0]):
for i_region in range(idx.shape[1]):
unique_ind = torch.unique(idx[i_batch, i_region, :])
num_unique = unique_ind.shape[0]
unique_cnt[i_batch, i_region] = num_unique
sample_ind = torch.randint(
0,
num_unique, (self.sample_num - num_unique, ),
dtype=torch.long)
all_ind = torch.cat((unique_ind, unique_ind[sample_ind]))
idx[i_batch, i_region, :] = all_ind
xyz_trans = points_xyz.transpose(1, 2).contiguous()
# (B, 3, npoint, sample_num)
grouped_xyz = grouping_operation(xyz_trans, idx)
grouped_xyz_diff = grouped_xyz - \
center_xyz.transpose(1, 2).unsqueeze(-1) # relative offsets
if self.normalize_xyz:
grouped_xyz_diff /= self.max_radius
if features is not None:
grouped_features = grouping_operation(features, idx)
if self.use_xyz:
# (B, C + 3, npoint, sample_num)
new_features = torch.cat([grouped_xyz_diff, grouped_features],
dim=1)
else:
new_features = grouped_features
else:
assert (self.use_xyz
), 'Cannot have not features and not use xyz as a feature!'
new_features = grouped_xyz_diff
ret = [new_features]
if self.return_grouped_xyz:
ret.append(grouped_xyz)
if self.return_unique_cnt:
ret.append(unique_cnt)
if self.return_grouped_idx:
ret.append(idx)
if len(ret) == 1:
return ret[0]
else:
return tuple(ret)
[docs]class GroupAll(nn.Module):
"""Group xyz with feature.
Args:
use_xyz (bool): Whether to use xyz.
"""
def __init__(self, use_xyz: bool = True):
super().__init__()
self.use_xyz = use_xyz
[docs] def forward(self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
features: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
xyz (Tensor): (B, N, 3) xyz coordinates of the features.
new_xyz (Tensor): new xyz coordinates of the features.
features (Tensor): (B, C, N) features to group.
Returns:
Tensor: (B, C + 3, 1, N) Grouped feature.
"""
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
if features is not None:
grouped_features = features.unsqueeze(2)
if self.use_xyz:
# (B, 3 + C, 1, N)
new_features = torch.cat([grouped_xyz, grouped_features],
dim=1)
else:
new_features = grouped_features
else:
new_features = grouped_xyz
return new_features
class GroupingOperation(Function):
"""Group feature with given index."""
@staticmethod
def forward(
ctx,
features: torch.Tensor,
indices: torch.Tensor,
features_batch_cnt: Optional[torch.Tensor] = None,
indices_batch_cnt: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
features (Tensor): Tensor of features to group, input shape is
(B, C, N) or stacked inputs (N1 + N2 ..., C).
indices (Tensor): The indices of features to group with, input
shape is (B, npoint, nsample) or stacked inputs
(M1 + M2 ..., nsample).
features_batch_cnt (Tensor, optional): Input features nums in
each batch, just like (N1, N2, ...). Defaults to None.
New in version 1.7.0.
indices_batch_cnt (Tensor, optional): Input indices nums in
each batch, just like (M1, M2, ...). Defaults to None.
New in version 1.7.0.
Returns:
Tensor: Grouped features, the shape is (B, C, npoint, nsample)
or (M1 + M2 ..., C, nsample).
"""
features = features.contiguous()
indices = indices.contiguous()
if features_batch_cnt is not None and indices_batch_cnt is not None:
assert features_batch_cnt.dtype == torch.int
assert indices_batch_cnt.dtype == torch.int
M, nsample = indices.size()
N, C = features.size()
B = indices_batch_cnt.shape[0]
output = features.new_zeros((M, C, nsample))
ext_module.stack_group_points_forward(
features,
features_batch_cnt,
indices,
indices_batch_cnt,
output,
b=B,
m=M,
c=C,
nsample=nsample)
ctx.for_backwards = (B, N, indices, features_batch_cnt,
indices_batch_cnt)
else:
B, nfeatures, nsample = indices.size()
_, C, N = features.size()
output = features.new_zeros(B, C, nfeatures, nsample)
ext_module.group_points_forward(
features,
indices,
output,
b=B,
c=C,
n=N,
npoints=nfeatures,
nsample=nsample)
ctx.for_backwards = (indices, N)
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor) -> Tuple:
"""
Args:
grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
of the output from forward.
Returns:
Tensor: (B, C, N) gradient of the features.
"""
if len(ctx.for_backwards) != 5:
idx, N = ctx.for_backwards
B, C, npoint, nsample = grad_out.size()
grad_features = grad_out.new_zeros(B, C, N)
grad_out_data = grad_out.data.contiguous()
ext_module.group_points_backward(
grad_out_data,
idx,
grad_features.data,
b=B,
c=C,
n=N,
npoints=npoint,
nsample=nsample)
return grad_features, None
else:
B, N, idx, features_batch_cnt, idx_batch_cnt = ctx.for_backwards
M, C, nsample = grad_out.size()
grad_features = grad_out.new_zeros(N, C)
grad_out_data = grad_out.data.contiguous()
ext_module.stack_group_points_backward(
grad_out_data,
idx,
idx_batch_cnt,
features_batch_cnt,
grad_features.data,
b=B,
c=C,
m=M,
n=N,
nsample=nsample)
return grad_features, None, None, None
grouping_operation = GroupingOperation.apply