SyncBatchNorm¶
- class mmcv.ops.SyncBatchNorm(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, group: Optional[int] = None, stats_mode: str = 'default')[source]¶
Synchronized Batch Normalization.
- Parameters
num_features (int) – number of features/chennels in input tensor
eps (float, optional) – a value added to the denominator for numerical stability. Defaults to 1e-5.
momentum (float, optional) – the value used for the running_mean and running_var computation. Defaults to 0.1.
affine (bool, optional) – whether to use learnable affine parameters. Defaults to True.
track_running_stats (bool, optional) – whether to track the running mean and variance during training. When set to False, this module does not track such statistics, and initializes statistics buffers
running_mean
andrunning_var
asNone
. When these buffers areNone
, this module always uses batch statistics in both training and eval modes. Defaults to True.group (int, optional) – synchronization of stats happen within each process group individually. By default it is synchronization across the whole world. Defaults to None.
stats_mode (str, optional) – The statistical mode. Available options includes
'default'
and'N'
. Defaults to ‘default’. Whenstats_mode=='default'
, it computes the overall statistics using those from each worker with equal weight, i.e., the statistics are synchronized and simply divied bygroup
. This mode will produce inaccurate statistics when empty tensors occur. Whenstats_mode=='N'
, it compute the overall statistics using the total number of batches in each worker ignoring the number of group, i.e., the statistics are synchronized and then divied by the total batchN
. This mode is beneficial when empty tensors occur during training, as it average the total mean by the real number of batch.
- forward(input: torch.Tensor) → torch.Tensor[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.