Shortcuts

MMCV中的TensorRT自定义算子 (实验性)

介绍

NVIDIA TensorRT是一个为深度学习模型高性能推理准备的软件开发工具(SDK)。它包括深度学习推理优化器和运行时,可为深度学习推理应用提供低延迟和高吞吐量。请访问developer’s website了解更多信息。 为了简化TensorRT部署带有MMCV自定义算子的模型的流程,MMCV中添加了一系列TensorRT插件。

MMCV中的TensorRT插件列表

ONNX算子 TensorRT插件 MMCV版本
MMCVRoiAlign MMCVRoiAlign 1.2.6
ScatterND ScatterND 1.2.6
NonMaxSuppression NonMaxSuppression 1.3.0
MMCVDeformConv2d MMCVDeformConv2d 1.3.0
grid_sampler grid_sampler 1.3.1
cummax cummax 1.3.5
cummin cummin 1.3.5
MMCVInstanceNormalization MMCVInstanceNormalization 1.3.5
MMCVModulatedDeformConv2d MMCVModulatedDeformConv2d master

注意

  • 以上所有算子均在 TensorRT-7.2.1.6.Ubuntu-16.04.x86_64-gnu.cuda-10.2.cudnn8.0 环境下开发。

如何编译MMCV中的TensorRT插件

准备

  • 克隆代码仓库

git clone https://github.com/open-mmlab/mmcv.git
  • 安装TensorRT

NVIDIA Developer Zone 下载合适的TensorRT版本。

比如,对安装了cuda-10.2的x86-64的Ubuntu 16.04,下载文件为TensorRT-7.2.1.6.Ubuntu-16.04.x86_64-gnu.cuda-10.2.cudnn8.0.tar.gz.

然后使用下面方式安装并配置环境

cd ~/Downloads
tar -xvzf TensorRT-7.2.1.6.Ubuntu-16.04.x86_64-gnu.cuda-10.2.cudnn8.0.tar.gz
export TENSORRT_DIR=`pwd`/TensorRT-7.2.1.6
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TENSORRT_DIR/lib

安装python依赖: tensorrt, graphsurgeon, onnx-graphsurgeon

pip install $TENSORRT_DIR/python/tensorrt-7.2.1.6-cp37-none-linux_x86_64.whl
pip install $TENSORRT_DIR/onnx_graphsurgeon/onnx_graphsurgeon-0.2.6-py2.py3-none-any.whl
pip install $TENSORRT_DIR/graphsurgeon/graphsurgeon-0.4.5-py2.py3-none-any.whl

想了解更多通过tar包安装TensorRT,请访问Nvidia’ website.

在Linux上编译

cd mmcv ## to MMCV root directory
MMCV_WITH_OPS=1 MMCV_WITH_TRT=1 pip install -e .

创建TensorRT推理引擎并在python下进行推理

范例如下:

import torch
import onnx

from mmcv.tensorrt import (TRTWrapper, onnx2trt, save_trt_engine,
                                   is_tensorrt_plugin_loaded)

assert is_tensorrt_plugin_loaded(), 'Requires to complie TensorRT plugins in mmcv'

onnx_file = 'sample.onnx'
trt_file = 'sample.trt'
onnx_model = onnx.load(onnx_file)

## Model input
inputs = torch.rand(1, 3, 224, 224).cuda()
## Model input shape info
opt_shape_dict = {
    'input': [list(inputs.shape),
              list(inputs.shape),
              list(inputs.shape)]
}

## Create TensorRT engine
max_workspace_size = 1 << 30
trt_engine = onnx2trt(
    onnx_model,
    opt_shape_dict,
    max_workspace_size=max_workspace_size)

## Save TensorRT engine
save_trt_engine(trt_engine, trt_file)

## Run inference with TensorRT
trt_model = TRTWrapper(trt_file, ['input'], ['output'])

with torch.no_grad():
    trt_outputs = trt_model({'input': inputs})
    output = trt_outputs['output']

如何在MMCV中添加新的TensorRT自定义算子

主要流程

下面是主要的步骤:

  1. 添加c++头文件

  2. 添加c++源文件

  3. 添加cuda kernel文件

  4. trt_plugin.cpp中注册插件

  5. tests/test_ops/test_tensorrt.py中添加单元测试

以RoIAlign算子插件roi_align举例。

  1. 在TensorRT包含目录mmcv/ops/csrc/tensorrt/中添加头文件trt_roi_align.hpp

  2. 在TensorRT源码目录mmcv/ops/csrc/tensorrt/plugins/中添加头文件trt_roi_align.cpp

  3. 在TensorRT源码目录mmcv/ops/csrc/tensorrt/plugins/中添加cuda kernel文件trt_roi_align_kernel.cu

  4. trt_plugin.cpp中注册roi_align插件

    #include "trt_plugin.hpp"
    
    #include "trt_roi_align.hpp"
    
    REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
    
    extern "C" {
    bool initLibMMCVInferPlugins() { return true; }
    }  // extern "C"
    
  5. tests/test_ops/test_tensorrt.py中添加单元测试

注意

  • 部分MMCV中的自定义算子存在对应的cuda实现,在进行TensorRT插件开发的时候可以参考。

已知问题

Read the Docs v: v1.4.3
Versions
latest
stable
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.