Shortcuts

Source code for mmcv.runner.dist_utils

# Copyright (c) OpenMMLab. All rights reserved.
import functools
import os
import subprocess
from collections import OrderedDict

import torch
import torch.multiprocessing as mp
from torch import distributed as dist
from torch._utils import (_flatten_dense_tensors, _take_tensors,
                          _unflatten_dense_tensors)


def init_dist(launcher, backend='nccl', **kwargs):
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')
    if launcher == 'pytorch':
        _init_dist_pytorch(backend, **kwargs)
    elif launcher == 'mpi':
        _init_dist_mpi(backend, **kwargs)
    elif launcher == 'slurm':
        _init_dist_slurm(backend, **kwargs)
    else:
        raise ValueError(f'Invalid launcher type: {launcher}')


def _init_dist_pytorch(backend, **kwargs):
    # TODO: use local_rank instead of rank % num_gpus
    rank = int(os.environ['RANK'])
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(rank % num_gpus)
    dist.init_process_group(backend=backend, **kwargs)


def _init_dist_mpi(backend, **kwargs):
    # TODO: use local_rank instead of rank % num_gpus
    rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(rank % num_gpus)
    dist.init_process_group(backend=backend, **kwargs)


def _init_dist_slurm(backend, port=None):
    """Initialize slurm distributed training environment.

    If argument ``port`` is not specified, then the master port will be system
    environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
    environment variable, then a default port ``29500`` will be used.

    Args:
        backend (str): Backend of torch.distributed.
        port (int, optional): Master port. Defaults to None.
    """
    proc_id = int(os.environ['SLURM_PROCID'])
    ntasks = int(os.environ['SLURM_NTASKS'])
    node_list = os.environ['SLURM_NODELIST']
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(proc_id % num_gpus)
    addr = subprocess.getoutput(
        f'scontrol show hostname {node_list} | head -n1')
    # specify master port
    if port is not None:
        os.environ['MASTER_PORT'] = str(port)
    elif 'MASTER_PORT' in os.environ:
        pass  # use MASTER_PORT in the environment variable
    else:
        # 29500 is torch.distributed default port
        os.environ['MASTER_PORT'] = '29500'
    # use MASTER_ADDR in the environment variable if it already exists
    if 'MASTER_ADDR' not in os.environ:
        os.environ['MASTER_ADDR'] = addr
    os.environ['WORLD_SIZE'] = str(ntasks)
    os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
    os.environ['RANK'] = str(proc_id)
    dist.init_process_group(backend=backend)


def get_dist_info():
    if dist.is_available() and dist.is_initialized():
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1
    return rank, world_size


def master_only(func):

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        rank, _ = get_dist_info()
        if rank == 0:
            return func(*args, **kwargs)

    return wrapper


[docs]def allreduce_params(params, coalesce=True, bucket_size_mb=-1): """Allreduce parameters. Args: params (list[torch.Parameters]): List of parameters or buffers of a model. coalesce (bool, optional): Whether allreduce parameters as a whole. Defaults to True. bucket_size_mb (int, optional): Size of bucket, the unit is MB. Defaults to -1. """ _, world_size = get_dist_info() if world_size == 1: return params = [param.data for param in params] if coalesce: _allreduce_coalesced(params, world_size, bucket_size_mb) else: for tensor in params: dist.all_reduce(tensor.div_(world_size))
[docs]def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): """Allreduce gradients. Args: params (list[torch.Parameters]): List of parameters of a model coalesce (bool, optional): Whether allreduce parameters as a whole. Defaults to True. bucket_size_mb (int, optional): Size of bucket, the unit is MB. Defaults to -1. """ grads = [ param.grad.data for param in params if param.requires_grad and param.grad is not None ] _, world_size = get_dist_info() if world_size == 1: return if coalesce: _allreduce_coalesced(grads, world_size, bucket_size_mb) else: for tensor in grads: dist.all_reduce(tensor.div_(world_size))
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): if bucket_size_mb > 0: bucket_size_bytes = bucket_size_mb * 1024 * 1024 buckets = _take_tensors(tensors, bucket_size_bytes) else: buckets = OrderedDict() for tensor in tensors: tp = tensor.type() if tp not in buckets: buckets[tp] = [] buckets[tp].append(tensor) buckets = buckets.values() for bucket in buckets: flat_tensors = _flatten_dense_tensors(bucket) dist.all_reduce(flat_tensors) flat_tensors.div_(world_size) for tensor, synced in zip( bucket, _unflatten_dense_tensors(flat_tensors, bucket)): tensor.copy_(synced)
Read the Docs v: v1.4.7
Versions
master
latest
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.