Shortcuts

欢迎来到 MMCV 的中文文档!

您可以在页面左下角切换中英文文档。

介绍 MMCV

MMCV 是一个面向计算机视觉的基础库,它提供了以下功能:

MMCV 支持多种平台,包括:

  • Linux

  • Windows

  • macOS

它支持的 OpenMMLab 项目:

  • MMClassification: OpenMMLab 图像分类工具箱

  • MMDetection: OpenMMLab 目标检测工具箱

  • MMDetection3D: OpenMMLab 新一代通用 3D 目标检测平台

  • MMRotate: OpenMMLab 旋转框检测工具箱与测试基准

  • MMYOLO: OpenMMLab YOLO 系列工具箱与测试基准

  • MMSegmentation: OpenMMLab 语义分割工具箱

  • MMOCR: OpenMMLab 全流程文字检测识别理解工具箱

  • MMPose: OpenMMLab 姿态估计工具箱

  • MMHuman3D: OpenMMLab 人体参数化模型工具箱与测试基准

  • MMSelfSup: OpenMMLab 自监督学习工具箱与测试基准

  • MMRazor: OpenMMLab 模型压缩工具箱与测试基准

  • MMFewShot: OpenMMLab 少样本学习工具箱与测试基准

  • MMAction2: OpenMMLab 新一代视频理解工具箱

  • MMTracking: OpenMMLab 一体化视频目标感知平台

  • MMFlow: OpenMMLab 光流估计工具箱与测试基准

  • MMEditing: OpenMMLab 图像视频编辑工具箱

  • MMGeneration: OpenMMLab 图片视频生成模型工具箱

  • MMDeploy: OpenMMLab 模型部署框架

安装 MMCV

MMCV 有两个版本:

  • mmcv: 完整版,包含所有的特性以及丰富的开箱即用的 CPU 和 CUDA 算子。注意,完整版本可能需要更长时间来编译。

  • mmcv-lite: 精简版,不包含 CPU 和 CUDA 算子但包含其余所有特性和功能,类似 MMCV 1.0 之前的版本。如果你不需要使用算子的话,精简版可以作为一个考虑选项。

警告

请不要在同一个环境中安装两个版本,否则可能会遇到类似 ModuleNotFound 的错误。在安装一个版本之前,需要先卸载另一个。如果 CUDA 可用,强烈推荐安装 mmcv

安装 mmcv

在安装 mmcv 之前,请确保 PyTorch 已经成功安装在环境中,可以参考 PyTorch 官方安装文档。可使用以下命令验证

python -c 'import torch;print(torch.__version__)'

如果输出版本信息,则表示 PyTorch 已安装。

使用 mim 安装(推荐)

mim 是 OpenMMLab 项目的包管理工具,使用它可以很方便地安装 mmcv。

pip install -U openmim
mim install mmcv

如果发现上述的安装命令没有使用预编译包(以 .whl 结尾)而是使用源码包(以 .tar.gz 结尾)安装,则有可能是我们没有提供和当前环境的 PyTorch 版本、CUDA 版本相匹配的 mmcv 预编译包,此时,你可以源码安装 mmcv

使用预编译包的安装日志

Looking in links: https://download.openmmlab.com/mmcv/dist/cu102/torch1.8.0/index.html
Collecting mmcv
Downloading https://download.openmmlab.com/mmcv/dist/cu102/torch1.8.0/mmcv-2.0.0-cp38-cp38-manylinux1_x86_64.whl

使用源码包的安装日志

Looking in links: https://download.openmmlab.com/mmcv/dist/cu102/torch1.8.0/index.html
Collecting mmcv==2.0.0
Downloading mmcv-2.0.0.tar.gz

如需安装指定版本的 mmcv,例如安装 2.0.0 版本的 mmcv,可使用以下命令

mim install mmcv==2.0.0

注解

如果你打算使用 opencv-python-headless 而不是 opencv-python,例如在一个很小的容器环境或者没有图形用户界面的服务器中,你可以先安装 opencv-python-headless,这样在安装 mmcv 依赖的过程中会跳过 opencv-python

另外,如果安装依赖库的时间过长,可以指定 pypi 源

mim install "mmcv>=2.0.0rc1" -i https://pypi.tuna.tsinghua.edu.cn/simple

安装完成后可以运行 check_installation.py 脚本检查 mmcv 是否安装成功。

使用 pip 安装

使用以下命令查看 CUDA 和 PyTorch 的版本

python -c 'import torch;print(torch.__version__);print(torch.version.cuda)'

根据系统的类型、CUDA 版本、PyTorch 版本以及 MMCV 版本选择相应的安装命令





如果在上面的下拉框中没有找到对应的版本,则可能是没有对应 PyTorch 或者 CUDA 或者 mmcv 版本的预编译包,此时,你可以源码安装 mmcv

注解

PyTorch 在 1.x.0 和 1.x.1 之间通常是兼容的,故 mmcv 只提供 1.x.0 的编译包。如果你 的 PyTorch 版本是 1.x.1,你可以放心地安装在 1.x.0 版本编译的 mmcv。例如,如果你的 PyTorch 版本是 1.8.1,你可以放心选择 1.8.x。

注解

如果你打算使用 opencv-python-headless 而不是 opencv-python,例如在一个很小的容器环境或者没有图形用户界面的服务器中,你可以先安装 opencv-python-headless,这样在安装 mmcv 依赖的过程中会跳过 opencv-python

另外,如果安装依赖库的时间过长,可以指定 pypi 源

pip install mmcv -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html -i https://pypi.tuna.tsinghua.edu.cn/simple

安装完成后可以运行 check_installation.py 脚本检查 mmcv 是否安装成功。

使用 docker 镜像

先将算法库克隆到本地再构建镜像

git clone https://github.com/open-mmlab/mmcv.git && cd mmcv
docker build -t mmcv -f docker/release/Dockerfile .

也可以直接使用下面的命令构建镜像

docker build -t mmcv https://github.com/open-mmlab/mmcv.git#main:docker/release

Dockerfile 默认安装最新的 mmcv,如果你想要指定版本,可以使用下面的命令

docker image build -t mmcv -f docker/release/Dockerfile --build-arg MMCV=2.0.0 .

如果你想要使用其他版本的 PyTorch 和 CUDA,你可以在构建镜像时指定它们的版本。

例如指定 PyTorch 的版本是 1.11,CUDA 的版本是 11.3

docker build -t mmcv -f docker/release/Dockerfile \
    --build-arg PYTORCH=1.11.0 \
    --build-arg CUDA=11.3 \
    --build-arg CUDNN=8 \
    --build-arg MMCV=2.0.0 .

更多 PyTorch 和 CUDA 镜像可以点击 dockerhub/pytorch 查看。

安装 mmcv-lite

如果你需要使用和 PyTorch 相关的模块,请确保 PyTorch 已经成功安装在环境中,可以参考 PyTorch 官方安装文档

pip install mmcv-lite

从源码编译 MMCV

编译 mmcv

在编译 mmcv 之前,请确保 PyTorch 已经成功安装在环境中,可以参考 PyTorch 官方安装文档。可使用以下命令验证

python -c 'import torch;print(torch.__version__)'

注解

  • 如果克隆代码仓库的速度过慢,可以使用以下命令克隆(注意:gitee 的 mmcv 不一定和 github 的保持一致,因为每天只同步一次)

git clone https://gitee.com/open-mmlab/mmcv.git
  • 如果打算使用 opencv-python-headless 而不是 opencv-python,例如在一个很小的容器环境或者没有图形用户界面的服务器中,你可以先安装 opencv-python-headless,这样在安装 mmcv 依赖的过程中会跳过 opencv-python

  • 如果编译过程安装依赖库的时间过长,可以设置 pypi 源

pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

在 Linux 上编译 mmcv

| TODO: 视频教程

  1. 克隆代码仓库

    git clone https://github.com/open-mmlab/mmcv.git
    cd mmcv
    
  2. 安装 ninjapsutil 以加快编译速度

    pip install -r requirements/optional.txt
    
  3. 检查 nvcc 的版本(要求大于等于 9.2,如果没有 GPU,可以跳过)

    nvcc --version
    

    上述命令如果输出以下信息,表示 nvcc 的设置没有问题,否则需要设置 CUDA_HOME

    nvcc: NVIDIA (R) Cuda compiler driver
    Copyright (c) 2005-2020 NVIDIA Corporation
    Built on Mon_Nov_30_19:08:53_PST_2020
    Cuda compilation tools, release 11.2, V11.2.67
    Build cuda_11.2.r11.2/compiler.29373293_0
    

    注解

    如果想要支持 ROCm,可以参考 AMD ROCm 安装 ROCm。

  4. 检查 gcc 的版本(要求大于等于5.4

    gcc --version
    
  5. 开始编译(预估耗时 10 分钟)

    pip install -e . -v
    
  6. 验证安装

    python .dev_scripts/check_installation.py
    

    如果上述命令没有报错,说明安装成功。如有报错,请查看问题解决页面是否已经有解决方案。

    如果没有找到解决方案,欢迎提 issue

在 macOS 上编译 mmcv

| TODO: 视频教程

注解

如果你使用的是搭载 apple silicon 的 mac 设备,请安装 PyTorch 1.13+ 的版本,否则会遇到 issues#2218 中的问题。

  1. 克隆代码仓库

    git clone https://github.com/open-mmlab/mmcv.git
    cd mmcv
    
  2. 安装 ninjapsutil 以加快编译速度

    pip install -r requirements/optional.txt
    
  3. 开始编译

    pip install -e .
    
  4. 验证安装

    python .dev_scripts/check_installation.py
    

    如果上述命令没有报错,说明安装成功。如有报错,请查看问题解决页面是否已经有解决方案。

    如果没有找到解决方案,欢迎提 issue

在 Windows 上编译 mmcv

| TODO: 视频教程

在 Windows 上编译 mmcv 比 Linux 复杂,本节将一步步介绍如何在 Windows 上编译 mmcv。

依赖项

请先安装以下的依赖项:

  • Git:安装期间,请选择 add git to Path

  • Visual Studio Community 2019:用于编译 C++ 和 CUDA 代码

  • Miniconda:包管理工具

  • CUDA 10.2:如果只需要 CPU 版本可以不安装 CUDA,安装 CUDA 时,可根据需要进行自定义安装。如果已经安装新版本的显卡驱动,建议取消驱动程序的安装

注解

如果不清楚如何安装以上依赖,请参考Windows 环境从零安装 mmcv。 另外,你需要知道如何在 Windows 上设置变量环境,尤其是 “PATH” 的设置,以下安装过程都会用到。

通用步骤
  1. 从 Windows 菜单启动 Anaconda 命令行

    如 Miniconda 安装程序建议,不要使用原始的 cmd.exe 或是 powershell.exe。命令行有两个版本,一个基于 PowerShell,一个基于传统的 cmd.exe。请注意以下说明都是使用的基于 PowerShell

  2. 创建一个新的 Conda 环境

    (base) PS C:\Users\xxx> conda create --name mmcv python=3.7
    (base) PS C:\Users\xxx> conda activate mmcv  # 确保做任何操作前先激活环境
    
  3. 安装 PyTorch 时,可以根据需要安装支持 CUDA 或不支持 CUDA 的版本

    # CUDA version
    (mmcv) PS C:\Users\xxx> conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
    # CPU version
    (mmcv) PS C:\Users\xxx> conda install install pytorch torchvision cpuonly -c pytorch
    
  4. 克隆代码仓库

    (mmcv) PS C:\Users\xxx> git clone https://github.com/open-mmlab/mmcv.git
    (mmcv) PS C:\Users\xxx> cd mmcv
    
  5. 安装 ninjapsutil 以加快编译速度

    (mmcv) PS C:\Users\xxx\mmcv> pip install -r requirements/optional.txt
    
  6. 设置 MSVC 编译器

    设置环境变量。添加 C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.27.29110\bin\Hostx86\x64PATH,则 cl.exe 可以在命令行中运行,如下所示。

    (mmcv) PS C:\Users\xxx\mmcv> cl
    Microsoft (R) C/C++ Optimizing  Compiler Version 19.27.29111 for x64
    Copyright (C) Microsoft Corporation.   All rights reserved.
    
    usage: cl [ option... ] filename... [ / link linkoption... ]
    

    为了兼容性,我们使用 x86-hosted 以及 x64-targeted 版本,即路径中的 Hostx86\x64

    因为 PyTorch 将解析 cl.exe 的输出以检查其版本,只有 utf-8 将会被识别,你可能需要将系统语言更改为英语。控制面板 -> 地区-> 管理-> 非 Unicode 来进行语言转换。

编译与安装 mmcv

mmcv 有两个版本:

  • 只包含 CPU 算子的版本

    编译 CPU 算子,但只有 x86 将会被编译,并且编译版本只能在 CPU only 情况下运行

  • 既包含 CPU 算子,又包含 CUDA 算子的版本

    同时编译 CPU 和 CUDA 算子,ops 模块的 x86 与 CUDA 的代码都可以被编译。同时编译的版本可以在 CUDA 上调用 GPU

CPU 版本

编译安装

(mmcv) PS C:\Users\xxx\mmcv> python setup.py build_ext  # 如果成功, cl 将被启动用于编译算子
(mmcv) PS C:\Users\xxx\mmcv> python setup.py develop  # 安装
GPU 版本
  1. 检查 CUDA_PATH 或者 CUDA_HOME 环境变量已经存在在 envs 之中

    (mmcv) PS C:\Users\xxx\mmcv> ls env:
    
    Name                           Value
    ----                           -----
    CUDA_PATH                      C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2
    CUDA_PATH_V10_1                C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1
    CUDA_PATH_V10_2                C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2
    

    如果没有,你可以按照下面的步骤设置

    (mmcv) PS C:\Users\xxx\mmcv> $env:CUDA_HOME = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2"
    # 或者
    (mmcv) PS C:\Users\xxx\mmcv> $env:CUDA_HOME = $env:CUDA_PATH_V10_2  # CUDA_PATH_V10_2 已经在环境变量中
    
  2. 设置 CUDA 的目标架构

    # 这里需要改成你的显卡对应的目标架构
    (mmcv) PS C:\Users\xxx\mmcv> $env:TORCH_CUDA_ARCH_LIST="7.5"
    

    注解

    可以点击 cuda-gpus 查看 GPU 的计算能力,也可以通过 CUDA 目录下的 deviceQuery.exe 工具查看

    (mmcv) PS C:\Users\xxx\mmcv> &"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2\extras\demo_suite\deviceQuery.exe"
    Device 0: "NVIDIA GeForce GTX 1660 SUPER"
    CUDA Driver Version / Runtime Version          11.7 / 11.1
    CUDA Capability Major/Minor version number:    7.5
    

    上面的 7.5 表示目标架构。注意:需把上面命令的 v10.2 换成你的 CUDA 版本。

  3. 编译安装

    (mmcv) PS C:\Users\xxx\mmcv> python setup.py build_ext  # 如果成功, cl 将被启动用于编译算子
    (mmcv) PS C:\Users\xxx\mmcv> python setup.py develop # 安装
    

    注解

    如果你的 PyTorch 版本是 1.6.0,你可能会遇到一些 issue 提到的错误,你可以参考这个 pull request 修改本地环境的 PyTorch 源代码

验证安装
(mmcv) PS C:\Users\xxx\mmcv> python .dev_scripts/check_installation.py

如果上述命令没有报错,说明安装成功。如有报错,请查看问题解决页面是否已经有解决方案。 如果没有找到解决方案,欢迎提 issue

编译 mmcv-lite

如果你需要使用和 PyTorch 相关的模块,请确保 PyTorch 已经成功安装在环境中,可以参考 PyTorch 官方安装文档

  1. 克隆代码仓库

    git clone https://github.com/open-mmlab/mmcv.git
    cd mmcv
    
  2. 开始编译

    MMCV_WITH_OPS=0 pip install -e . -v
    
  3. 验证安装

    python -c 'import mmcv;print(mmcv.__version__)'
    

在寒武纪 MLU 机器编译 mmcv-full

安装 torch_mlu

选项1: 基于寒武纪 docker image 安装

首先请下载并且拉取寒武纪 docker (请向 service@cambricon.com 发邮件以获得最新的寒武纪 pytorch 发布 docker)。

docker pull ${docker image}

进入 docker, 编译 MMCV MLU进行验证

选项2:基于 cambricon pytorch 源码编译安装

请向 service@cambricon.com 发送邮件或联系 Cambricon 工程师以获取合适版本的 CATCH 软件包,在您获得合适版本的 CATCH 软件包后,请参照 ${CATCH-path}/CONTRIBUTING.md 中的步骤安装 CATCH。

编译 MMCV

克隆代码仓库

git clone https://github.com/open-mmlab/mmcv.git

算子库 mlu-ops 在编译 MMCV 时自动下载到默认路径(mmcv/mlu-ops),你也可以在编译前设置环境变量 MMCV_MLU_OPS_PATH 指向已经存在的 mlu-ops 算子库路径。

export MMCV_MLU_OPS_PATH=/xxx/xxx/mlu-ops

开始编译

cd mmcv
export MMCV_WITH_OPS=1
export FORCE_MLU=1
python setup.py install

验证是否成功安装

完成上述安装步骤之后,您可以尝试运行下面的 Python 代码以测试您是否成功在 MLU 设备上安装了 mmcv-full

import torch
import torch_mlu
from mmcv.ops import sigmoid_focal_loss
x = torch.randn(3, 10).mlu()
x.requires_grad = True
y = torch.tensor([1, 5, 3]).mlu()
w = torch.ones(10).float().mlu()
output = sigmoid_focal_loss(x, y, 2.0, 0.25, w, 'none')

在昇腾 NPU 机器编译 mmcv

在编译 mmcv 前,需要安装 torch_npu,完整安装教程详见 PyTorch 安装指南

选项 1: 使用 NPU 设备源码编译安装 mmcv (推荐方式)

git pull https://github.com/open-mmlab/mmcv.git
  • 编译

MMCV_WITH_OPS=1 MAX_JOBS=8 FORCE_NPU=1 python setup.py build_ext
  • 安装

MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py develop

选项 2: 使用 pip 安装 Ascend 编译版本的 mmcv

Ascend 编译版本的 mmcv 在 mmcv >= 1.7.0 时已经支持直接 pip 安装

pip install mmcv -f https://download.openmmlab.com/mmcv/dist/ascend/torch1.8.0/index.html

验证

import torch
import torch_npu
from mmcv.ops import softmax_focal_loss

# Init tensor to the NPU
x = torch.randn(3, 10).npu()
y = torch.tensor([1, 5, 3]).npu()
w = torch.ones(10).float().npu()

output = softmax_focal_loss(x, y, 2.0, 0.25, w, 'none')
print(output)

解读文章汇总

这篇文章汇总了 OpenMMLab 解读的部分文章(更多文章和视频见 OpenMMLabCourse),如果您有推荐的文章(不一定是 OpenMMLab 发布的文章,可以是自己写的文章),非常欢迎提 Pull Request 添加到这里。

下游算法库解读文章

数据处理

图像

图像模块提供了一些图像预处理的函数,该模块依赖 opencv

读取/保存/显示

使用 imreadimwrite 函数可以读取和保存图像。

import mmcv

img = mmcv.imread('test.jpg')
img = mmcv.imread('test.jpg', flag='grayscale')
img_ = mmcv.imread(img)  # 相当于什么也没做
mmcv.imwrite(img, 'out.jpg')

从二进制中读取图像

with open('test.jpg', 'rb') as f:
    data = f.read()
img = mmcv.imfrombytes(data)

显示图像文件或已读取的图像

mmcv.imshow('tests/data/color.jpg')

for i in range(10):
    img = np.random.randint(256, size=(100, 100, 3), dtype=np.uint8)
    mmcv.imshow(img, win_name='test image', wait_time=200)

色彩空间转换

支持的转换函数:

  • bgr2gray

  • gray2bgr

  • bgr2rgb

  • rgb2bgr

  • bgr2hsv

  • hsv2bgr

img = mmcv.imread('tests/data/color.jpg')
img1 = mmcv.bgr2rgb(img)
img2 = mmcv.rgb2gray(img1)
img3 = mmcv.bgr2hsv(img)

缩放

有三种缩放图像的方法。所有以 imresize_* 开头的函数都有一个 return_scale 参数,如果 该参数为 False ,函数的返回值只有调整之后的图像,否则是一个元组 (resized_img, scale)

# 缩放图像至给定的尺寸
mmcv.imresize(img, (1000, 600), return_scale=True)

# 缩放图像至与给定的图像同样的尺寸
mmcv.imresize_like(img, dst_img, return_scale=False)

# 以一定的比例缩放图像
mmcv.imrescale(img, 0.5)

# 缩放图像至最长的边不大于1000、最短的边不大于800并且没有改变图像的长宽比
mmcv.imrescale(img, (1000, 800))

旋转

我们可以使用 imrotate 旋转图像一定的角度。旋转的中心需要指定,默认值是原始图像的中心。有 两种旋转的模式,一种保持图像的尺寸不变,因此旋转后原始图像中的某些部分会被裁剪,另一种是扩大 图像的尺寸进而保留完整的原始图像。

img = mmcv.imread('tests/data/color.jpg')

# 顺时针旋转图像30度
img_ = mmcv.imrotate(img, 30)

# 逆时针旋转图像90度
img_ = mmcv.imrotate(img, -90)

# 顺时针旋转图像30度并且缩放图像为原始图像的1.5倍
img_ = mmcv.imrotate(img, 30, scale=1.5)

# 以坐标(100, 100)为中心顺时针旋转图像30度
img_ = mmcv.imrotate(img, 30, center=(100, 100))

# 顺时针旋转图像30度并扩大图像的尺寸
img_ = mmcv.imrotate(img, 30, auto_bound=True)

翻转

我们可以使用 imflip 翻转图像。

img = mmcv.imread('tests/data/color.jpg')

# 水平翻转图像
mmcv.imflip(img)

# 垂直翻转图像
mmcv.imflip(img, direction='vertical')

裁剪

imcrop 可以裁剪图像的一个或多个区域,每个区域用左上角和右下角坐标表示,形如(x1, y1, x2, y2)

import mmcv
import numpy as np

img = mmcv.imread('tests/data/color.jpg')

# 裁剪区域 (10, 10, 100, 120)
bboxes = np.array([10, 10, 100, 120])
patch = mmcv.imcrop(img, bboxes)

# 裁剪两个区域,分别是 (10, 10, 100, 120) 和 (0, 0, 50, 50)
bboxes = np.array([[10, 10, 100, 120], [0, 0, 50, 50]])
patches = mmcv.imcrop(img, bboxes)

# 裁剪两个区域并且缩放区域1.2倍
patches = mmcv.imcrop(img, bboxes, scale=1.2)

填充

impad and impad_to_multiple 可以用给定的值将图像填充至给定的尺寸。

img = mmcv.imread('tests/data/color.jpg')

# 用给定值将图像填充至 (1000, 1200)
img_ = mmcv.impad(img, shape=(1000, 1200), pad_val=0)

# 用给定值分别填充图像的3个通道至 (1000, 1200)
img_ = mmcv.impad(img, shape=(1000, 1200), pad_val=(100, 50, 200))

# 用给定值填充图像的左、右、上、下四条边
img_ = mmcv.impad(img, padding=(10, 20, 30, 40), pad_val=0)

# 用3个值分别填充图像的左、右、上、下四条边的3个通道
img_ = mmcv.impad(img, padding=(10, 20, 30, 40), pad_val=(100, 50, 200))

# 将图像的四条边填充至能够被给定值整除
img_ = mmcv.impad_to_multiple(img, 32)

视频

视频模块提供了以下的功能:

  • 一个 VideoReader 类,具有友好的 API 接口可以读取和转换视频

  • 一些编辑视频的方法,包括 cutconcatresize

  • 光流的读取/保存/变换

VideoReader

VideoReader 类提供了和序列一样的接口去获取视频帧。该类会缓存所有被访问过的帧。

video = mmcv.VideoReader('test.mp4')

# 获取基本的信息
print(len(video))
print(video.width, video.height, video.resolution, video.fps)

# 遍历所有的帧
for frame in video:
    print(frame.shape)

# 读取下一帧
img = video.read()

# 使用索引获取帧
img = video[100]

# 获取指定范围的帧
img = video[5:10]

将视频切成帧并保存至给定目录或者从给定目录中生成视频。

# 将视频切成帧并保存至目录
video = mmcv.VideoReader('test.mp4')
video.cvt2frames('out_dir')

# 从给定目录中生成视频
mmcv.frames2video('out_dir', 'test.avi')

编辑函数

有几个用于编辑视频的函数,这些函数是对 ffmpeg 的封装。

# 裁剪视频
mmcv.cut_video('test.mp4', 'clip1.mp4', start=3, end=10, vcodec='h264')

# 将多个视频拼接成一个视频
mmcv.concat_video(['clip1.mp4', 'clip2.mp4'], 'joined.mp4', log_level='quiet')

# 将视频缩放至给定的尺寸
mmcv.resize_video('test.mp4', 'resized1.mp4', (360, 240))

# 将视频缩放至给定的倍率
mmcv.resize_video('test.mp4', 'resized2.mp4', ratio=2)

光流

mmcv 提供了以下用于操作光流的函数:

  • 读取/保存

  • 可视化

  • 流变换

我们提供了两种将光流dump到文件的方法,分别是非压缩和压缩的方法。非压缩的方法直接将浮点数值的光流 保存至二进制文件,虽然光流无损但文件会比较大。而压缩的方法先量化光流至 0-255 整形数值再保存为 jpeg图像。光流的x维度和y维度会被拼接到图像中。

  1. 读取/保存

flow = np.random.rand(800, 600, 2).astype(np.float32)
# 保存光流到flo文件 (~3.7M)
mmcv.flowwrite(flow, 'uncompressed.flo')
# 保存光流为jpeg图像 (~230K),图像的尺寸为 (800, 1200)
mmcv.flowwrite(flow, 'compressed.jpg', quantize=True, concat_axis=1)

# 读取光流文件,以下两种方式读取的光流尺寸均为 (800, 600, 2)
flow = mmcv.flowread('uncompressed.flo')
flow = mmcv.flowread('compressed.jpg', quantize=True, concat_axis=1)
  1. 可视化

使用 mmcv.flowshow() 可视化光流

mmcv.flowshow(flow)

progress

  1. 流变换

img1 = mmcv.imread('img1.jpg')
flow = mmcv.flowread('flow.flo')
warped_img2 = mmcv.flow_warp(img1, flow)

img1 (左) and img2 (右)

raw images

光流 (img2 -> img1)

optical flow

变换后的图像和真实图像的差异

warped image

数据变换

在 OpenMMLab 算法库中,数据集的构建和数据的准备是相互解耦的。通常,数据集的构建只对数据集进行解析,记录每个样本的基本信息;而数据的准备则是通过一系列的数据变换,根据样本的基本信息进行数据加载、预处理、格式化等操作。

数据变换的设计

在 MMCV 中,我们使用各种可调用的数据变换类来进行数据的操作。这些数据变换类可以接受若干配置参数进行实例化,之后通过调用的方式对输入的数据字典进行处理。同时,我们约定所有数据变换都接受一个字典作为输入,并将处理后的数据输出为一个字典。一个简单的例子如下:

>>> import numpy as np
>>> from mmcv.transforms import Resize
>>>
>>> transform = Resize(scale=(224, 224))
>>> data_dict = {'img': np.random.rand(256, 256, 3)}
>>> data_dict = transform(data_dict)
>>> print(data_dict['img'].shape)
(224, 224, 3)

数据变换类会读取输入字典的某些字段,并且可能添加、或者更新某些字段。这些字段的键大部分情况下是固定的,如 Resize 会固定地读取输入字典中的 "img" 等字段。我们可以在对应类的文档中了解对输入输出字段的约定。

注解

默认情况下,在需要图像尺寸作为初始化参数的数据变换 (如Resize, Pad) 中,图像尺寸的顺序均为 (width, height)。在数据变换返回的字典中,图像相关的尺寸, 如 img_shapeori_shapepad_shape 等,均为 (height, width)。

MMCV 为所有的数据变换类提供了一个统一的基类 (BaseTransform):

class BaseTransform(metaclass=ABCMeta):

    def __call__(self, results: dict) -> dict:

        return self.transform(results)

    @abstractmethod
    def transform(self, results: dict) -> dict:
        pass

所有的数据变换类都需要继承 BaseTransform,并实现 transform 方法。transform 方法的输入和输出均为一个字典。在自定义数据变换类一节中,我们会更详细地介绍如何实现一个数据变换类。

数据流水线

如上所述,所有数据变换的输入和输出都是一个字典,而且根据 OpenMMLab 中 有关数据集的约定,数据集中每个样本的基本信息都是一个字典。这样一来,我们可以将所有的数据变换操作首尾相接,组合成为一条数据流水线(data pipeline),输入数据集中样本的信息字典,输出完成一系列处理后的信息字典。

以分类任务为例,我们在下图展示了一个典型的数据流水线。对每个样本,数据集中保存的基本信息是一个如图中最左侧所示的字典,之后每经过一个由蓝色块代表的数据变换操作,数据字典中都会加入新的字段(标记为绿色)或更新现有的字段(标记为橙色)。

在配置文件中,数据流水线是一个若干数据变换配置字典组成的列表,每个数据集都需要设置参数 pipeline 来定义该数据集需要进行的数据准备操作。如上数据流水线在配置文件中的配置如下:

pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=256, keep_ratio=True),
    dict(type='CenterCrop', crop_size=224),
    dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
    dict(type='ClsFormatBundle')
]

dataset = dict(
    ...
    pipeline=pipeline,
    ...
)

常用的数据变换类

按照功能,常用的数据变换类可以大致分为数据加载、数据预处理与增强、数据格式化。在 MMCV 中,我们提供了一些常用的数据变换类如下:

数据加载

为了支持大规模数据集的加载,通常在 Dataset 初始化时不加载数据,只加载相应的路径。因此需要在数据流水线中进行具体数据的加载。

class 功能
LoadImageFromFile 根据路径加载图像
LoadAnnotations 加载和组织标注信息,如 bbox、语义分割图等

数据预处理及增强

数据预处理和增强通常是对图像本身进行变换,如裁剪、填充、缩放等。

class 功能
Pad 填充图像边缘
CenterCrop 居中裁剪
Normalize 对图像进行归一化
Resize 按照指定尺寸或比例缩放图像
RandomResize 缩放图像至指定范围的随机尺寸
RandomMultiscaleResize 缩放图像至多个尺寸中的随机一个尺寸
RandomGrayscale 随机灰度化
RandomFlip 图像随机翻转
MultiScaleFlipAug 支持缩放和翻转的测试时数据增强

数据格式化

数据格式化操作通常是对数据进行的类型转换。

class 功能
ToTensor 将指定的数据转换为 torch.Tensor
ImageToTensor 将图像转换为 torch.Tensor

自定义数据变换类

要实现一个新的数据变换类,需要继承 BaseTransform,并实现 transform 方法。这里,我们使用一个简单的翻转变换(MyFlip)作为示例:

import random
import mmcv
from mmcv.transforms import BaseTransform, TRANSFORMS

@TRANSFORMS.register_module()
class MyFlip(BaseTransform):
    def __init__(self, direction: str):
        super().__init__()
        self.direction = direction

    def transform(self, results: dict) -> dict:
        img = results['img']
        results['img'] = mmcv.imflip(img, direction=self.direction)
        return results

从而,我们可以实例化一个 MyFlip 对象,并将之作为一个可调用对象,来处理我们的数据字典。

import numpy as np

transform = MyFlip(direction='horizontal')
data_dict = {'img': np.random.rand(224, 224, 3)}
data_dict = transform(data_dict)
processed_img = data_dict['img']

又或者,在配置文件的 pipeline 中使用 MyFlip 变换

pipeline = [
    ...
    dict(type='MyFlip', direction='horizontal'),
    ...
]

需要注意的是,如需在配置文件中使用,需要保证 MyFlip 类所在的文件在运行时能够被导入。

变换包装

变换包装是一种特殊的数据变换类,他们本身并不操作数据字典中的图像、标签等信息,而是对其中定义的数据变换的行为进行增强。

字段映射(KeyMapper)

字段映射包装(KeyMapper)用于对数据字典中的字段进行映射。例如,一般的图像处理变换都从数据字典中的 "img" 字段获得值。但有些时候,我们希望这些变换处理数据字典中其他字段中的图像,比如 "gt_img" 字段。

如果配合注册器和配置文件使用的话,在配置文件中数据集的 pipeline 中如下例使用字段映射包装:

pipeline = [
    ...
    dict(type='KeyMapper',
        mapping={
            'img': 'gt_img',  # 将 "gt_img" 字段映射至 "img" 字段
            'mask': ...,  # 不使用原始数据中的 "mask" 字段。即对于被包装的数据变换,数据中不包含 "mask" 字段
        },
        auto_remap=True,  # 在完成变换后,将 "img" 重映射回 "gt_img" 字段
        transforms=[
            # 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
            dict(type='RandomFlip'),
        ])
    ...
]

利用字段映射包装,我们在实现数据变换类时,不需要考虑在 transform 方法中考虑各种可能的输入字段名,只需要处理默认的字段即可。

随机选择(RandomChoice)和随机执行(RandomApply)

随机选择包装(RandomChoice)用于从一系列数据变换组合中随机应用一个数据变换组合。利用这一包装,我们可以简单地实现一些数据增强功能,比如 AutoAugment。

如果配合注册器和配置文件使用的话,在配置文件中数据集的 pipeline 中如下例使用随机选择包装:

pipeline = [
    ...
    dict(type='RandomChoice',
        transforms=[
            [
                dict(type='Posterize', bits=4),
                dict(type='Rotate', angle=30.)
            ],  # 第一种随机变化组合
            [
                dict(type='Equalize'),
                dict(type='Rotate', angle=30)
            ],  # 第二种随机变换组合
        ],
        prob=[0.4, 0.6]  # 两种随机变换组合各自的选用概率
        )
    ...
]

随机执行包装(RandomApply)用于以指定概率随机执行数据变换组合。例如:

pipeline = [
    ...
    dict(type='RandomApply',
        transforms=[dict(type='Rotate', angle=30.)],
        prob=0.3)  # 以 0.3 的概率执行被包装的数据变换
    ...
]

多目标扩展(TransformBroadcaster)

通常,一个数据变换类只会从一个固定的字段读取操作目标。虽然我们也可以使用 KeyMapper 来改变读取的字段,但无法将变换一次性应用于多个字段的数据。为了实现这一功能,我们需要借助多目标扩展包装(TransformBroadcaster)。

多目标扩展包装(TransformBroadcaster)有两个用法,一是将数据变换作用于指定的多个字段,二是将数据变换作用于某个字段下的一组目标中。

  1. 应用于多个字段

    假设我们需要将数据变换应用于 "lq" (low-quality) 和 "gt" (ground-truth) 两个字段中的图像上。

    pipeline = [
        dict(type='TransformBroadcaster',
            # 分别应用于 "lq" 和 "gt" 两个字段,并将二者应设置 "img" 字段
            mapping={'img': ['lq', 'gt']},
            # 在完成变换后,将 "img" 字段重映射回原先的字段
            auto_remap=True,
            # 是否在对各目标的变换中共享随机变量
            # 更多介绍参加后续章节(随机变量共享)
            share_random_params=True,
            transforms=[
                # 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
                dict(type='RandomFlip'),
            ])
    ]
    

    在多目标扩展的 mapping 设置中,我们同样可以使用 ... 来忽略指定的原始字段。如以下例子中,被包裹的 RandomCrop 会对字段 "img" 中的图像进行裁剪,并且在字段 "img_shape" 存在时更新剪裁后的图像大小。如果我们希望同时对两个图像字段 "lq""gt" 进行相同的随机裁剪,但只更新一次 "img_shape" 字段,可以通过例子中的方式实现:

    pipeline = [
        dict(type='TransformBroadcaster',
            mapping={
                'img': ['lq', 'gt'],
                'img_shape': ['img_shape', ...],
             },
            # 在完成变换后,将 "img" 和 "img_shape" 字段重映射回原先的字段
            auto_remap=True,
            # 是否在对各目标的变换中共享随机变量
            # 更多介绍参加后续章节(随机变量共享)
            share_random_params=True,
            transforms=[
                # `RandomCrop` 类中会操作 "img" 和 "img_shape" 字段。若 "img_shape" 空缺,
                # 则只操作 "img"
                dict(type='RandomCrop'),
            ])
    ]
    
  2. 应用于一个字段的一组目标

    假设我们需要将数据变换应用于 "images" 字段,该字段为一个图像组成的 list。

    pipeline = [
        dict(type='TransformBroadcaster',
            # 将 "images" 字段下的每张图片映射至 "img" 字段
            mapping={'img': 'images'},
            # 在完成变换后,将 "img" 字段下的图片重映射回 "images" 字段的列表中
            auto_remap=True,
            # 是否在对各目标的变换中共享随机变量
            share_random_params=True,
            transforms=[
                # 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
                dict(type='RandomFlip'),
            ])
    ]
    
装饰器 cache_randomness

TransformBroadcaster 中,我们提供了 share_random_params 选项来支持在多次数据变换中共享随机状态。例如,在超分辨率任务中,我们希望将随机变换同步作用于低分辨率图像和原始图像。如果我们希望在自定义的数据变换类中使用这一功能,需要在类中标注哪些随机变量是支持共享的。这可以通过装饰器 cache_randomness 来实现。

以上文中的 MyFlip 为例,我们希望以一定的概率随机执行翻转:

from mmcv.transforms.utils import cache_randomness

@TRANSFORMS.register_module()
class MyRandomFlip(BaseTransform):
    def __init__(self, prob: float, direction: str):
        super().__init__()
        self.prob = prob
        self.direction = direction

    @cache_randomness  # 标注该方法的输出为可共享的随机变量
    def do_flip(self):
        flip = True if random.random() > self.prob else False
        return flip

    def transform(self, results: dict) -> dict:
        img = results['img']
        if self.do_flip():
            results['img'] = mmcv.imflip(img, direction=self.direction)
        return results

在上面的例子中,我们用cache_randomness 装饰 do_flip方法,即将该方法返回值 flip 标注为一个支持共享的随机变量。进而,在 TransformBroadcaster 对多个目标的变换中,这一变量的值都会保持一致。

装饰器 avoid_cache_randomness

在一些情况下,我们无法将数据变换中产生随机变量的过程单独放在类方法中。例如数据变换中使用的来自第三方库的模块,这些模块将随机变量相关的部分封装在了内部,导致无法将其抽出为数据变换的类方法。这样的数据变换无法通过装饰器 cache_randomness 标注支持共享的随机变量,进而无法在多目标扩展时共享随机变量。

为了避免在多目标扩展中误用此类数据变换,我们提供了另一个装饰器 avoid_cache_randomness,用来对此类数据变换进行标记:

from mmcv.transforms.utils import avoid_cache_randomness

@TRANSFORMS.register_module()
@avoid_cache_randomness
class MyRandomTransform(BaseTransform):

    def transform(self, results: dict) -> dict:
        ...

avoid_cache_randomness 标记的数据变换类,当其实例被 TransformBroadcaster 包装且将参数 share_random_params 设置为 True 时,会抛出异常,以此提醒用户不能这样使用。

在使用 avoid_cache_randomness 时需要注意以下几点:

  1. avoid_cache_randomness 只用于装饰数据变换类(BaseTransfrom 的子类),而不能用与装饰其他一般的类、类方法或函数

  2. avoid_cache_randomness 修饰的数据变换作为基类时,其子类将不会继承这一特性。如果子类仍无法共享随机变量,则应再次使用 avoid_cache_randomness 修饰

  3. 只有当一个数据变换具有随机性,且无法共享随机参数时,才需要以 avoid_cache_randomness 修饰。无随机性的数据变换不需要修饰

可视化

mmcv 可以展示图像以及标注(目前只支持标注框)

# 展示图像文件
mmcv.imshow('a.jpg')

# 展示已加载的图像
img = np.random.rand(100, 100, 3)
mmcv.imshow(img)

# 展示带有标注框的图像
img = np.random.rand(100, 100, 3)
bboxes = np.array([[0, 0, 50, 50], [20, 20, 60, 60]])
mmcv.imshow_bboxes(img, bboxes)

mmcv 也可以展示特殊的图像,例如光流

flow = mmcv.flowread('test.flo')
mmcv.flowshow(flow)

卷积神经网络

我们为卷积神经网络提供了一些构建模块,包括层构建、模块组件和权重初始化。

网络层的构建

在运行实验时,我们可能需要尝试同属一种类型但不同配置的层,但又不希望每次都修改代码。于是我们提供一些层构建方法,可以从字典构建层,字典可以在配置文件中配置,也可以通过命令行参数指定。

用法

一个简单的例子:

from mmcv.cnn import build_conv_layer

cfg = dict(type='Conv3d')
layer = build_conv_layer(cfg, in_channels=3, out_channels=8, kernel_size=3)
  • build_conv_layer: 支持的类型包括 Conv1d、Conv2d、Conv3d、Conv (Conv是Conv2d的别名)

  • build_norm_layer: 支持的类型包括 BN1d、BN2d、BN3d、BN (alias for BN2d)、SyncBN、GN、LN、IN1d、IN2d、IN3d、IN(IN是IN2d的别名)

  • build_activation_layer:支持的类型包括 ReLU、LeakyReLU、PReLU、RReLU、ReLU6、ELU、Sigmoid、Tanh、GELU

  • build_upsample_layer: 支持的类型包括 nearest、bilinear、deconv、pixel_shuffle

  • build_padding_layer: 支持的类型包括 zero、reflect、replicate

拓展

我们还允许自定义层和算子来扩展构建方法。

  1. 编写和注册自己的模块:

    from mmengine.registry import MODELS
    
    @MODELS.register_module()
    class MyUpsample:
    
        def __init__(self, scale_factor):
            pass
    
        def forward(self, x):
            pass
    
  2. 在某处导入 MyUpsample (例如 __init__.py )然后使用它:

    from mmcv.cnn import build_upsample_layer
    
    cfg = dict(type='MyUpsample', scale_factor=2)
    layer = build_upsample_layer(cfg)
    

模块组件

我们还提供了常用的模块组件,以方便网络构建。 卷积组件 ConvModule 由 convolution、normalization以及activation layers 组成,更多细节请参考 ConvModule api

from mmcv.cnn import ConvModule

# conv + bn + relu
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
# conv + gn + relu
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='GN', num_groups=2))
# conv + relu
conv = ConvModule(3, 8, 2)
# conv
conv = ConvModule(3, 8, 2, act_cfg=None)
# conv + leaky relu
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='LeakyReLU'))
# bn + conv + relu
conv = ConvModule(
    3, 8, 2, norm_cfg=dict(type='BN'), order=('norm', 'conv', 'act'))

算子

MMCV 提供了检测、分割等任务中常用的算子

Device CPU CUDA MLU MPS Ascend
ActiveRotatedFilter
AssignScoreWithK
BallQuery
BBoxOverlaps
BorderAlign
BoxIouRotated
BoxIouQuadri
CARAFE
ChamferDistance
CrissCrossAttention
ContourExpand
ConvexIoU
CornerPool
Correlation
Deformable Convolution v1/v2
Deformable RoIPool
DiffIoURotated
DynamicScatter
FurthestPointSample
FurthestPointSampleWithDist
FusedBiasLeakyrelu
GatherPoints
GroupPoints
Iou3d
KNN
MaskedConv
MergeCells
MinAreaPolygon
ModulatedDeformConv2d
MultiScaleDeformableAttn
NMS
NMSRotated
NMSQuadri
PixelGroup
PointsInBoxes
PointsInPolygons
PSAMask
RotatedFeatureAlign
RoIPointPool3d
RoIPool
RoIAlignRotated
RiRoIAlignRotated
RoIAlign
RoIAwarePool3d
SAConv2d
SigmoidFocalLoss
SoftmaxFocalLoss
SoftNMS
Sparse Convolution
Synchronized BatchNorm
ThreeInterpolate
ThreeNN
TINShift
UpFirDn2d
Voxelization
PrRoIPool
BezierAlign
BiasAct
FilteredLrelu
Conv2dGradfix

v2.0.0

OpenMMLab 团队于 2022 年 9 月 1 日在世界人工智能大会发布了新一代训练引擎 MMEngine,它是一个用于训练深度学习模型的基础库。相比于 MMCV,它提供了更高级且通用的训练器、接口更加统一的开放架构以及可定制化程度更高的训练流程。

OpenMMLab 团队于 2023 年 4 月 6 日发布 MMCV v2.0.0。在 2.x 版本中,它有以下重大变化:

(1)删除了以下组件:

  • mmcv.fileio 模块,删除于 PR #2179。在需要使用 FileIO 的地方使用 mmengine 中的 FileIO 模块

  • mmcv.runnermmcv.parallelmmcv.enginemmcv.device,删除于 PR #2216

  • mmcv.utils 的所有类(例如 ConfigRegistry)和大部分函数,删除于 PR #2217,只保留少数和 mmcv 相关的函数

  • mmcv.onnxmmcv.tensorrt 模块以及相关的函数,删除于 PR #2225

  • 删除 MMCV 所有的根注册器并将类或者函数注册到 MMEngine 的根注册器

(2)新增了 mmcv.transforms 数据变换模块

(3)在 PR #2235 中将包名 mmcv 重命名为 mmcv-litemmcv-full 重命名为 mmcv。此外,将环境变量 MMCV_WITH_OPS 的默认值从 0 改为 1

MMCV < 2.0 MMCV >= 2.0
# 包含算子,因为 mmcv-full 的最高版本小于 2.0.0,所以无需加版本限制
pip install openmim
mim install mmcv-full

# 不包含算子
pip install openmim
mim install "mmcv < 2.0.0"
# 包含算子
pip install openmim
mim install mmcv

# 不包含算子,因为 mmcv-lite 的起始版本为 2.0.0,所以无需加版本限制
pip install openmim
mim install mmcv-lite

v1.3.18

部分自定义算子对于不同的设备有不同实现,为此添加的大量宏命令与类型检查使得代码变得难以维护。例如:

  if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
    CHECK_CUDA_INPUT(input);
    CHECK_CUDA_INPUT(rois);
    CHECK_CUDA_INPUT(output);
    CHECK_CUDA_INPUT(argmax_y);
    CHECK_CUDA_INPUT(argmax_x);

    roi_align_forward_cuda(input, rois, output, argmax_y, argmax_x,
                           aligned_height, aligned_width, spatial_scale,
                           sampling_ratio, pool_mode, aligned);
#else
    AT_ERROR("RoIAlign is not compiled with GPU support");
#endif
  } else {
    CHECK_CPU_INPUT(input);
    CHECK_CPU_INPUT(rois);
    CHECK_CPU_INPUT(output);
    CHECK_CPU_INPUT(argmax_y);
    CHECK_CPU_INPUT(argmax_x);
    roi_align_forward_cpu(input, rois, output, argmax_y, argmax_x,
                          aligned_height, aligned_width, spatial_scale,
                          sampling_ratio, pool_mode, aligned);
  }

为此我们设计了注册与分发的机制以更好的管理这些算子实现。

void ROIAlignForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output,
                                       Tensor argmax_y, Tensor argmax_x,
                                       int aligned_height, int aligned_width,
                                       float spatial_scale, int sampling_ratio,
                                       int pool_mode, bool aligned);

void roi_align_forward_cuda(Tensor input, Tensor rois, Tensor output,
                            Tensor argmax_y, Tensor argmax_x,
                            int aligned_height, int aligned_width,
                            float spatial_scale, int sampling_ratio,
                            int pool_mode, bool aligned) {
  ROIAlignForwardCUDAKernelLauncher(
      input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width,
      spatial_scale, sampling_ratio, pool_mode, aligned);
}

// 注册算子的cuda实现
void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
                            Tensor argmax_y, Tensor argmax_x,
                            int aligned_height, int aligned_width,
                            float spatial_scale, int sampling_ratio,
                            int pool_mode, bool aligned);
REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda);

// roi_align.cpp
// 使用dispatcher根据参数中的Tensor device类型对实现进行分发
void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
                            Tensor argmax_y, Tensor argmax_x,
                            int aligned_height, int aligned_width,
                            float spatial_scale, int sampling_ratio,
                            int pool_mode, bool aligned) {
  DISPATCH_DEVICE_IMPL(roi_align_forward_impl, input, rois, output, argmax_y,
                       argmax_x, aligned_height, aligned_width, spatial_scale,
                       sampling_ratio, pool_mode, aligned);
}

v1.3.11

为了灵活地支持更多的后端和硬件,例如 NVIDIA GPUsAMD GPUs,我们重构了 mmcv/ops/csrc 目录。注意,这次重构不会影响 API 的使用。更多相关信息,请参考 PR1206

原始的目录结构如下所示

.
├── common_cuda_helper.hpp
├── ops_cuda_kernel.cuh
├── pytorch_cpp_helper.hpp
├── pytorch_cuda_helper.hpp
├── parrots_cpp_helper.hpp
├── parrots_cuda_helper.hpp
├── parrots_cudawarpfunction.cuh
├── onnxruntime
│   ├── onnxruntime_register.h
│   ├── onnxruntime_session_options_config_keys.h
│   ├── ort_mmcv_utils.h
│   ├── ...
│   ├── onnx_ops.h
│   └── cpu
│       ├── onnxruntime_register.cpp
│       ├── ...
│       └── onnx_ops_impl.cpp
├── parrots
│   ├── ...
│   ├── ops.cpp
│   ├── ops_cuda.cu
│   ├── ops_parrots.cpp
│   └── ops_pytorch.h
├── pytorch
│   ├── ...
│   ├── ops.cpp
│   ├── ops_cuda.cu
│   ├── pybind.cpp
└── tensorrt
    ├── trt_cuda_helper.cuh
    ├── trt_plugin_helper.hpp
    ├── trt_plugin.hpp
    ├── trt_serialize.hpp
    ├── ...
    ├── trt_ops.hpp
    └── plugins
        ├── trt_cuda_helper.cu
        ├── trt_plugin.cpp
        ├── ...
        ├── trt_ops.cpp
        └── trt_ops_kernel.cu

重构之后,它的结构如下所示

.
├── common
│   ├── box_iou_rotated_utils.hpp
│   ├── parrots_cpp_helper.hpp
│   ├── parrots_cuda_helper.hpp
│   ├── pytorch_cpp_helper.hpp
│   ├── pytorch_cuda_helper.hpp
│   └── cuda
│       ├── common_cuda_helper.hpp
│       ├── parrots_cudawarpfunction.cuh
│       ├── ...
│       └── ops_cuda_kernel.cuh
├── onnxruntime
│   ├── onnxruntime_register.h
│   ├── onnxruntime_session_options_config_keys.h
│   ├── ort_mmcv_utils.h
│   ├── ...
│   ├── onnx_ops.h
│   └── cpu
│       ├── onnxruntime_register.cpp
│       ├── ...
│       └── onnx_ops_impl.cpp
├── parrots
│   ├── ...
│   ├── ops.cpp
│   ├── ops_parrots.cpp
│   └── ops_pytorch.h
├── pytorch
│   ├── info.cpp
│   ├── pybind.cpp
│   ├── ...
│   ├── ops.cpp
│   └── cuda
│       ├── ...
│       └── ops_cuda.cu
└── tensorrt
    ├── trt_cuda_helper.cuh
    ├── trt_plugin_helper.hpp
    ├── trt_plugin.hpp
    ├── trt_serialize.hpp
    ├── ...
    ├── trt_ops.hpp
    └── plugins
        ├── trt_cuda_helper.cu
        ├── trt_plugin.cpp
        ├── ...
        ├── trt_ops.cpp
        └── trt_ops_kernel.cu

常见问题

在这里我们列出了用户经常遇到的问题以及对应的解决方法。如果您遇到了其他常见的问题,并且知道可以帮到大家的解决办法, 欢迎随时丰富这个列表。

安装问题

  • KeyError: “xxx: ‘yyy is not in the zzz registry’”

    只有模块所在的文件被导入时,注册机制才会被触发,所以您需要在某处导入该文件,更多详情请查看 KeyError: “MaskRCNN: ‘RefineRoIHead is not in the models registry’”

  • “No module named ‘mmcv.ops’”; “No module named ‘mmcv._ext’”

    1. 使用 pip uninstall mmcv 卸载您环境中的 mmcv

    2. 参考 installation instruction 或者 Build MMCV from source 安装 mmcv-full

  • “invalid device function” 或者 “no kernel image is available for execution”

    1. 检查 GPU 的 CUDA 计算能力

    2. 运行 python mmdet/utils/collect_env.py 来检查 PyTorch、torchvision 和 MMCV 是否是针对正确的 GPU 架构构建的,您可能需要去设置 TORCH_CUDA_ARCH_LIST 来重新安装 MMCV。兼容性问题可能会出现在使用旧版的 GPUs,如:colab 上的 Tesla K80 (3.7)

    3. 检查运行环境是否和 mmcv/mmdet 编译时的环境相同。例如,您可能使用 CUDA 10.0 编译 mmcv,但在 CUDA 9.0 的环境中运行它

  • “undefined symbol” 或者 “cannot open xxx.so”

    1. 如果符号和 CUDA/C++ 相关(例如:libcudart.so 或者 GLIBCXX),请检查 CUDA/GCC 运行时的版本是否和编译 mmcv 的一致

    2. 如果符号和 PyTorch 相关(例如:符号包含 caffe、aten 和 TH),请检查 PyTorch 运行时的版本是否和编译 mmcv 的一致

    3. 运行 python mmdet/utils/collect_env.py 以检查 PyTorch、torchvision 和 MMCV 构建和运行的环境是否相同

  • “RuntimeError: CUDA error: invalid configuration argument”

    这个错误可能是由于您的 GPU 性能不佳造成的。尝试降低 THREADS_PER_BLOCK 的值并重新编译 mmcv。

  • “RuntimeError: nms is not compiled with GPU support”

    这个错误是由于您的 CUDA 环境没有正确安装。 您可以尝试重新安装您的 CUDA 环境,然后删除 mmcv/build 文件夹并重新编译 mmcv。

  • “Segmentation fault”

    1. 检查 GCC 的版本,通常是因为 PyTorch 版本与 GCC 版本不匹配 (例如 GCC < 4.9 ),我们推荐用户使用 GCC 5.4,我们也不推荐使用 GCC 5.5, 因为有反馈 GCC 5.5 会导致 “segmentation fault” 并且切换到 GCC 5.4 就可以解决问题

    2. 检查是否正确安装 CUDA 版本的 PyTorc。输入以下命令并检查是否返回 True

      python -c 'import torch; print(torch.cuda.is_available())'
      
    3. 如果 torch 安装成功,那么检查 MMCV 是否安装成功。输入以下命令,如果没有报错说明 mmcv-full 安装成。

      python -c 'import mmcv; import mmcv.ops'
      
    4. 如果 MMCV 与 PyTorch 都安装成功了,则可以使用 ipdb 设置断点或者使用 print 函数,分析是哪一部分的代码导致了 segmentation fault

  • “libtorch_cuda_cu.so: cannot open shared object file”

    mmcv-full 依赖 libtorch_cuda_cu.so 文件,但程序运行时没能找到该文件。我们可以检查该文件是否存在 ~/miniconda3/envs/{environment-name}/lib/python3.7/site-packages/torch/lib 也可以尝试重装 PyTorch。

  • “fatal error C1189: #error: – unsupported Microsoft Visual Studio version!”

    如果您在 Windows 上编译 mmcv-full 并且 CUDA 的版本是 9.2,您很可能会遇到这个问题 "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.2\include\crt/host_config.h(133): fatal error C1189: #error:  -- unsupported Microsoft Visual Studio version! Only the versions 2012, 2013, 2015 and 2017 are supported!",您可以尝试使用低版本的 Microsoft Visual Studio,例如 vs2017。

  • “error: member “torch::jit::detail::ModulePolicy::all_slots” may not be initialized”

    如果您在 Windows 上编译 mmcv-full 并且 PyTorch 的版本是 1.5.0,您很可能会遇到这个问题 - torch/csrc/jit/api/module.h(474): error: member "torch::jit::detail::ModulePolicy::all_slots" may not be initialized。解决这个问题的方法是将 torch/csrc/jit/api/module.h 文件中所有 static constexpr bool all_slots = false; 替换为 static bool all_slots = false;。更多细节可以查看 member “torch::jit::detail::AttributePolicy::all_slots” may not be initialized

  • “error: a member with an in-class initializer must be const”

    如果您在 Windows 上编译 mmcv-full 并且 PyTorch 的版本是 1.6.0,您很可能会遇到这个问题 "- torch/include\torch/csrc/jit/api/module.h(483): error: a member with an in-class initializer must be const". 解决这个问题的方法是将 torch/include\torch/csrc/jit/api/module.h 文件中的所有 CONSTEXPR_EXCEPT_WIN_CUDA 替换为 const。更多细节可以查看 Ninja: build stopped: subcommand failed

  • “error: member “torch::jit::ProfileOptionalOp::Kind” may not be initialized”

    如果您在 Windows 上编译 mmcv-full 并且 PyTorch 的版本是 1.7.0,您很可能会遇到这个问题 torch/include\torch/csrc/jit/ir/ir.h(1347): error: member "torch::jit::ProfileOptionalOp::Kind" may not be initialized. 解决这个问题的方法是修改 PyTorch 中的几个文件:

    • 删除 torch/include\torch/csrc/jit/ir/ir.h 文件中的 static constexpr Symbol Kind = ::c10::prim::profile;tatic constexpr Symbol Kind = ::c10::prim::profile_optional;

    • torch\include\pybind11\cast.h 文件中的 explicit operator type&() { return *(this->value); } 替换为 explicit operator type&() { return *((type*)this->value); }

    • torch/include\torch/csrc/jit/api/module.h 文件中的 所有 CONSTEXPR_EXCEPT_WIN_CUDA 替换为 const

    更多细节可以查看 Ensure default extra_compile_args

  • MMCV 和 MMDetection 的兼容性问题;”ConvWS is already registered in conv layer”

    请参考 installation instruction 为您的 MMDetection 版本安装正确版本的 MMCV。

使用问题

  • “RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one”

    1. 这个错误是因为有些参数没有参与 loss 的计算,可能是代码中存在多个分支,导致有些分支没有参与 loss 的计算。更多细节见 Expected to have finished reduction in the prior iteration before starting a new one

    2. 你可以设置 DDP 中的 find_unused_parametersTrue,或者手动查找哪些参数没有用到。

  • “RuntimeError: Trying to backward through the graph a second time”

    不能同时设置 GradientCumulativeOptimizerHookOptimizerHook,这会导致 loss.backward() 被调用两次,于是程序抛出 RuntimeError。我们只需设置其中的一个。更多细节见 Trying to backward through the graph a second time

贡献代码

欢迎加入 MMCV 社区,我们致力于打造最前沿的计算机视觉基础库,我们欢迎任何类型的贡献,包括但不限于

修复错误

修复代码实现错误的步骤如下:

  1. 如果提交的代码改动较大,建议先提交 issue,并正确描述 issue 的现象、原因和复现方式,讨论后确认修复方案。

  2. 修复错误并补充相应的单元测试,提交拉取请求。

新增功能或组件

  1. 如果新功能或模块涉及较大的代码改动,建议先提交 issue,确认功能的必要性。

  2. 实现新增功能并添单元测试,提交拉取请求。

文档补充

修复文档可以直接提交拉取请求

添加文档或将文档翻译成其他语言步骤如下

  1. 提交 issue,确认添加文档的必要性。

  2. 添加文档,提交拉取请求。

拉取请求工作流

如果你对拉取请求不了解,没关系,接下来的内容将会从零开始,一步一步地指引你如何创建一个拉取请求。如果你想深入了解拉取请求的开发模式,可以参考 github 官方文档

1. 复刻仓库

当你第一次提交拉取请求时,先复刻 OpenMMLab 原代码库,点击 GitHub 页面右上角的 Fork 按钮,复刻后的代码库将会出现在你的 GitHub 个人主页下。

将代码克隆到本地

git clone git@github.com:{username}/mmcv.git

添加原代码库为上游代码库

git remote add upstream git@github.com:open-mmlab/mmcv

检查 remote 是否添加成功,在终端输入 git remote -v

origin	git@github.com:{username}/mmcv.git (fetch)
origin	git@github.com:{username}/mmcv.git (push)
upstream	git@github.com:open-mmlab/mmcv (fetch)
upstream	git@github.com:open-mmlab/mmcv (push)

注解

这里对 origin 和 upstream 进行一个简单的介绍,当我们使用 git clone 来克隆代码时,会默认创建一个 origin 的 remote,它指向我们克隆的代码库地址,而 upstream 则是我们自己添加的,用来指向原始代码库地址。当然如果你不喜欢他叫 upstream,也可以自己修改,比如叫 open-mmlab。我们通常向 origin 提交代码(即 fork 下来的远程仓库),然后向 upstream 提交一个 pull request。如果提交的代码和最新的代码发生冲突,再从 upstream 拉取最新的代码,和本地分支解决冲突,再提交到 origin。

2. 配置 pre-commit

在本地开发环境中,我们使用 pre-commit 来检查代码风格,以确保代码风格的统一。在提交代码,需要先安装 pre-commit(需要在 MMCV 目录下执行):

pip install -U pre-commit
pre-commit install

检查 pre-commit 是否配置成功,并安装 .pre-commit-config.yaml 中的钩子:

pre-commit run --all-files

注解

如果你是中国用户,由于网络原因,可能会出现安装失败的情况,这时可以使用国内源

pre-commit install -c .pre-commit-config-zh-cn.yaml

pre-commit run –all-files -c .pre-commit-config-zh-cn.yaml

如果安装过程被中断,可以重复执行 pre-commit run ... 继续安装。

如果提交的代码不符合代码风格规范,pre-commit 会发出警告,并自动修复部分错误。

如果我们想临时绕开 pre-commit 的检查提交一次代码,可以在 git commit 时加上 --no-verify(需要保证最后推送至远程仓库的代码能够通过 pre-commit 检查)。

git commit -m "xxx" --no-verify

3. 创建开发分支

安装完 pre-commit 之后,我们需要基于 main 创建开发分支,建议的分支命名规则为 username/pr_name

git checkout -b yhc/refactor_contributing_doc

在后续的开发中,如果本地仓库的 main 分支落后于 upstream 的 main 分支,我们需要先拉取 upstream 的代码进行同步,再执行上面的命令

git pull upstream main

4. 提交代码并在本地通过单元测试

  • MMCV 引入了 mypy 来做静态类型检查,以增加代码的鲁棒性。因此我们在提交代码时,需要补充 Type Hints。具体规则可以参考教程

  • 提交的代码同样需要通过单元测试

    # 通过全量单元测试
    pytest tests
    
    # 我们需要保证提交的代码能够通过修改模块的单元测试,以 runner 为例
    pytest tests/test_runner/test_runner.py
    

    如果你由于缺少依赖无法运行修改模块的单元测试,可以参考指引-单元测试

  • 如果修改/添加了文档,参考指引确认文档渲染正常。

5. 推送代码到远程

代码通过单元测试和 pre-commit 检查后,将代码推送到远程仓库,如果是第一次推送,可以在 git push 后加上 -u 参数以关联远程分支

git push -u origin {branch_name}

这样下次就可以直接使用 git push 命令推送代码了,而无需指定分支和远程仓库。

6. 提交拉取请求(PR)

(1) 在 GitHub 的 Pull request 界面创建拉取请求

(2) 根据指引修改 PR 描述,以便于其他开发者更好地理解你的修改

描述规范详见拉取请求规范

 

注意事项

(a) PR 描述应该包含修改理由、修改内容以及修改后带来的影响,并关联相关 Issue(具体方式见文档

(b) 如果是第一次为 OpenMMLab 做贡献,需要签署 CLA

(c) 检查提交的 PR 是否通过 CI(集成测试)

MMCV 会在不同的平台(Linux、Window、Mac),基于不同版本的 Python、PyTorch、CUDA 对提交的代码进行单元测试,以保证代码的正确性,如果有任何一个没有通过,我们可点击上图中的 Details 来查看具体的测试信息,以便于我们修改代码。

(3) 如果 PR 通过了 CI,那么就可以等待其他开发者的 review,并根据 reviewer 的意见,修改代码,并重复 4-5 步骤,直到 reviewer 同意合入 PR。

所有 reviewer 同意合入 PR 后,我们会尽快将 PR 合并到主分支。

7. 解决冲突

随着时间的推移,我们的代码库会不断更新,这时候,如果你的 PR 与主分支存在冲突,你需要解决冲突,解决冲突的方式有两种:

git fetch --all --prune
git rebase upstream/main

或者

git fetch --all --prune
git merge upstream/main

如果你非常善于处理冲突,那么可以使用 rebase 的方式来解决冲突,因为这能够保证你的 commit log 的整洁。如果你不太熟悉 rebase 的使用,那么可以使用 merge 的方式来解决冲突。

指引

单元测试

如果你无法正常执行部分模块的单元测试,例如 video 模块,可能是你的当前环境没有安装以下依赖

# Linux
sudo apt-get update -y
sudo apt-get install -y libturbojpeg
sudo apt-get install -y ffmpeg

# Windows
conda install ffmpeg

在提交修复代码错误或新增特性的拉取请求时,我们应该尽可能的让单元测试覆盖所有提交的代码,计算单元测试覆盖率的方法如下

python -m coverage run -m pytest /path/to/test_file
python -m coverage html
# check file in htmlcov/index.html

文档渲染

在提交修复代码错误或新增特性的拉取请求时,可能会需要修改/新增模块的 docstring。我们需要确认渲染后的文档样式是正确的。 本地生成渲染后的文档的方法如下

pip install -r requirements/docs.txt
cd docs/zh_cn/
# or docs/en
make html
# check file in ./docs/zh_cn/_build/html/index.html

代码风格

Python

PEP8 作为 OpenMMLab 算法库首选的代码规范,我们使用以下工具检查和格式化代码

  • flake8: Python 官方发布的代码规范检查工具,是多个检查工具的封装

  • isort: 自动调整模块导入顺序的工具

  • yapf: Google 发布的代码规范检查工具

  • codespell: 检查单词拼写是否有误

  • mdformat: 检查 markdown 文件的工具

  • docformatter: 格式化 docstring 的工具

yapf 和 isort 的配置可以在 setup.cfg 找到

通过配置 pre-commit hook ,我们可以在提交代码时自动检查和格式化 flake8yapfisorttrailing whitespacesmarkdown files, 修复 end-of-filesdouble-quoted-stringspython-encoding-pragmamixed-line-ending,调整 requirments.txt 的包顺序。 pre-commit 钩子的配置可以在 .pre-commit-config 找到。

pre-commit 具体的安装使用方式见拉取请求

更具体的规范请参考 OpenMMLab 代码规范

C++ and CUDA

C++ 和 CUDA 的代码规范遵从 Google C++ Style Guide

拉取请求规范

  1. 使用 pre-commit hook,尽量减少代码风格相关问题

  2. 一个拉取请求对应一个短期分支

  3. 粒度要细,一个拉取请求只做一件事情,避免超大的拉取请求

    • Bad:实现 Faster R-CNN

    • Acceptable:给 Faster R-CNN 添加一个 box head

    • Good:给 box head 增加一个参数来支持自定义的 conv 层数

  4. 每次 Commit 时需要提供清晰且有意义 commit 信息

  5. 提供清晰且有意义的拉取请求描述

    • 标题写明白任务名称,一般格式:[Prefix] Short description of the pull request (Suffix)

    • prefix: 新增功能 [Feature], 修 bug [Fix], 文档相关 [Docs], 开发中 [WIP] (暂时不会被review)

    • 描述里介绍拉取请求的主要修改内容,结果,以及对其他部分的影响, 参考拉取请求模板

    • 关联相关的议题 (issue) 和其他拉取请求

  6. 如果引入了其他三方库,或借鉴了三方库的代码,请确认他们的许可证和 mmcv 兼容,并在借鉴的代码上补充 This code is inspired from http://

拉取请求

本文档的内容已迁移到贡献指南

代码规范

代码规范标准

PEP 8 —— Python 官方代码规范

Python 官方的代码风格指南,包含了以下几个方面的内容:

  • 代码布局,介绍了 Python 中空行、断行以及导入相关的代码风格规范。比如一个常见的问题:当我的代码较长,无法在一行写下时,何处可以断行?

  • 表达式,介绍了 Python 中表达式空格相关的一些风格规范。

  • 尾随逗号相关的规范。当列表较长,无法一行写下而写成如下逐行列表时,推荐在末项后加逗号,从而便于追加选项、版本控制等。

    # Correct:
    FILES = ['setup.cfg', 'tox.ini']
    # Correct:
    FILES = [
        'setup.cfg',
        'tox.ini',
    ]
    # Wrong:
    FILES = ['setup.cfg', 'tox.ini',]
    # Wrong:
    FILES = [
        'setup.cfg',
        'tox.ini'
    ]
    
  • 命名相关规范、注释相关规范、类型注解相关规范,我们将在后续章节中做详细介绍。

    “A style guide is about consistency. Consistency with this style guide is important. Consistency within a project is more important. Consistency within one module or function is the most important.” PEP 8 – Style Guide for Python Code

注解

PEP 8 的代码规范并不是绝对的,项目内的一致性要优先于 PEP 8 的规范。OpenMMLab 各个项目都在 setup.cfg 设定了一些代码规范的设置,请遵照这些设置。一个例子是在 PEP 8 中有如下一个例子:

# Correct:
hypot2 = x*x + y*y
# Wrong:
hypot2 = x * x + y * y

这一规范是为了指示不同优先级,但 OpenMMLab 的设置中通常没有启用 yapf 的 ARITHMETIC_PRECEDENCE_INDICATION 选项,因而格式规范工具不会按照推荐样式格式化,以设置为准。

Google 开源项目风格指南

Google 使用的编程风格指南,包括了 Python 相关的章节。相较于 PEP 8,该指南提供了更为详尽的代码指南。该指南包括了语言规范和风格规范两个部分。

其中,语言规范对 Python 中很多语言特性进行了优缺点的分析,并给出了使用指导意见,如异常、Lambda 表达式、列表推导式、metaclass 等。

风格规范的内容与 PEP 8 较为接近,大部分约定建立在 PEP 8 的基础上,也有一些更为详细的约定,如函数长度、TODO 注释、文件与 socket 对象的访问等。

推荐将该指南作为参考进行开发,但不必严格遵照,一来该指南存在一些 Python 2 兼容需求,例如指南中要求所有无基类的类应当显式地继承 Object, 而在仅使用 Python 3 的环境中,这一要求是不必要的,依本项目中的惯例即可。二来 OpenMMLab 的项目作为框架级的开源软件,不必对一些高级技巧过于避讳,尤其是 MMCV。但尝试使用这些技巧前应当认真考虑是否真的有必要,并寻求其他开发人员的广泛评估。

另外需要注意的一处规范是关于包的导入,在该指南中,要求导入本地包时必须使用路径全称,且导入的每一个模块都应当单独成行,通常这是不必要的,而且也不符合目前项目的开发惯例,此处进行如下约定:

# Correct
from mmcv.cnn.bricks import (Conv2d, build_norm_layer, DropPath, MaxPool2d,
                             Linear)
from ..utils import ext_loader

# Wrong
from mmcv.cnn.bricks import Conv2d, build_norm_layer, DropPath, MaxPool2d, \
                            Linear  # 使用括号进行连接,而不是反斜杠
from ...utils import is_str  # 最多向上回溯一层,过多的回溯容易导致结构混乱

OpenMMLab 项目使用 pre-commit 工具自动格式化代码,详情见贡献代码

命名规范

命名规范的重要性

优秀的命名是良好代码可读的基础。基础的命名规范对各类变量的命名做了要求,使读者可以方便地根据代码名了解变量是一个类 / 局部变量 / 全局变量等。而优秀的命名则需要代码作者对于变量的功能有清晰的认识,以及良好的表达能力,从而使读者根据名称就能了解其含义,甚至帮助了解该段代码的功能。

基础命名规范

类型 公有 私有
模块 lower_with_under _lower_with_under
lower_with_under
CapWords _CapWords
异常 CapWordsError
函数(方法) lower_with_under _lower_with_under
函数 / 方法参数 lower_with_under
全局 / 类内常量 CAPS_WITH_UNDER _CAPS_WITH_UNDER
全局 / 类内变量 lower_with_under _lower_with_under
变量 lower_with_under _lower_with_under
局部变量 lower_with_under

注意:

  • 尽量避免变量名与保留字冲突,特殊情况下如不可避免,可使用一个后置下划线,如 class_

  • 尽量不要使用过于简单的命名,除了约定俗成的循环变量 i,文件变量 f,错误变量 e 等。

  • 不会被用到的变量可以命名为 _,逻辑检查器会将其忽略。

命名技巧

良好的变量命名需要保证三点:

  1. 含义准确,没有歧义

  2. 长短适中

  3. 前后统一

# Wrong
class Masks(metaclass=ABCMeta):  # 命名无法表现基类;Instance or Semantic?
    pass

# Correct
class BaseInstanceMasks(metaclass=ABCMeta):
    pass

# Wrong,不同地方含义相同的变量尽量用统一的命名
def __init__(self, inplanes, planes):
    pass

def __init__(self, in_channels, out_channels):
    pass

常见的函数命名方法:

  • 动宾命名法:crop_img, init_weights

  • 动宾倒置命名法:imread, bbox_flip

注意函数命名与参数的顺序,保证主语在前,符合语言习惯:

  • check_keys_exist(key, container)

  • check_keys_contain(container, key)

注意避免非常规或统一约定的缩写,如 nb -> num_blocks,in_nc -> in_channels

docstring 规范

为什么要写 docstring

docstring 是对一个类、一个函数功能与 API 接口的详细描述,有两个功能,一是帮助其他开发者了解代码功能,方便 debug 和复用代码;二是在 Readthedocs 文档中自动生成相关的 API reference 文档,帮助不了解源代码的社区用户使用相关功能。

如何写 docstring

与注释不同,一份规范的 docstring 有着严格的格式要求,以便于 Python 解释器以及 sphinx 进行文档解析,详细的 docstring 约定参见 PEP 257。此处以例子的形式介绍各种文档的标准格式,参考格式为 Google 风格

  1. 模块文档

    代码风格规范推荐为每一个模块(即 Python 文件)编写一个 docstring,但目前 OpenMMLab 项目大部分没有此类 docstring,因此不做硬性要求。

    """A one line summary of the module or program, terminated by a period.
    
    Leave one blank line. The rest of this docstring should contain an
    overall description of the module or program. Optionally, it may also
    contain a brief description of exported classes and functions and/or usage
    examples.
    
    Typical usage example:
    
    foo = ClassFoo()
    bar = foo.FunctionBar()
    """
    
  2. 类文档

    类文档是我们最常需要编写的,此处,按照 OpenMMLab 的惯例,我们使用了与 Google 风格不同的写法。如下例所示,文档中没有使用 Attributes 描述类属性,而是使用 Args 描述 init 函数的参数。

    在 Args 中,遵照 parameter (type): Description. 的格式,描述每一个参数类型和功能。其中,多种类型可使用 (float or str) 的写法,可以为 None 的参数可以写为 (int, optional)

    class BaseRunner(metaclass=ABCMeta):
        """The base class of Runner, a training helper for PyTorch.
    
        All subclasses should implement the following APIs:
    
        - ``run()``
        - ``train()``
        - ``val()``
        - ``save_checkpoint()``
    
        Args:
            model (:obj:`torch.nn.Module`): The model to be run.
            batch_processor (callable, optional): A callable method that process
                a data batch. The interface of this method should be
                ``batch_processor(model, data, train_mode) -> dict``.
                Defaults to None.
            optimizer (dict or :obj:`torch.optim.Optimizer`, optional): It can be
                either an optimizer (in most cases) or a dict of optimizers
                (in models that requires more than one optimizer, e.g., GAN).
                Defaults to None.
            work_dir (str, optional): The working directory to save checkpoints
                and logs. Defaults to None.
            logger (:obj:`logging.Logger`): Logger used during training.
                 Defaults to None. (The default value is just for backward
                 compatibility)
            meta (dict, optional): A dict records some import information such as
                environment info and seed, which will be logged in logger hook.
                Defaults to None.
            max_epochs (int, optional): Total training epochs. Defaults to None.
            max_iters (int, optional): Total training iterations. Defaults to None.
        """
    
        def __init__(self,
                     model,
                     batch_processor=None,
                     optimizer=None,
                     work_dir=None,
                     logger=None,
                     meta=None,
                     max_iters=None,
                     max_epochs=None):
            ...
    

    另外,在一些算法实现的主体类中,建议加入原论文的链接;如果参考了其他开源代码的实现,则应加入 modified from,而如果是直接复制了其他代码库的实现,则应加入 copied from ,并注意源码的 License。如有必要,也可以通过 .. math:: 来加入数学公式

    # 参考实现
    # This func is modified from `detectron2
    # <https://github.com/facebookresearch/detectron2/blob/ffff8acc35ea88ad1cb1806ab0f00b4c1c5dbfd9/detectron2/structures/masks.py#L387>`_.
    
    # 复制代码
    # This code was copied from the `ubelt
    # library<https://github.com/Erotemic/ubelt>`_.
    
    # 引用论文 & 添加公式
    class LabelSmoothLoss(nn.Module):
        r"""Initializer for the label smoothed cross entropy loss.
    
        Refers to `Rethinking the Inception Architecture for Computer Vision
        <https://arxiv.org/abs/1512.00567>`_.
    
        This decreases gap between output scores and encourages generalization.
        Labels provided to forward can be one-hot like vectors (NxC) or class
        indices (Nx1).
        And this accepts linear combination of one-hot like labels from mixup or
        cutmix except multi-label task.
    
        Args:
            label_smooth_val (float): The degree of label smoothing.
            num_classes (int, optional): Number of classes. Defaults to None.
            mode (str): Refers to notes, Options are "original", "classy_vision",
                "multi_label". Defaults to "classy_vision".
            reduction (str): The method used to reduce the loss.
                Options are "none", "mean" and "sum". Defaults to 'mean'.
            loss_weight (float):  Weight of the loss. Defaults to 1.0.
    
        Note:
            if the ``mode`` is "original", this will use the same label smooth
            method as the original paper as:
    
            .. math::
                (1-\epsilon)\delta_{k, y} + \frac{\epsilon}{K}
    
            where :math:`\epsilon` is the ``label_smooth_val``, :math:`K` is
            the ``num_classes`` and :math:`\delta_{k,y}` is Dirac delta,
            which equals 1 for k=y and 0 otherwise.
    
            if the ``mode`` is "classy_vision", this will use the same label
            smooth method as the `facebookresearch/ClassyVision
            <https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/losses/label_smoothing_loss.py>`_ repo as:
    
            .. math::
                \frac{\delta_{k, y} + \epsilon/K}{1+\epsilon}
    
            if the ``mode`` is "multi_label", this will accept labels from
            multi-label task and smoothing them as:
    
            .. math::
                (1-2\epsilon)\delta_{k, y} + \epsilon
    

注解

注意 ``here``、`here`、”here” 三种引号功能是不同。

在 reStructured 语法中,``here`` 表示一段代码;`here` 表示斜体;”here” 无特殊含义,一般可用来表示字符串。其中 `here` 的用法与 Markdown 中不同,需要多加留意。 另外还有 :obj:`type` 这种更规范的表示类的写法,但鉴于长度,不做特别要求,一般仅用于表示非常用类型。

  1. 方法(函数)文档

    函数文档与类文档的结构基本一致,但需要加入返回值文档。对于较为复杂的函数和类,可以使用 Examples 字段加入示例;如果需要对参数加入一些较长的备注,可以加入 Note 字段进行说明。

    对于使用较为复杂的类或函数,比起看大段大段的说明文字和参数文档,添加合适的示例更能帮助用户迅速了解其用法。需要注意的是,这些示例最好是能够直接在 Python 交互式环境中运行的,并给出一些相对应的结果。如果存在多个示例,可以使用注释简单说明每段示例,也能起到分隔作用。

    def import_modules_from_strings(imports, allow_failed_imports=False):
        """Import modules from the given list of strings.
    
        Args:
            imports (list | str | None): The given module names to be imported.
            allow_failed_imports (bool): If True, the failed imports will return
                None. Otherwise, an ImportError is raise. Defaults to False.
    
        Returns:
            List[module] | module | None: The imported modules.
            All these three lines in docstring will be compiled into the same
            line in readthedocs.
    
        Examples:
            >>> osp, sys = import_modules_from_strings(
            ...     ['os.path', 'sys'])
            >>> import os.path as osp_
            >>> import sys as sys_
            >>> assert osp == osp_
            >>> assert sys == sys_
        """
        ...
    

    如果函数接口在某个版本发生了变化,需要在 docstring 中加入相关的说明,必要时添加 Note 或者 Warning 进行说明,例如:

    class CheckpointHook(Hook):
        """Save checkpoints periodically.
    
        Args:
            out_dir (str, optional): The root directory to save checkpoints. If
                not specified, ``runner.work_dir`` will be used by default. If
                specified, the ``out_dir`` will be the concatenation of
                ``out_dir`` and the last level directory of ``runner.work_dir``.
                Defaults to None. `Changed in version 1.3.15.`
            file_client_args (dict, optional): Arguments to instantiate a
                FileClient. See :class:`mmcv.fileio.FileClient` for details.
                Defaults to None. `New in version 1.3.15.`
    
        Warning:
            Before v1.3.15, the ``out_dir`` argument indicates the path where the
            checkpoint is stored. However, in v1.3.15 and later, ``out_dir``
            indicates the root directory and the final path to save checkpoint is
            the concatenation of out_dir and the last level directory of
            ``runner.work_dir``. Suppose the value of ``out_dir`` is
            "/path/of/A" and the value of ``runner.work_dir`` is "/path/of/B",
            then the final path will be "/path/of/A/B".
    

    如果参数或返回值里带有需要展开描述字段的 dict,则应该采用如下格式:

    def func(x):
        r"""
        Args:
            x (None): A dict with 2 keys, ``padded_targets``, and ``targets``.
    
                - ``targets`` (list[Tensor]): A list of tensors.
                  Each tensor has the shape of :math:`(T_i)`. Each
                  element is the index of a character.
                - ``padded_targets`` (Tensor): A tensor of shape :math:`(N)`.
                  Each item is the length of a word.
    
        Returns:
            dict: A dict with 2 keys, ``padded_targets``, and ``targets``.
    
            - ``targets`` (list[Tensor]): A list of tensors.
              Each tensor has the shape of :math:`(T_i)`. Each
              element is the index of a character.
            - ``padded_targets`` (Tensor): A tensor of shape :math:`(N)`.
              Each item is the length of a word.
        """
        return x
    

重要

为了生成 readthedocs 文档,文档的编写需要按照 ReStructrued 文档格式,否则会产生文档渲染错误,在提交 PR 前,最好生成并预览一下文档效果。 语法规范参考:

注释规范

为什么要写注释

对于一个开源项目,团队合作以及社区之间的合作是必不可少的,因而尤其要重视合理的注释。不写注释的代码,很有可能过几个月自己也难以理解,造成额外的阅读和修改成本。

如何写注释

最需要写注释的是代码中那些技巧性的部分。如果你在下次代码审查的时候必须解释一下,那么你应该现在就给它写注释。对于复杂的操作,应该在其操作开始前写上若干行注释。对于不是一目了然的代码,应在其行尾添加注释。 —— Google 开源项目风格指南

# We use a weighted dictionary search to find out where i is in
# the array. We extrapolate position based on the largest num
# in the array and the array size and then do binary search to
# get the exact number.
if i & (i-1) == 0:  # True if i is 0 or a power of 2.

为了提高可读性, 注释应该至少离开代码2个空格. 另一方面, 绝不要描述代码. 假设阅读代码的人比你更懂Python, 他只是不知道你的代码要做什么. —— Google 开源项目风格指南

# Wrong:
# Now go through the b array and make sure whenever i occurs
# the next element is i+1

# Wrong:
if i & (i-1) == 0:  # True if i bitwise and i-1 is 0.

在注释中,可以使用 Markdown 语法,因为开发人员通常熟悉 Markdown 语法,这样可以便于交流理解,如可使用单反引号表示代码和变量(注意不要和 docstring 中的 ReStructured 语法混淆)

# `_reversed_padding_repeated_twice` is the padding to be passed to
# `F.pad` if needed (e.g., for non-zero padding types that are
# implemented as two ops: padding + conv). `F.pad` accepts paddings in
# reverse order than the dimension.
self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2)

注释示例

  1. 出自 mmcv/utils/registry.py,对于较为复杂的逻辑结构,通过注释,明确了优先级关系。

    # self.build_func will be set with the following priority:
    # 1. build_func
    # 2. parent.build_func
    # 3. build_from_cfg
    if build_func is None:
        if parent is not None:
            self.build_func = parent.build_func
        else:
            self.build_func = build_from_cfg
    else:
        self.build_func = build_func
    
  2. 出自 mmcv/runner/checkpoint.py,对于 bug 修复中的一些特殊处理,可以附带相关的 issue 链接,帮助其他人了解 bug 背景。

    def _save_ckpt(checkpoint, file):
        # The 1.6 release of PyTorch switched torch.save to use a new
        # zipfile-based file format. It will cause RuntimeError when a
        # checkpoint was saved in high version (PyTorch version>=1.6.0) but
        # loaded in low version (PyTorch version<1.6.0). More details at
        # https://github.com/open-mmlab/mmpose/issues/904
        if digit_version(TORCH_VERSION) >= digit_version('1.6.0'):
            torch.save(checkpoint, file, _use_new_zipfile_serialization=False)
        else:
            torch.save(checkpoint, file)
    

类型注解

为什么要写类型注解

类型注解是对函数中变量的类型做限定或提示,为代码的安全性提供保障、增强代码的可读性、避免出现类型相关的错误。 Python 没有对类型做强制限制,类型注解只起到一个提示作用,通常你的 IDE 会解析这些类型注解,然后在你调用相关代码时对类型做提示。另外也有类型注解检查工具,这些工具会根据类型注解,对代码中可能出现的问题进行检查,减少 bug 的出现。 需要注意的是,通常我们不需要注释模块中的所有函数:

  1. 公共的 API 需要注释

  2. 在代码的安全性,清晰性和灵活性上进行权衡是否注释

  3. 对于容易出现类型相关的错误的代码进行注释

  4. 难以理解的代码请进行注释

  5. 若代码中的类型已经稳定,可以进行注释. 对于一份成熟的代码,多数情况下,即使注释了所有的函数,也不会丧失太多的灵活性.

如何写类型注解

  1. 函数 / 方法类型注解,通常不对 self 和 cls 注释。

    from typing import Optional, List, Tuple
    
    # 全部位于一行
    def my_method(self, first_var: int) -> int:
        pass
    
    # 另起一行
    def my_method(
            self, first_var: int,
            second_var: float) -> Tuple[MyLongType1, MyLongType1, MyLongType1]:
        pass
    
    # 单独成行(具体的应用场合与行宽有关,建议结合 yapf 自动化格式使用)
    def my_method(
        self, first_var: int, second_var: float
    ) -> Tuple[MyLongType1, MyLongType1, MyLongType1]:
        pass
    
    # 引用尚未被定义的类型
    class MyClass:
        def __init__(self,
                     stack: List["MyClass"]) -> None:
            pass
    

    注:类型注解中的类型可以是 Python 内置类型,也可以是自定义类,还可以使用 Python 提供的 wrapper 类对类型注解进行装饰,一些常见的注解如下:

    # 数值类型
    from numbers import Number
    
    # 可选类型,指参数可以为 None
    from typing import Optional
    def foo(var: Optional[int] = None):
        pass
    
    # 联合类型,指同时接受多种类型
    from typing import Union
    def foo(var: Union[float, str]):
        pass
    
    from typing import Sequence  # 序列类型
    from typing import Iterable  # 可迭代类型
    from typing import Any  # 任意类型
    from typing import Callable  # 可调用类型
    
    from typing import List, Dict  # 列表和字典的泛型类型
    from typing import Tuple  # 元组的特殊格式
    # 虽然在 Python 3.9 中,list, tuple 和 dict 本身已支持泛型,但为了支持之前的版本
    # 我们在进行类型注解时还是需要使用 List, Tuple, Dict 类型
    # 另外,在对参数类型进行注解时,尽量使用 Sequence & Iterable & Mapping
    # List, Tuple, Dict 主要用于返回值类型注解
    # 参见 https://docs.python.org/3/library/typing.html#typing.List
    
  2. 变量类型注解,一般用于难以直接推断其类型时

    # Recommend: 带类型注解的赋值
    a: Foo = SomeUndecoratedFunction()
    a: List[int]: [1, 2, 3]         # List 只支持单一类型泛型,可使用 Union
    b: Tuple[int, int] = (1, 2)     # 长度固定为 2
    c: Tuple[int, ...] = (1, 2, 3)  # 变长
    d: Dict[str, int] = {'a': 1, 'b': 2}
    
    # Not Recommend:行尾类型注释
    # 虽然这种方式被写在了 Google 开源指南中,但这是一种为了支持 Python 2.7 版本
    # 而补充的注释方式,鉴于我们只支持 Python 3, 为了风格统一,不推荐使用这种方式。
    a = SomeUndecoratedFunction()  # type: Foo
    a = [1, 2, 3]  # type: List[int]
    b = (1, 2, 3)  # type: Tuple[int, ...]
    c = (1, "2", 3.5)  # type: Tuple[int, Text, float]
    
  3. 泛型

    上文中我们知道,typing 中提供了 list 和 dict 的泛型类型,那么我们自己是否可以定义类似的泛型呢?

    from typing import TypeVar, Generic
    
    KT = TypeVar('KT')
    VT = TypeVar('VT')
    
    class Mapping(Generic[KT, VT]):
        def __init__(self, data: Dict[KT, VT]):
            self._data = data
    
        def __getitem__(self, key: KT) -> VT:
            return self._data[key]
    

    使用上述方法,我们定义了一个拥有泛型能力的映射类,实际用法如下:

    mapping = Mapping[str, float]({'a': 0.5})
    value: float = example['a']
    

    另外,我们也可以利用 TypeVar 在函数签名中指定联动的多个类型:

    from typing import TypeVar, List
    
    T = TypeVar('T')  # Can be anything
    A = TypeVar('A', str, bytes)  # Must be str or bytes
    
    
    def repeat(x: T, n: int) -> List[T]:
        """Return a list containing n references to x."""
        return [x]*n
    
    
    def longest(x: A, y: A) -> A:
        """Return the longest of two strings."""
        return x if len(x) >= len(y) else y
    

更多关于类型注解的写法请参考 typing

类型注解检查工具

mypy 是一个 Python 静态类型检查工具。根据你的类型注解,mypy 会检查传参、赋值等操作是否符合类型注解,从而避免可能出现的 bug。

例如如下的一个 Python 脚本文件 test.py:

def foo(var: int) -> float:
    return float(var)

a: str = foo('2.0')
b: int = foo('3.0')  # type: ignore

运行 mypy test.py 可以得到如下检查结果,分别指出了第 4 行在函数调用和返回值赋值两处类型错误。而第 5 行同样存在两个类型错误,由于使用了 type: ignore 而被忽略了,只有部分特殊情况可能需要此类忽略。

test.py:4: error: Incompatible types in assignment (expression has type "float", variable has type "int")
test.py:4: error: Argument 1 to "foo" has incompatible type "str"; expected "int"
Found 2 errors in 1 file (checked 1 source file)

mmcv.image

IO

imfrombytes

Read an image from bytes.

imread

Read an image.

imwrite

Write image to file.

use_backend

Select a backend for image decoding.

Color Space

bgr2gray

Convert a BGR image to grayscale image.

bgr2hls

Convert a BGR image to HLS

bgr2hsv

Convert a BGR image to HSV

bgr2rgb

Convert a BGR image to RGB

bgr2ycbcr

Convert a BGR image to YCbCr image.

gray2bgr

Convert a grayscale image to BGR image.

gray2rgb

Convert a grayscale image to RGB image.

hls2bgr

Convert a HLS image to BGR

hsv2bgr

Convert a HSV image to BGR

imconvert

Convert an image from the src colorspace to dst colorspace.

rgb2bgr

Convert a RGB image to BGR

rgb2gray

Convert a RGB image to grayscale image.

rgb2ycbcr

Convert a RGB image to YCbCr image.

ycbcr2bgr

Convert a YCbCr image to BGR image.

ycbcr2rgb

Convert a YCbCr image to RGB image.

Geometric

cutout

Randomly cut out a rectangle from the original img.

imcrop

Crop image patches.

imflip

Flip an image horizontally or vertically.

impad

Pad the given image to a certain shape or pad on all sides with specified padding mode and padding value.

impad_to_multiple

Pad an image to ensure each edge to be multiple to some number.

imrescale

Resize image while keeping the aspect ratio.

imresize

Resize image to a given size.

imresize_like

Resize image to the same size of a given image.

imresize_to_multiple

Resize image according to a given size or scale factor and then rounds up the the resized or rescaled image size to the nearest value that can be divided by the divisor.

imrotate

Rotate an image.

imshear

Shear an image.

imtranslate

Translate an image.

rescale_size

Calculate the new size to be rescaled to.

Photometric

adjust_brightness

Adjust image brightness.

adjust_color

It blends the source image and its gray image:

adjust_contrast

Adjust image contrast.

adjust_hue

Adjust hue of an image.

adjust_lighting

AlexNet-style PCA jitter.

adjust_sharpness

Adjust image sharpness.

auto_contrast

Auto adjust image contrast.

clahe

Use CLAHE method to process the image.

imdenormalize

imequalize

Equalize the image histogram.

iminvert

Invert (negate) an image.

imnormalize

Normalize an image with mean and std.

lut_transform

Transform array by look-up table.

posterize

Posterize an image (reduce the number of bits for each color channel)

solarize

Solarize an image (invert all pixel values above a threshold)

Miscellaneous

tensor2imgs

Convert tensor to 3-channel images or 1-channel gray images.

mmcv.video

IO

VideoReader

Video class with similar usage to a list object.

Cache

frames2video

Read the frame images from a directory and join them as a video.

Optical Flow

dequantize_flow

Recover from quantized flow.

flow_from_bytes

Read dense optical flow from bytes.

flow_warp

Use flow to warp img.

flowread

Read an optical flow map.

flowwrite

Write optical flow to file.

quantize_flow

Quantize flow to [0, 255].

sparse_flow_from_bytes

Read the optical flow in KITTI datasets from bytes.

Video Processing

concat_video

Concatenate multiple videos into a single one.

convert_video

Convert a video with ffmpeg.

cut_video

Cut a clip from a video.

resize_video

Resize a video.

mmcv.visualization

mmcv.visualization

Color

Color

An enum that defines common colors.

color_val

Convert various input to color tuples.

Image

imshow

Show an image.

imshow_bboxes

Draw bboxes on an image.

imshow_det_bboxes

Draw bboxes and class labels (with scores) on an image.

Optical Flow

flow2rgb

Convert flow map to RGB image.

flowshow

Show optical flow.

make_color_wheel

Build a color wheel.

mmcv.cnn

Module

ContextBlock

ContextBlock module in GCNet.

Conv2d

Conv3d

ConvAWS2d

AWS (Adaptive Weight Standardization)

ConvModule

A conv block that bundles conv/norm/activation layers.

ConvTranspose2d

ConvTranspose3d

ConvWS2d

DepthwiseSeparableConvModule

Depthwise separable convolution module.

GeneralizedAttention

GeneralizedAttention module.

HSigmoid

Hard Sigmoid Module.

HSwish

Hard Swish Module.

Linear

MaxPool2d

MaxPool3d

NonLocal1d

1D Non-local module.

NonLocal2d

2D Non-local module.

NonLocal3d

3D Non-local module.

Scale

A learnable scale parameter.

Swish

Swish Module.

Conv2dRFSearchOp

Enable Conv2d with receptive field searching ability.

Build Function

build_activation_layer

Build activation layer.

build_conv_layer

Build convolution layer.

build_norm_layer

Build normalization layer.

build_padding_layer

Build padding layer.

build_plugin_layer

Build plugin layer.

build_upsample_layer

Build upsample layer.

Miscellaneous

fuse_conv_bn

Recursively fuse conv and bn in a module.

conv_ws_2d

is_norm

Check if a layer is a normalization layer.

make_res_layer

make_vgg_layer

get_model_complexity_info

Get complexity information of a model.

mmcv.ops

BorderAlign

Border align pooling layer.

CARAFE

CARAFE: Content-Aware ReAssembly of FEatures

CARAFENaive

CARAFEPack

A unified package of CARAFE upsampler that contains: 1) channel compressor 2) content encoder 3) CARAFE op.

Conv2d

alias of mmcv.ops.deprecated_wrappers.Conv2d_deprecated

ConvTranspose2d

alias of mmcv.ops.deprecated_wrappers.ConvTranspose2d_deprecated

CornerPool

Corner Pooling.

Correlation

Correlation operator

CrissCrossAttention

Criss-Cross Attention Module.

DeformConv2d

Deformable 2D convolution.

DeformConv2dPack

A Deformable Conv Encapsulation that acts as normal Conv layers.

DeformRoIPool

DeformRoIPoolPack

DynamicScatter

Scatters points into voxels, used in the voxel encoder with dynamic voxelization.

FusedBiasLeakyReLU

Fused bias leaky ReLU.

GroupAll

Group xyz with feature.

Linear

alias of mmcv.ops.deprecated_wrappers.Linear_deprecated

MaskedConv2d

A MaskedConv2d which inherits the official Conv2d.

MaxPool2d

alias of mmcv.ops.deprecated_wrappers.MaxPool2d_deprecated

ModulatedDeformConv2d

ModulatedDeformConv2dPack

A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.

ModulatedDeformRoIPoolPack

MultiScaleDeformableAttention

An attention module used in Deformable-Detr.

PSAMask

PointsSampler

Points sampling.

PrRoIPool

The operation of precision RoI pooling.

QueryAndGroup

Groups points with a ball query of radius.

RiRoIAlignRotated

Rotation-invariant RoI align pooling layer for rotated proposals.

RoIAlign

RoI align pooling layer.

RoIAlignRotated

RoI align pooling layer for rotated proposals.

RoIAwarePool3d

Encode the geometry-specific features of each 3D proposal.

RoIPointPool3d

Encode the geometry-specific features of each 3D proposal.

RoIPool

SAConv2d

SAC (Switchable Atrous Convolution)

SigmoidFocalLoss

SimpleRoIAlign

SoftmaxFocalLoss

SparseConv2d

SparseConv3d

SparseConvTensor

SparseConvTranspose2d

SparseConvTranspose3d

SparseInverseConv2d

SparseInverseConv3d

SparseMaxPool2d

SparseMaxPool3d

SparseModule

place holder, All module subclass from this will take sptensor in SparseSequential.

SparseSequential

A sequential container.

SubMConv2d

SubMConv3d

SyncBatchNorm

Synchronized Batch Normalization.

TINShift

Temporal Interlace Shift.

Voxelization

Convert kitti points(N, >=3) to voxels.

active_rotated_filter

assign_score_withk

ball_query

batched_nms

Performs non-maximum suppression in a batched fashion.

bbox_overlaps

Calculate overlap between two set of bboxes.

border_align

box_iou_rotated

Return intersection-over-union (Jaccard index) of boxes.

boxes_iou3d

Calculate boxes 3D IoU.

boxes_iou_bev

Calculate boxes IoU in the Bird’s Eye View.

boxes_overlap_bev

Calculate boxes BEV overlap.

carafe

carafe_naive

chamfer_distance

contour_expand

Expand kernel contours so that foreground pixels are assigned into instances.

convex_giou

Return generalized intersection-over-union (Jaccard index) between point sets and polygons.

convex_iou

Return intersection-over-union (Jaccard index) between point sets and polygons.

deform_conv2d

deform_roi_pool

diff_iou_rotated_2d

Calculate differentiable iou of rotated 2d boxes.

diff_iou_rotated_3d

Calculate differentiable iou of rotated 3d boxes.

dynamic_scatter

furthest_point_sample

furthest_point_sample_with_dist

fused_bias_leakyrelu

Fused bias leaky ReLU function.

gather_points

grouping_operation

knn

masked_conv2d

min_area_polygons

Find the smallest polygons that surrounds all points in the point sets.

modulated_deform_conv2d

nms

Dispatch to either CPU or GPU NMS implementations.

nms3d

3D NMS function GPU implementation (for BEV boxes).

nms3d_normal

Normal 3D NMS function GPU implementation.

nms_bev

NMS function GPU implementation (for BEV boxes).

nms_match

Matched dets into different groups by NMS.

nms_normal_bev

Normal NMS function GPU implementation (for BEV boxes).

nms_rotated

Performs non-maximum suppression (NMS) on the rotated boxes according to their intersection-over-union (IoU).

pixel_group

Group pixels into text instances, which is widely used text detection methods.

point_sample

A wrapper around grid_sample() to support 3D point_coords tensors Unlike torch.nn.functional.grid_sample() it assumes point_coords to lie inside [0, 1] x [0, 1] square.

points_in_boxes_all

Find all boxes in which each point is (CUDA).

points_in_boxes_cpu

Find all boxes in which each point is (CPU).

points_in_boxes_part

Find the box in which each point is (CUDA).

points_in_polygons

Judging whether points are inside polygons, which is used in the ATSS assignment for the rotated boxes.

prroi_pool

rel_roi_point_to_rel_img_point

Convert roi based relative point coordinates to image based absolute point coordinates.

riroi_align_rotated

roi_align

roi_align_rotated

roi_pool

rotated_feature_align

scatter_nd

pytorch edition of tensorflow scatter_nd.

sigmoid_focal_loss

soft_nms

Dispatch to only CPU Soft NMS implementations.

softmax_focal_loss

three_interpolate

three_nn

tin_shift

upfirdn2d

Pad, upsample, filter, and downsample a batch of 2D images.

voxelization

mmcv.transforms

BaseTransform

Base class for all transformations.

TestTimeAug

Test-time augmentation transform.

Loading

LoadAnnotations

Load and process the instances and seg_map annotation provided by dataset.

LoadImageFromFile

Load an image from file.

Processing

CenterCrop

Crop the center of the image, segmentation masks, bounding boxes and key points.

MultiScaleFlipAug

Test-time augmentation with multiple scales and flipping.

Normalize

Normalize the image.

Pad

Pad the image & segmentation map.

RandomChoiceResize

Resize images & bbox & mask from a list of multiple scales.

RandomFlip

Flip the image & bbox & keypoints & segmentation map.

RandomGrayscale

Randomly convert image to grayscale with a probability.

RandomResize

Random resize images & bbox & keypoints.

Resize

Resize images & bbox & seg & keypoints.

ToTensor

Convert some results to torch.Tensor by given keys.

ImageToTensor

Convert image to torch.Tensor by given keys.

Wrapper

Compose

Compose multiple transforms sequentially.

KeyMapper

A transform wrapper to map and reorganize the input/output of the wrapped transforms (or sub-pipeline).

RandomApply

Apply transforms randomly with a given probability.

RandomChoice

Process data with a randomly chosen transform from given candidates.

TransformBroadcaster

A transform wrapper to apply the wrapped transforms to multiple data items.

mmcv.arraymisc

quantize

Quantize an array of (-inf, inf) to [0, levels-1].

dequantize

Dequantize an array.

mmcv.utils

IS_CUDA_AVAILABLE

bool(x) -> bool

IS_MLU_AVAILABLE

bool(x) -> bool

IS_MPS_AVAILABLE

bool(x) -> bool

collect_env

Collect the information of the running environments.

jit

skip_no_elena

Indices and tables

Read the Docs v: stable
Versions
latest
stable
2.x
v2.0.1
v2.0.0
1.x
v1.7.1
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.