mmcv.ops.sparse_structure 源代码
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
[文档]def scatter_nd(indices: torch.Tensor, updates: torch.Tensor,
shape: torch.Tensor) -> torch.Tensor:
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully when
indice repeats, don't support repeat add which is supported in tensorflow.
"""
ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)
ndim = indices.shape[-1]
output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:]
flatted_indices = indices.view(-1, ndim)
slices = [flatted_indices[:, i] for i in range(ndim)]
slices += [Ellipsis]
ret[slices] = updates.view(*output_shape)
return ret
[文档]class SparseConvTensor:
def __init__(self,
features: torch.Tensor,
indices: torch.Tensor,
spatial_shape: Union[List, Tuple],
batch_size: int,
grid: Optional[torch.Tensor] = None):
self.features = features
self.indices = indices
if self.indices.dtype != torch.int32:
self.indices.int()
self.spatial_shape = spatial_shape
self.batch_size = batch_size
self.indice_dict: dict = {}
self.grid = grid
@property
def spatial_size(self):
return np.prod(self.spatial_shape)
def find_indice_pair(self, key):
if key is None:
return None
if key in self.indice_dict:
return self.indice_dict[key]
return None
def dense(self, channels_first: bool = True) -> torch.Tensor:
output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(self.indices.long(), self.features, output_shape)
if not channels_first:
return res
ndim = len(self.spatial_shape)
trans_params = list(range(0, ndim + 1))
trans_params.insert(1, ndim + 1)
return res.permute(*trans_params).contiguous()
@property
def sparity(self):
return (self.indices.shape[0] / np.prod(self.spatial_shape) /
self.batch_size)