Shortcuts

Source code for mmcv.engine.test

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import pickle
import shutil
import tempfile
import time
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.utils.data import DataLoader

import mmcv
from mmcv.runner import get_dist_info


[docs]def single_gpu_test(model: nn.Module, data_loader: DataLoader) -> list: """Test model with a single gpu. This method tests model with a single gpu and displays test progress bar. Args: model (nn.Module): Model to be tested. data_loader (nn.Dataloader): Pytorch data loader. Returns: list: The prediction results. """ model.eval() results = [] dataset = data_loader.dataset prog_bar = mmcv.ProgressBar(len(dataset)) for data in data_loader: with torch.no_grad(): result = model(return_loss=False, **data) results.extend(result) # Assume result has the same length of batch_size # refer to https://github.com/open-mmlab/mmcv/issues/985 batch_size = len(result) for _ in range(batch_size): prog_bar.update() return results
[docs]def multi_gpu_test(model: nn.Module, data_loader: DataLoader, tmpdir: Optional[str] = None, gpu_collect: bool = False) -> Optional[list]: """Test model with multiple gpus. This method tests model with multiple gpus and collects the results under two different modes: gpu and cpu modes. By setting ``gpu_collect=True``, it encodes results to gpu tensors and use gpu communication for results collection. On cpu mode it saves the results on different gpus to ``tmpdir`` and collects them by the rank 0 worker. Args: model (nn.Module): Model to be tested. data_loader (nn.Dataloader): Pytorch data loader. tmpdir (str): Path of directory to save the temporary results from different gpus under cpu mode. gpu_collect (bool): Option to use either gpu or cpu to collect results. Returns: list: The prediction results. """ model.eval() results = [] dataset = data_loader.dataset rank, world_size = get_dist_info() if rank == 0: prog_bar = mmcv.ProgressBar(len(dataset)) time.sleep(2) # This line can prevent deadlock problem in some cases. for i, data in enumerate(data_loader): with torch.no_grad(): result = model(return_loss=False, **data) results.extend(result) if rank == 0: batch_size = len(result) batch_size_all = batch_size * world_size if batch_size_all + prog_bar.completed > len(dataset): batch_size_all = len(dataset) - prog_bar.completed for _ in range(batch_size_all): prog_bar.update() # collect results from all ranks if gpu_collect: result_from_ranks = collect_results_gpu(results, len(dataset)) else: result_from_ranks = collect_results_cpu(results, len(dataset), tmpdir) return result_from_ranks
[docs]def collect_results_cpu(result_part: list, size: int, tmpdir: Optional[str] = None) -> Optional[list]: """Collect results under cpu mode. On cpu mode, this function will save the results on different gpus to ``tmpdir`` and collect them by the rank 0 worker. Args: result_part (list): Result list containing result parts to be collected. size (int): Size of the results, commonly equal to length of the results. tmpdir (str | None): temporal directory for collected results to store. If set to None, it will create a random temporal directory for it. Returns: list: The collected results. """ rank, world_size = get_dist_info() # create a tmp dir if it is not specified if tmpdir is None: MAX_LEN = 512 # 32 is whitespace dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8, device='cuda') if rank == 0: mmcv.mkdir_or_exist('.dist_test') tmpdir = tempfile.mkdtemp(dir='.dist_test') tmpdir = torch.tensor( bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') dir_tensor[:len(tmpdir)] = tmpdir dist.broadcast(dir_tensor, 0) tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() else: mmcv.mkdir_or_exist(tmpdir) # dump the part result to the dir part_file = osp.join(tmpdir, f'part_{rank}.pkl') # type: ignore mmcv.dump(result_part, part_file) dist.barrier() # collect all parts if rank != 0: return None else: # load results of all parts from tmp dir part_list = [] for i in range(world_size): part_file = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore part_result = mmcv.load(part_file) # When data is severely insufficient, an empty part_result # on a certain gpu could makes the overall outputs empty. if part_result: part_list.append(part_result) # sort the results ordered_results = [] for res in zip(*part_list): ordered_results.extend(list(res)) # the dataloader may pad some samples ordered_results = ordered_results[:size] # remove tmp dir shutil.rmtree(tmpdir) # type: ignore return ordered_results
[docs]def collect_results_gpu(result_part: list, size: int) -> Optional[list]: """Collect results under gpu mode. On gpu mode, this function will encode results to gpu tensors and use gpu communication for results collection. Args: result_part (list): Result list containing result parts to be collected. size (int): Size of the results, commonly equal to length of the results. Returns: list: The collected results. """ rank, world_size = get_dist_info() # dump result part to tensor with pickle part_tensor = torch.tensor( bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') # gather all result part tensor shape shape_tensor = torch.tensor(part_tensor.shape, device='cuda') shape_list = [shape_tensor.clone() for _ in range(world_size)] dist.all_gather(shape_list, shape_tensor) # padding result part tensor to max length shape_max = torch.tensor(shape_list).max() part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') part_send[:shape_tensor[0]] = part_tensor part_recv_list = [ part_tensor.new_zeros(shape_max) for _ in range(world_size) ] # gather all result part dist.all_gather(part_recv_list, part_send) if rank == 0: part_list = [] for recv, shape in zip(part_recv_list, shape_list): part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()) # When data is severely insufficient, an empty part_result # on a certain gpu could makes the overall outputs empty. if part_result: part_list.append(part_result) # sort the results ordered_results = [] for res in zip(*part_list): ordered_results.extend(list(res)) # the dataloader may pad some samples ordered_results = ordered_results[:size] return ordered_results else: return None
Read the Docs v: master
Versions
master
latest
2.x
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.