Shortcuts

mmcv.engine.test 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import pickle
import shutil
import tempfile
import time

import torch
import torch.distributed as dist

import mmcv
from mmcv.runner import get_dist_info


[文档]def single_gpu_test(model, data_loader): """Test model with a single gpu. This method tests model with a single gpu and displays test progress bar. Args: model (nn.Module): Model to be tested. data_loader (nn.Dataloader): Pytorch data loader. Returns: list: The prediction results. """ model.eval() results = [] dataset = data_loader.dataset prog_bar = mmcv.ProgressBar(len(dataset)) for data in data_loader: with torch.no_grad(): result = model(return_loss=False, **data) results.extend(result) # Assume result has the same length of batch_size # refer to https://github.com/open-mmlab/mmcv/issues/985 batch_size = len(result) for _ in range(batch_size): prog_bar.update() return results
[文档]def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): """Test model with multiple gpus. This method tests model with multiple gpus and collects the results under two different modes: gpu and cpu modes. By setting ``gpu_collect=True``, it encodes results to gpu tensors and use gpu communication for results collection. On cpu mode it saves the results on different gpus to ``tmpdir`` and collects them by the rank 0 worker. Args: model (nn.Module): Model to be tested. data_loader (nn.Dataloader): Pytorch data loader. tmpdir (str): Path of directory to save the temporary results from different gpus under cpu mode. gpu_collect (bool): Option to use either gpu or cpu to collect results. Returns: list: The prediction results. """ model.eval() results = [] dataset = data_loader.dataset rank, world_size = get_dist_info() if rank == 0: prog_bar = mmcv.ProgressBar(len(dataset)) time.sleep(2) # This line can prevent deadlock problem in some cases. for i, data in enumerate(data_loader): with torch.no_grad(): result = model(return_loss=False, **data) results.extend(result) if rank == 0: batch_size = len(result) batch_size_all = batch_size * world_size if batch_size_all + prog_bar.completed > len(dataset): batch_size_all = len(dataset) - prog_bar.completed for _ in range(batch_size_all): prog_bar.update() # collect results from all ranks if gpu_collect: results = collect_results_gpu(results, len(dataset)) else: results = collect_results_cpu(results, len(dataset), tmpdir) return results
[文档]def collect_results_cpu(result_part, size, tmpdir=None): """Collect results under cpu mode. On cpu mode, this function will save the results on different gpus to ``tmpdir`` and collect them by the rank 0 worker. Args: result_part (list): Result list containing result parts to be collected. size (int): Size of the results, commonly equal to length of the results. tmpdir (str | None): temporal directory for collected results to store. If set to None, it will create a random temporal directory for it. Returns: list: The collected results. """ rank, world_size = get_dist_info() # create a tmp dir if it is not specified if tmpdir is None: MAX_LEN = 512 # 32 is whitespace dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8, device='cuda') if rank == 0: mmcv.mkdir_or_exist('.dist_test') tmpdir = tempfile.mkdtemp(dir='.dist_test') tmpdir = torch.tensor( bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') dir_tensor[:len(tmpdir)] = tmpdir dist.broadcast(dir_tensor, 0) tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() else: mmcv.mkdir_or_exist(tmpdir) # dump the part result to the dir mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl')) dist.barrier() # collect all parts if rank != 0: return None else: # load results of all parts from tmp dir part_list = [] for i in range(world_size): part_file = osp.join(tmpdir, f'part_{i}.pkl') part_result = mmcv.load(part_file) # When data is severely insufficient, an empty part_result # on a certain gpu could makes the overall outputs empty. if part_result: part_list.append(part_result) # sort the results ordered_results = [] for res in zip(*part_list): ordered_results.extend(list(res)) # the dataloader may pad some samples ordered_results = ordered_results[:size] # remove tmp dir shutil.rmtree(tmpdir) return ordered_results
[文档]def collect_results_gpu(result_part, size): """Collect results under gpu mode. On gpu mode, this function will encode results to gpu tensors and use gpu communication for results collection. Args: result_part (list): Result list containing result parts to be collected. size (int): Size of the results, commonly equal to length of the results. Returns: list: The collected results. """ rank, world_size = get_dist_info() # dump result part to tensor with pickle part_tensor = torch.tensor( bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') # gather all result part tensor shape shape_tensor = torch.tensor(part_tensor.shape, device='cuda') shape_list = [shape_tensor.clone() for _ in range(world_size)] dist.all_gather(shape_list, shape_tensor) # padding result part tensor to max length shape_max = torch.tensor(shape_list).max() part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') part_send[:shape_tensor[0]] = part_tensor part_recv_list = [ part_tensor.new_zeros(shape_max) for _ in range(world_size) ] # gather all result part dist.all_gather(part_recv_list, part_send) if rank == 0: part_list = [] for recv, shape in zip(part_recv_list, shape_list): part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()) # When data is severely insufficient, an empty part_result # on a certain gpu could makes the overall outputs empty. if part_result: part_list.append(part_result) # sort the results ordered_results = [] for res in zip(*part_list): ordered_results.extend(list(res)) # the dataloader may pad some samples ordered_results = ordered_results[:size] return ordered_results
Read the Docs v: v1.4.2
Versions
latest
stable
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.