Shortcuts

mmcv.ops.sparse_structure 源代码

from typing import List, Optional, Tuple, Union

import numpy as np
import torch


[文档]def scatter_nd(indices: torch.Tensor, updates: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: """pytorch edition of tensorflow scatter_nd. this function don't contain except handle code. so use this carefully when indice repeats, don't support repeat add which is supported in tensorflow. """ ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device) ndim = indices.shape[-1] output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:] flatted_indices = indices.view(-1, ndim) slices = [flatted_indices[:, i] for i in range(ndim)] slices += [Ellipsis] ret[slices] = updates.view(*output_shape) return ret
[文档]class SparseConvTensor: def __init__(self, features: torch.Tensor, indices: torch.Tensor, spatial_shape: Union[List, Tuple], batch_size: int, grid: Optional[torch.Tensor] = None): self.features = features self.indices = indices if self.indices.dtype != torch.int32: self.indices.int() self.spatial_shape = spatial_shape self.batch_size = batch_size self.indice_dict: dict = {} self.grid = grid @property def spatial_size(self): return np.prod(self.spatial_shape) def find_indice_pair(self, key): if key is None: return None if key in self.indice_dict: return self.indice_dict[key] return None def dense(self, channels_first: bool = True) -> torch.Tensor: output_shape = [self.batch_size] + list( self.spatial_shape) + [self.features.shape[1]] res = scatter_nd(self.indices.long(), self.features, output_shape) if not channels_first: return res ndim = len(self.spatial_shape) trans_params = list(range(0, ndim + 1)) trans_params.insert(1, ndim + 1) return res.permute(*trans_params).contiguous() @property def sparity(self): return (self.indices.shape[0] / np.prod(self.spatial_shape) / self.batch_size)