Add RT-DETR

This commit is contained in:
Marcos Luciano
2023-11-01 18:23:28 -03:00
parent 000bcd676d
commit 1177624dd2
4 changed files with 349 additions and 4 deletions

View File

@@ -22,16 +22,14 @@ NVIDIA DeepStream SDK 6.3 / 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 / 5.1 configuration
* Support for non square models * Support for non square models
* Models benchmarks * Models benchmarks
* Support for Darknet models (YOLOv4, etc) using cfg and weights conversion with GPU post-processing * Support for Darknet models (YOLOv4, etc) using cfg and weights conversion with GPU post-processing
* Support for YOLO-NAS, PPYOLOE+, PPYOLOE, DAMO-YOLO, YOLOX, YOLOR, YOLOv8, YOLOv7, YOLOv6 and YOLOv5 using ONNX conversion with GPU post-processing * Support for RT-DETR, YOLO-NAS, PPYOLOE+, PPYOLOE, DAMO-YOLO, YOLOX, YOLOR, YOLOv8, YOLOv7, YOLOv6 and YOLOv5 using ONNX conversion with GPU post-processing
* GPU bbox parser (it is slightly slower than CPU bbox parser on V100 GPU tests) * GPU bbox parser (it is slightly slower than CPU bbox parser on V100 GPU tests)
* Support for DeepStream 5.1 * Support for DeepStream 5.1
* Custom ONNX model parser (`NvDsInferYoloCudaEngineGet`) * Custom ONNX model parser (`NvDsInferYoloCudaEngineGet`)
* Dynamic batch-size for Darknet and ONNX exported models * Dynamic batch-size for Darknet and ONNX exported models
* INT8 calibration (PTQ) for Darknet and ONNX exported models * INT8 calibration (PTQ) for Darknet and ONNX exported models
* New output structure (fix wrong output on DeepStream < 6.2) - it need to export the ONNX model with the new export file, generate the TensorRT engine again with the updated files, and use the new config_infer_primary file according to your model * New output structure (fix wrong output on DeepStream < 6.2) - it need to export the ONNX model with the new export file, generate the TensorRT engine again with the updated files, and use the new config_infer_primary file according to your model
* **YOLO-Pose: https://github.com/marcoslucianops/DeepStream-Yolo-Pose** * **RT-DETR (https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetr_pytorch)**
* **YOLO-Seg: https://github.com/marcoslucianops/DeepStream-Yolo-Seg**
* **YOLO-Face: https://github.com/marcoslucianops/DeepStream-Yolo-Face**
## ##
@@ -54,6 +52,7 @@ NVIDIA DeepStream SDK 6.3 / 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 / 5.1 configuration
* [DAMO-YOLO usage](docs/DAMOYOLO.md) * [DAMO-YOLO usage](docs/DAMOYOLO.md)
* [PP-YOLOE / PP-YOLOE+ usage](docs/PPYOLOE.md) * [PP-YOLOE / PP-YOLOE+ usage](docs/PPYOLOE.md)
* [YOLO-NAS usage](docs/YOLONAS.md) * [YOLO-NAS usage](docs/YOLONAS.md)
* [RT-DETR usage](docs/RTDETR.md)
* [Using your custom model](docs/customModels.md) * [Using your custom model](docs/customModels.md)
* [Multiple YOLO GIEs](docs/multipleGIEs.md) * [Multiple YOLO GIEs](docs/multipleGIEs.md)

View File

@@ -0,0 +1,27 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
onnx-file=rtdetr_r50vd_6x_coco_from_paddle.onnx
model-engine-file=model_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=0
#workspace-size=2000
parse-bbox-func-name=NvDsInferParseYolo
#parse-bbox-func-name=NvDsInferParseYoloCuda
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
engine-create-func-name=NvDsInferYoloCudaEngineGet
[class-attrs-all]
nms-iou-threshold=0.45
pre-cluster-threshold=0.25
topk=300

198
docs/RTDETR.md Normal file
View File

@@ -0,0 +1,198 @@
# RT_DETR usage
**NOTE**: For it is supported only the https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetr_pytorch version.
* [Convert model](#convert-model)
* [Compile the lib](#compile-the-lib)
* [Edit the config_infer_primary_rtdetr file](#edit-the-config_infer_primary_rtdetr-file)
* [Edit the deepstream_app_config file](#edit-the-deepstream_app_config-file)
* [Testing the model](#testing-the-model)
##
### Convert model
#### 1. Download the RT-DETR repo and install the requirements
```
git clone https://github.com/lyuwenyu/RT-DETR.git
cd RT-DETR/rtdetr_pytorch
pip3 install -r requirements.txt
pip3 install onnx onnxsim onnxruntime
```
**NOTE**: It is recommended to use Python virtualenv.
#### 2. Copy conversor
Copy the `export_rtdetr_pytorch.py` file from `DeepStream-Yolo/utils` directory to the `RT-DETR/rtdetr_pytorch` folder.
#### 3. Download the model
Download the `pth` file from [RT-DETR](https://github.com/lyuwenyu/storage/releases) releases (example for RT-DETR-R50)
```
wget https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_6x_coco_from_paddle.pth
```
**NOTE**: You can use your custom model.
#### 4. Convert model
Generate the ONNX model file (example for RT-DETR-R50)
```
python3 export_rtdetr_pytorch.py -w rtdetr_r50vd_6x_coco_from_paddle.pth -c configs/rtdetr/rtdetr_r50vd_6x_coco.yml --dynamic
```
**NOTE**: To change the inference size (defaut: 640)
```
-s SIZE
--size SIZE
-s HEIGHT WIDTH
--size HEIGHT WIDTH
```
Example for 1280
```
-s 1280
```
or
```
-s 1280 1280
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6.0)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6.1)
```
--dynamic
```
**NOTE**: To use static batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using the DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 16.
```
--opset 12
```
#### 5. Copy generated files
Copy the generated ONNX model file and labels.txt file (if generated) to the `DeepStream-Yolo` folder.
##
### Compile the lib
Open the `DeepStream-Yolo` folder and compile the lib
* DeepStream 6.3 on x86 platform
```
CUDA_VER=12.1 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.2 on x86 platform
```
CUDA_VER=11.8 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.1.1 on x86 platform
```
CUDA_VER=11.7 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.1 on x86 platform
```
CUDA_VER=11.6 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.0.1 / 6.0 on x86 platform
```
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 5.1 on x86 platform
```
CUDA_VER=11.1 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.3 / 6.2 / 6.1.1 / 6.1 on Jetson platform
```
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.0.1 / 6.0 / 5.1 on Jetson platform
```
CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
```
##
### Edit the config_infer_primary_rtdetr file
Edit the `config_infer_primary_rtdetr.txt` file according to your model (example for RT-DETR-R50 with 80 classes)
```
[property]
...
onnx-file=rtdetr_r50vd_6x_coco_from_paddle.onnx
...
num-detected-classes=80
...
parse-bbox-func-name=NvDsInferParseYolo
...
```
**NOTE**: The **RT-DETR** do not resize the input with padding. To get better accuracy, use
```
[property]
...
maintain-aspect-ratio=0
...
```
##
### Edit the deepstream_app_config file
```
...
[primary-gie]
...
config-file=config_infer_primary_rtdetr.txt
```
##
### Testing the model
```
deepstream-app -c deepstream_app_config.txt
```
**NOTE**: The TensorRT engine file may take a very long time to generate (sometimes more than 10 minutes).
**NOTE**: For more information about custom models configuration (`batch-size`, `network-mode`, etc), please check the [`docs/customModels.md`](customModels.md) file.

121
utils/export_rtdetr_pytorch.py Executable file
View File

@@ -0,0 +1,121 @@
import os
import sys
import argparse
import warnings
import onnx
import torch
import torch.nn as nn
from src.core import YAMLConfig
class DeepStreamOutput(nn.Module):
def __init__(self, img_size):
self.img_size = img_size
super().__init__()
def forward(self, x):
boxes = x['pred_boxes']
boxes[:, :, [0, 2]] *= self.img_size[1]
boxes[:, :, [1, 3]] *= self.img_size[0]
scores, classes = torch.max(x['pred_logits'], 2, keepdim=True)
classes = classes.float()
return boxes, scores, classes
class DeepStreamInput(nn.Module):
def __init__(self, img_size, device):
self.img_size = img_size
self.device = device
super().__init__()
def forward(self, x):
size = torch.tensor([[*self.img_size]]).to(self.device)
return x, size
def suppress_warnings():
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)
def rtdetr_pytorch_export(weights, cfg_file, device):
cfg = YAMLConfig(cfg_file, resume=weights)
checkpoint = torch.load(weights, map_location=device)
if 'ema' in checkpoint:
state = checkpoint['ema']['module']
else:
state = checkpoint['model']
cfg.model.load_state_dict(state)
return cfg.model.deploy()
def main(args):
suppress_warnings()
print('\nStarting: %s' % args.weights)
print('Opening RT-DETR PyTorch model\n')
device = torch.device('cpu')
model = rtdetr_pytorch_export(args.weights, args.config, device)
img_size = args.size * 2 if len(args.size) == 1 else args.size
model = nn.Sequential(model, DeepStreamOutput(img_size))
onnx_input_im = torch.zeros(args.batch, 3, *img_size).to(device)
onnx_output_file = os.path.basename(args.weights).split('.pt')[0] + '.onnx'
dynamic_axes = {
'input': {
0: 'batch'
},
'boxes': {
0: 'batch'
},
'scores': {
0: 'batch'
},
'classes': {
0: 'batch'
}
}
print('\nExporting the model to ONNX')
torch.onnx.export(model, onnx_input_im, onnx_output_file, verbose=False, opset_version=args.opset,
do_constant_folding=True, input_names=['input'], output_names=['boxes', 'scores', 'classes'],
dynamic_axes=dynamic_axes if args.dynamic else None)
if args.simplify:
print('Simplifying the ONNX model')
import onnxsim
model_onnx = onnx.load(onnx_output_file)
model_onnx, _ = onnxsim.simplify(model_onnx)
onnx.save(model_onnx, onnx_output_file)
print('Done: %s\n' % onnx_output_file)
def parse_args():
parser = argparse.ArgumentParser(description='DeepStream RT-DETR PyTorch conversion')
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pth) file path (required)')
parser.add_argument('-c', '--config', required=True, help='Input YAML (.yml) file path (required)')
parser.add_argument('-s', '--size', nargs='+', type=int, default=[640], help='Inference size [H,W] (default [640])')
parser.add_argument('--opset', type=int, default=16, help='ONNX opset version')
parser.add_argument('--simplify', action='store_true', help='ONNX simplify model')
parser.add_argument('--dynamic', action='store_true', help='Dynamic batch-size')
parser.add_argument('--batch', type=int, default=1, help='Static batch-size')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit('Invalid weights file')
if not os.path.isfile(args.config):
raise SystemExit('Invalid config file')
if args.dynamic and args.batch > 1:
raise SystemExit('Cannot set dynamic batch-size and static batch-size at same time')
return args
if __name__ == '__main__':
args = parse_args()
sys.exit(main(args))