Deploying yolort on TensorRT

Unlike other pipelines that deal with yolov5 on TensorRT, we embed the whole post-processing into the Graph with onnx-graghsurgeon. We gain a lot with this whole pipeline. The ablation experiment results are below. The first one is the result without running EfficientNMS_TRT, and the second one is the result with EfficientNMS_TRT embedded. As you can see, the inference time is even reduced, we guess it is because the data copied to the device will be much less after doing EfficientNMS_TRT. (The mean Latency of D2H is reduced from 0.868048 ms to 0.0102295 ms, running on Nivdia Geforce GTX 1080ti, using TensorRT 8.2 with yolov5n6 and scaling images to 512x640.)

And onnx-graphsurgeon is easy to install, you can just use their prebuilt wheels:

python3 -m pip install onnx_graphsurgeon --index-url https://pypi.ngc.nvidia.com

The detailed results:

[I] === Performance summary w/o EfficientNMS_TRT plugin ===
[I] Throughput: 383.298 qps
[I] Latency: min = 3.66479 ms, max = 5.41199 ms, mean = 4.00543 ms, median = 3.99316 ms, percentile(99%) = 4.23831 ms
[I] End-to-End Host Latency: min = 3.76599 ms, max = 6.45874 ms, mean = 5.08597 ms, median = 5.07544 ms, percentile(99%) = 5.50839 ms
[I] Enqueue Time: min = 0.743408 ms, max = 5.27966 ms, mean = 0.940805 ms, median = 0.924805 ms, percentile(99%) = 1.37329 ms
[I] H2D Latency: min = 0.502045 ms, max = 0.62674 ms, mean = 0.538255 ms, median = 0.537354 ms, percentile(99%) = 0.582153 ms
[I] GPU Compute Time: min = 2.23233 ms, max = 3.92395 ms, mean = 2.59913 ms, median = 2.58661 ms, percentile(99%) = 2.8201 ms
[I] D2H Latency: min = 0.851807 ms, max = 0.900421 ms, mean = 0.868048 ms, median = 0.867676 ms, percentile(99%) = 0.889191 ms
[I] Total Host Walltime: 3.0081 s
[I] Total GPU Compute Time: 2.99679 s
[I] Explanations of the performance metrics are printed in the verbose logs.
[I]
&&&& PASSED TensorRT.trtexec [TensorRT v8201] # trtexec --onnx=yolov5n6-no-nms.onnx --workspace=8096
[I] === Performance summary w/ EfficientNMS_TRT plugin ===
[I] Throughput: 389.234 qps
[I] Latency: min = 2.81482 ms, max = 9.77234 ms, mean = 3.1062 ms, median = 3.07642 ms, percentile(99%) = 3.33548 ms
[I] End-to-End Host Latency: min = 2.82202 ms, max = 11.6749 ms, mean = 4.939 ms, median = 4.95587 ms, percentile(99%) = 5.45207 ms
[I] Enqueue Time: min = 0.999878 ms, max = 11.3833 ms, mean = 1.28942 ms, median = 1.18579 ms, percentile(99%) = 4.53088 ms
[I] H2D Latency: min = 0.488159 ms, max = 0.633881 ms, mean = 0.546754 ms, median = 0.546631 ms, percentile(99%) = 0.570557 ms
[I] GPU Compute Time: min = 2.30298 ms, max = 9.21094 ms, mean = 2.54921 ms, median = 2.51904 ms, percentile(99%) = 2.78528 ms
[I] D2H Latency: min = 0.00610352 ms, max = 0.302734 ms, mean = 0.0102295 ms, median = 0.00976562 ms, percentile(99%) = 0.0151367 ms
[I] Total Host Walltime: 3.00591 s
[I] Total GPU Compute Time: 2.98258 s
[I] Explanations of the performance metrics are printed in the verbose logs.
[I]
&&&& PASSED TensorRT.trtexec [TensorRT v8201] # trtexec --onnx=yolov5n6-efficient-nms.onnx --workspace=8096

TensorRT Installation Instructions

For us to unfold the subsequent story, TensorRT should be installed and the minimal version of TensorRT to run this demo is 8.2.0. Check out the TensorRT installation guides at https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html. You can use the Python wheels provided by TensorRT if you only want to use the Python interface:

pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com

These wheels only works on Ubuntu 18.04+ or CentOS 7+ with Python versions 3.6 to 3.9 and CUDA 11.x. Check out the details at https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html#installing-pip.

There are many ways to install the whole TensorRT at https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html#installing. One option to use the C++ interface of TensorRT is via docker. We’ve tested the docker published by Meta (Facebook) containing the TensorRT and PyTorch at NVIDIA GPU Cloud (NGC): https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.

[1]:
import os
import torch

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
cuda_visible = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible

assert torch.cuda.is_available()
device = torch.device('cuda')

import tensorrt as trt
print(f"We're using TensorRT: {trt.__version__} on {device} device: {cuda_visible}.")
We're using TensorRT: 8.2.0.6 on cuda device: 0.
[2]:
import cv2

from yolort.utils import cv2_imshow
from yolort.utils.image_utils import plot_one_box, color_list
from yolort.v5 import attempt_download
from yolort.v5.utils.downloads import safe_download

Prepare image and model weights to test

[3]:
# Define some parameters
batch_size = 1
img_size = 640
size_divisible = 64
fixed_shape = True
score_thresh = 0.35
nms_thresh = 0.45
detections_per_img = 100
precision = "fp32"  # Currently only supports fp32
[4]:
# img_source = "https://huggingface.co/spaces/zhiqwang/assets/resolve/main/zidane.jpg"
img_source = "https://huggingface.co/spaces/zhiqwang/assets/resolve/main/bus.jpg"
img_path = img_source.split("/")[-1]
safe_download(img_path, img_source)
img_raw = cv2.imread(img_path)
Downloading https://huggingface.co/spaces/zhiqwang/assets/resolve/main/bus.jpg to bus.jpg...

[5]:
# yolov5s6.pt is downloaded from 'https://github.com/ultralytics/yolov5/releases/download/v6.0/yolov5n6.pt'
model_path = "yolov5n6.pt"

checkpoint_path = attempt_download(model_path)
onnx_path = "yolov5n6.onnx"
engine_path = "yolov5n6.engine"

Export to ONNX and TensorRT model

We provide a utilization tool export_tensorrt_engine for exporting TensorRT engines.

[6]:
from yolort.runtime.trt_helper import export_tensorrt_engine
[7]:
input_sample = torch.rand(batch_size, 3, img_size, img_size)
[8]:
export_tensorrt_engine(
    model_path,
    score_thresh=score_thresh,
    nms_thresh=nms_thresh,
    onnx_path=onnx_path,
    engine_path=engine_path,
    input_sample=input_sample,
    detections_per_img=detections_per_img,
)

                 from  n    params  module                                  arguments
  0                -1  1      1760  yolort.v5.models.common.Conv            [3, 16, 6, 2, 2]
  1                -1  1      4672  yolort.v5.models.common.Conv            [16, 32, 3, 2]
  2                -1  1      4800  yolort.v5.models.common.C3              [32, 32, 1]
  3                -1  1     18560  yolort.v5.models.common.Conv            [32, 64, 3, 2]
  4                -1  2     29184  yolort.v5.models.common.C3              [64, 64, 2]
  5                -1  1     73984  yolort.v5.models.common.Conv            [64, 128, 3, 2]
  6                -1  3    156928  yolort.v5.models.common.C3              [128, 128, 3]
  7                -1  1    221568  yolort.v5.models.common.Conv            [128, 192, 3, 2]
  8                -1  1    167040  yolort.v5.models.common.C3              [192, 192, 1]
  9                -1  1    442880  yolort.v5.models.common.Conv            [192, 256, 3, 2]
 10                -1  1    296448  yolort.v5.models.common.C3              [256, 256, 1]
 11                -1  1    164608  yolort.v5.models.common.SPPF            [256, 256, 5]
 12                -1  1     49536  yolort.v5.models.common.Conv            [256, 192, 1, 1]
 13                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 14           [-1, 8]  1         0  yolort.v5.models.common.Concat          [1]
 15                -1  1    203904  yolort.v5.models.common.C3              [384, 192, 1, False]
 16                -1  1     24832  yolort.v5.models.common.Conv            [192, 128, 1, 1]
 17                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 18           [-1, 6]  1         0  yolort.v5.models.common.Concat          [1]
 19                -1  1     90880  yolort.v5.models.common.C3              [256, 128, 1, False]
 20                -1  1      8320  yolort.v5.models.common.Conv            [128, 64, 1, 1]
 21                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 22           [-1, 4]  1         0  yolort.v5.models.common.Concat          [1]
 23                -1  1     22912  yolort.v5.models.common.C3              [128, 64, 1, False]
 24                -1  1     36992  yolort.v5.models.common.Conv            [64, 64, 3, 2]
 25          [-1, 20]  1         0  yolort.v5.models.common.Concat          [1]
 26                -1  1     74496  yolort.v5.models.common.C3              [128, 128, 1, False]
 27                -1  1    147712  yolort.v5.models.common.Conv            [128, 128, 3, 2]
 28          [-1, 16]  1         0  yolort.v5.models.common.Concat          [1]
 29                -1  1    179328  yolort.v5.models.common.C3              [256, 192, 1, False]
 30                -1  1    332160  yolort.v5.models.common.Conv            [192, 192, 3, 2]
 31          [-1, 12]  1         0  yolort.v5.models.common.Concat          [1]
 32                -1  1    329216  yolort.v5.models.common.C3              [384, 256, 1, False]
 33  [23, 26, 29, 32]  1    164220  yolort.v5.models.yolo.Detect            [80, [[19, 27, 44, 40, 38, 94], [96, 68, 86, 152, 180, 137], [140, 301, 303, 264, 238, 542], [436, 615, 739, 380, 925, 792]], [64, 128, 192, 256]]
/opt/conda/lib/python3.8/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2157.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Model Summary: 355 layers, 3246940 parameters, 3246940 gradients, 4.6 GFLOPs

Loaded saved model from yolov5n6.pt
/coding/yolov5-rt-stack/yolort/models/anchor_utils.py:45: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  anchors = torch.as_tensor(self.anchor_grids, dtype=torch.float32, device=device).to(dtype=dtype)
/coding/yolov5-rt-stack/yolort/models/anchor_utils.py:46: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  strides = torch.as_tensor(self.strides, dtype=torch.float32, device=device).to(dtype=dtype)
/coding/yolov5-rt-stack/yolort/relaying/logits_decoder.py:45: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  strides = torch.as_tensor(self.strides, dtype=torch.float32, device=device).to(dtype=dtype)
/coding/yolov5-rt-stack/yolort/models/box_head.py:333: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for head_output, grid, shift, stride in zip(head_outputs, grids, shifts, strides):
PyTorch2ONNX graph created successfully
[W] 'Shape tensor cast elision' routine failed with: None
Created NMS plugin 'EfficientNMS_TRT' with attributes: {'plugin_version': '1', 'background_class': -1, 'max_output_boxes': 100, 'score_threshold': 0.35, 'iou_threshold': 0.45, 'score_activation': False, 'box_coding': 0}
Warning: Unsupported operator EfficientNMS_TRT. No schema registered for this operator.
Saved ONNX model to yolov5n6.onnx
Network Description
Input 'images' with shape (1, 3, 640, 640) and dtype DataType.FLOAT
Output 'num_detections' with shape (1, 1) and dtype DataType.INT32
Output 'detection_boxes' with shape (1, 100, 4) and dtype DataType.FLOAT
Output 'detection_scores' with shape (1, 100) and dtype DataType.FLOAT
Output 'detection_classes' with shape (1, 100) and dtype DataType.INT32
Building fp32 Engine in yolov5n6.engine
Using fp32 mode.
[02/12/2022-17:42:08] [TRT] [I] [MemUsageChange] Init CUDA: CPU +176, GPU +0, now: CPU 405, GPU 4987 (MiB)
[02/12/2022-17:42:09] [TRT] [I] ----------------------------------------------------------------
[02/12/2022-17:42:09] [TRT] [I] Input filename:   yolov5n6.onnx
[02/12/2022-17:42:09] [TRT] [I] ONNX IR version:  0.0.8
[02/12/2022-17:42:09] [TRT] [I] Opset version:    11
[02/12/2022-17:42:09] [TRT] [I] Producer name:
[02/12/2022-17:42:09] [TRT] [I] Producer version:
[02/12/2022-17:42:09] [TRT] [I] Domain:
[02/12/2022-17:42:09] [TRT] [I] Model version:    0
[02/12/2022-17:42:09] [TRT] [I] Doc string:
[02/12/2022-17:42:09] [TRT] [I] ----------------------------------------------------------------
[02/12/2022-17:42:09] [TRT] [W] parsers/onnx/onnx2trt_utils.cpp:364: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[02/12/2022-17:42:09] [TRT] [W] parsers/onnx/onnx2trt_utils.cpp:392: One or more weights outside the range of INT32 was clamped
[02/12/2022-17:42:09] [TRT] [W] parsers/onnx/onnx2trt_utils.cpp:392: One or more weights outside the range of INT32 was clamped
[02/12/2022-17:42:09] [TRT] [W] parsers/onnx/onnx2trt_utils.cpp:392: One or more weights outside the range of INT32 was clamped
[02/12/2022-17:42:09] [TRT] [W] parsers/onnx/onnx2trt_utils.cpp:392: One or more weights outside the range of INT32 was clamped
[02/12/2022-17:42:09] [TRT] [W] parsers/onnx/onnx2trt_utils.cpp:392: One or more weights outside the range of INT32 was clamped
[02/12/2022-17:42:09] [TRT] [I] No importer registered for op: EfficientNMS_TRT. Attempting to import as plugin.
[02/12/2022-17:42:09] [TRT] [I] Searching for plugin: EfficientNMS_TRT, plugin_version: 1, plugin_namespace:
[02/12/2022-17:42:09] [TRT] [I] Successfully created plugin: EfficientNMS_TRT
[02/12/2022-17:42:09] [TRT] [I] [MemUsageSnapshot] Builder begin: CPU 477 MiB, GPU 4987 MiB
[02/12/2022-17:42:10] [TRT] [W] TensorRT was linked against cuBLAS/cuBLAS LT 11.6.1 but loaded cuBLAS/cuBLAS LT 11.5.1
[02/12/2022-17:42:10] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +232, GPU +94, now: CPU 711, GPU 5081 (MiB)
[02/12/2022-17:42:10] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +188, GPU +84, now: CPU 899, GPU 5165 (MiB)
[02/12/2022-17:42:10] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[02/12/2022-17:42:53] [TRT] [I] [BlockAssignment] Algorithm Linear took 0.054667ms to assign 150 blocks to 150 nodes requiring 13011862528 bytes.
[02/12/2022-17:42:53] [TRT] [I] Total Activation Memory: 126960640
[02/12/2022-17:42:53] [TRT] [I] Detected 1 inputs and 4 output network tensors.
Serialize engine success, saved as yolov5n6.engine
[02/12/2022-17:42:58] [TRT] [I] Total Host Persistent Memory: 169904
[02/12/2022-17:42:58] [TRT] [I] Total Device Persistent Memory: 11788800
[02/12/2022-17:42:58] [TRT] [I] Total Scratch Memory: 48960768
[02/12/2022-17:42:58] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 3 MiB, GPU 2096 MiB
[02/12/2022-17:42:58] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 29.751ms to assign 8 blocks to 151 nodes requiring 66721280 bytes.
[02/12/2022-17:42:58] [TRT] [W] TensorRT was linked against cuBLAS/cuBLAS LT 11.6.1 but loaded cuBLAS/cuBLAS LT 11.5.1
[02/12/2022-17:42:58] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 2794, GPU 6061 (MiB)
[02/12/2022-17:42:58] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 2794, GPU 6069 (MiB)
[02/12/2022-17:42:58] [TRT] [I] [MemUsageSnapshot] Builder end: CPU 2793 MiB, GPU 6035 MiB

Test the exported TensorRT engine

Actually the above exported TensorRT engine only contains the post-processing (nms). And we wrap the pre-processing named YOLOTransform into a new module PredictorTRT for easy of use.

[9]:
from yolort.runtime import PredictorTRT
[10]:
y_runtime = PredictorTRT(engine_path, device=device)
Loading yolov5n6.engine for TensorRT inference...
[02/12/2022-17:42:58] [TRT] [I] The logger passed into createInferRuntime differs from one already provided for an existing builder, runtime, or refitter. Uses of the global logger, returned by nvinfer1::getLogger(), will return the existing value.

[02/12/2022-17:42:58] [TRT] [I] [MemUsageChange] Init CUDA: CPU +0, GPU +0, now: CPU 2779, GPU 6003 (MiB)
[02/12/2022-17:42:58] [TRT] [I] Loaded engine size: 18 MiB
[02/12/2022-17:42:58] [TRT] [I] [MemUsageSnapshot] deserializeCudaEngine begin: CPU 2797 MiB, GPU 6003 MiB
[02/12/2022-17:42:58] [TRT] [W] TensorRT was linked against cuBLAS/cuBLAS LT 11.6.1 but loaded cuBLAS/cuBLAS LT 11.5.1
[02/12/2022-17:42:58] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +1, GPU +10, now: CPU 2810, GPU 6031 (MiB)
[02/12/2022-17:42:58] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 2810, GPU 6039 (MiB)
[02/12/2022-17:42:58] [TRT] [I] [MemUsageSnapshot] deserializeCudaEngine end: CPU 2809 MiB, GPU 6021 MiB
[02/12/2022-17:42:58] [TRT] [I] [MemUsageSnapshot] ExecutionContext creation begin: CPU 2791 MiB, GPU 6043 MiB
[02/12/2022-17:42:58] [TRT] [W] TensorRT was linked against cuBLAS/cuBLAS LT 11.6.1 but loaded cuBLAS/cuBLAS LT 11.5.1
[02/12/2022-17:42:58] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +10, now: CPU 2791, GPU 6053 (MiB)
[02/12/2022-17:42:58] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 2791, GPU 6061 (MiB)
[02/12/2022-17:42:58] [TRT] [I] [MemUsageSnapshot] ExecutionContext creation end: CPU 2793 MiB, GPU 6145 MiB

Let’s warmup the engine by running inference once for GPU device.

[11]:
y_runtime.warmup()

Inferencing with TensorRT

[12]:
predictions_trt = y_runtime.predict(img_path)
[13]:
predictions_trt
[13]:
[{'scores': tensor([0.88487, 0.84572, 0.78537, 0.72515], device='cuda:0'),
  'labels': tensor([5, 0, 0, 0], device='cuda:0', dtype=torch.int32),
  'boxes': tensor([[ 35.89635, 226.08698, 808.98145, 739.60950],
          [ 50.38255, 387.47092, 240.80370, 898.42114],
          [677.84216, 380.34561, 809.81213, 876.44397],
          [224.94446, 391.35855, 347.93188, 866.14307]], device='cuda:0')}]

Predict as yolort

[14]:
from yolort.models import YOLOv5
[15]:
model = YOLOv5.load_from_yolov5(
    model_path,
    size=(img_size, img_size),
    size_divisible=size_divisible,
    fixed_shape=(img_size, img_size),
    score_thresh=score_thresh,
    nms_thresh=nms_thresh,
)

model = model.eval()
model = model.to(device)

                 from  n    params  module                                  arguments
  0                -1  1      1760  yolort.v5.models.common.Conv            [3, 16, 6, 2, 2]
  1                -1  1      4672  yolort.v5.models.common.Conv            [16, 32, 3, 2]
  2                -1  1      4800  yolort.v5.models.common.C3              [32, 32, 1]
  3                -1  1     18560  yolort.v5.models.common.Conv            [32, 64, 3, 2]
  4                -1  2     29184  yolort.v5.models.common.C3              [64, 64, 2]
  5                -1  1     73984  yolort.v5.models.common.Conv            [64, 128, 3, 2]
  6                -1  3    156928  yolort.v5.models.common.C3              [128, 128, 3]
  7                -1  1    221568  yolort.v5.models.common.Conv            [128, 192, 3, 2]
  8                -1  1    167040  yolort.v5.models.common.C3              [192, 192, 1]
  9                -1  1    442880  yolort.v5.models.common.Conv            [192, 256, 3, 2]
 10                -1  1    296448  yolort.v5.models.common.C3              [256, 256, 1]
 11                -1  1    164608  yolort.v5.models.common.SPPF            [256, 256, 5]
 12                -1  1     49536  yolort.v5.models.common.Conv            [256, 192, 1, 1]
 13                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 14           [-1, 8]  1         0  yolort.v5.models.common.Concat          [1]
 15                -1  1    203904  yolort.v5.models.common.C3              [384, 192, 1, False]
 16                -1  1     24832  yolort.v5.models.common.Conv            [192, 128, 1, 1]
 17                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 18           [-1, 6]  1         0  yolort.v5.models.common.Concat          [1]
 19                -1  1     90880  yolort.v5.models.common.C3              [256, 128, 1, False]
 20                -1  1      8320  yolort.v5.models.common.Conv            [128, 64, 1, 1]
 21                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 22           [-1, 4]  1         0  yolort.v5.models.common.Concat          [1]
 23                -1  1     22912  yolort.v5.models.common.C3              [128, 64, 1, False]
 24                -1  1     36992  yolort.v5.models.common.Conv            [64, 64, 3, 2]
 25          [-1, 20]  1         0  yolort.v5.models.common.Concat          [1]
 26                -1  1     74496  yolort.v5.models.common.C3              [128, 128, 1, False]
 27                -1  1    147712  yolort.v5.models.common.Conv            [128, 128, 3, 2]
 28          [-1, 16]  1         0  yolort.v5.models.common.Concat          [1]
 29                -1  1    179328  yolort.v5.models.common.C3              [256, 192, 1, False]
 30                -1  1    332160  yolort.v5.models.common.Conv            [192, 192, 3, 2]
 31          [-1, 12]  1         0  yolort.v5.models.common.Concat          [1]
 32                -1  1    329216  yolort.v5.models.common.C3              [384, 256, 1, False]
 33  [23, 26, 29, 32]  1    164220  yolort.v5.models.yolo.Detect            [80, [[19, 27, 44, 40, 38, 94], [96, 68, 86, 152, 180, 137], [140, 301, 303, 264, 238, 542], [436, 615, 739, 380, 925, 792]], [64, 128, 192, 256]]
Model Summary: 355 layers, 3246940 parameters, 3246940 gradients, 4.6 GFLOPs

Inferencing with PyTorch

[16]:
predictions_pt = model.predict(img_path)

Verify the detection results between yolort and TensorRT

[17]:
# Testing boxes
torch.testing.assert_close(predictions_pt[0]["boxes"], predictions_trt[0]["boxes"])
# Testing scores
torch.testing.assert_close(predictions_pt[0]["scores"], predictions_trt[0]["scores"])
# Testing labels
torch.testing.assert_close(predictions_pt[0]["labels"], predictions_trt[0]["labels"].to(dtype=torch.int64))

print("Exported model has been tested, and the result looks good!")
Exported model has been tested, and the result looks good!

Visualise the TensorRT detections

[18]:
# Get label names
import requests

# label_path = "https://raw.githubusercontent.com/zhiqwang/yolov5-rt-stack/main/notebooks/assets/coco.names"
label_path = "https://huggingface.co/spaces/zhiqwang/assets/resolve/main/coco.names"
response = requests.get(label_path)
names = response.text

LABELS = []

for label in names.strip().split('\n'):
    LABELS.append(label)

COLORS = color_list()
[19]:
for box, label in zip(predictions_trt[0]['boxes'].tolist(), predictions_trt[0]['labels'].tolist()):
    img_raw = plot_one_box(box, img_raw, color=COLORS[label % len(COLORS)], label=LABELS[label])
[20]:
cv2_imshow(img_raw, imshow_scale=0.5)
../_images/notebooks_onnx-graphsurgeon-inference-tensorrt_30_0.png

View this document as a notebook: https://github.com/zhiqwang/yolov5-rt-stack/blob/main/notebooks/onnx-graphsurgeon-inference-tensorrt.ipynb