Add DAMO-YOLO + Fixes

This commit is contained in:
Marcos Luciano
2023-05-21 17:11:39 -03:00
parent 79d22283c1
commit f9bfd65036
6 changed files with 273 additions and 5 deletions

View File

@@ -22,7 +22,7 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod
* Support for non square models * Support for non square models
* Models benchmarks * Models benchmarks
* **Support for Darknet YOLO models (YOLOv4, etc) using cfg and weights conversion with GPU post-processing** * **Support for Darknet YOLO models (YOLOv4, etc) using cfg and weights conversion with GPU post-processing**
* **Support for YOLO-NAS, PPYOLOE+, PPYOLOE, YOLOX, YOLOR, YOLOv8, YOLOv7, YOLOv6 and YOLOv5 using ONNX conversion with GPU post-processing** * **Support for YOLO-NAS, PPYOLOE+, PPYOLOE, DAMO-YOLO, YOLOX, YOLOR, YOLOv8, YOLOv7, YOLOv6 and YOLOv5 using ONNX conversion with GPU post-processing**
## ##
@@ -42,6 +42,7 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod
* [YOLOv8 usage](docs/YOLOv8.md) * [YOLOv8 usage](docs/YOLOv8.md)
* [YOLOR usage](docs/YOLOR.md) * [YOLOR usage](docs/YOLOR.md)
* [YOLOX usage](docs/YOLOX.md) * [YOLOX usage](docs/YOLOX.md)
* [DAMO-YOLO usage](docs/DAMOYOLO.md)
* [PP-YOLOE / PP-YOLOE+ usage](docs/PPYOLOE.md) * [PP-YOLOE / PP-YOLOE+ usage](docs/PPYOLOE.md)
* [YOLO-NAS usage](docs/YOLONAS.md) * [YOLO-NAS usage](docs/YOLONAS.md)
* [Using your custom model](docs/customModels.md) * [Using your custom model](docs/customModels.md)
@@ -128,6 +129,7 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod
* [YOLOv8](https://github.com/ultralytics/ultralytics) * [YOLOv8](https://github.com/ultralytics/ultralytics)
* [YOLOR](https://github.com/WongKinYiu/yolor) * [YOLOR](https://github.com/WongKinYiu/yolor)
* [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX) * [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)
* [DAMO-YOLO](https://github.com/tinyvision/DAMO-YOLO)
* [PP-YOLOE / PP-YOLOE+](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/ppyoloe) * [PP-YOLOE / PP-YOLOE+](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/ppyoloe)
* [YOLO-NAS](https://github.com/Deci-AI/super-gradients/blob/master/YOLONAS.md) * [YOLO-NAS](https://github.com/Deci-AI/super-gradients/blob/master/YOLONAS.md)
@@ -170,7 +172,7 @@ topk = 300
**NOTE**: ** = The YOLOv4 is trained with the trainvalno5k set, so the mAP is high on val2017 test. **NOTE**: ** = The YOLOv4 is trained with the trainvalno5k set, so the mAP is high on val2017 test.
**NOTE**: The p3.2xlarge instance (AWS) seems to max out at 625-635 FPS on DeepStream even using lighter models. **NOTE**: The V100 GPU decoder seems to max out at 625-635 FPS on DeepStream even using lighter models.
| DeepStream | Precision | Resolution | IoU=0.5:0.95 | IoU=0.5 | IoU=0.75 | FPS<br />(without display) | | DeepStream | Precision | Resolution | IoU=0.5:0.95 | IoU=0.5 | IoU=0.75 | FPS<br />(without display) |
|:------------------:|:---------:|:----------:|:------------:|:-------:|:--------:|:--------------------------:| |:------------------:|:---------:|:----------:|:------------:|:-------:|:--------:|:--------------------------:|

View File

@@ -0,0 +1,24 @@
[property]
gpu-id=0
net-scale-factor=1
model-color-format=0
onnx-file=damoyolo_tinynasL25_S.onnx
model-engine-file=damoyolo_tinynasL25_S.onnx_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=0
parse-bbox-func-name=NvDsInferParseYoloE
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
[class-attrs-all]
nms-iou-threshold=0.45
pre-cluster-threshold=0.25
topk=300

158
docs/DAMOYOLO.md Normal file
View File

@@ -0,0 +1,158 @@
# DAMO-YOLO usage
* [Convert model](#convert-model)
* [Compile the lib](#compile-the-lib)
* [Edit the config_infer_primary_damoyolo file](#edit-the-config_infer_primary_damoyolo-file)
* [Edit the deepstream_app_config file](#edit-the-deepstream_app_config-file)
* [Testing the model](#testing-the-model)
##
### Convert model
#### 1. Download the DAMO-YOLO repo and install the requirements
```
git clone https://github.com/tinyvision/DAMO-YOLO.git
cd DAMO-YOLO
pip3 install -r requirements.txt
pip3 install onnx onnxsim onnxruntime
```
**NOTE**: It is recommended to use Python virtualenv.
#### 2. Copy conversor
Copy the `export_damoyolo.py` file from `DeepStream-Yolo/utils` directory to the `DAMO-YOLO` folder.
#### 3. Download the model
Download the `pth` file from [DAMO-YOLO](https://github.com/tinyvision/DAMO-YOLO) releases (example for DAMO-YOLO-S*)
```
wget https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/DAMO-YOLO/release_model/clean_model_0317/damoyolo_tinynasL25_S_477.pth
```
**NOTE**: You can use your custom model.
#### 4. Convert model
Generate the ONNX model file (example for DAMO-YOLO-S*)
```
python3 export_damoyolo.py -w damoyolo_tinynasL25_S_477.pth -c configs/damoyolo_tinynasL25_S.py --simplify
```
**NOTE**: To change the inference size (defaut: 640)
```
-s SIZE
--size SIZE
-s HEIGHT WIDTH
--size HEIGHT WIDTH
```
Example for 1280
```
-s 1280
```
or
```
-s 1280 1280
```
#### 5. Copy generated files
Copy the generated ONNX model file to the `DeepStream-Yolo` folder.
##
### Compile the lib
Open the `DeepStream-Yolo` folder and compile the lib
* DeepStream 6.2 on x86 platform
```
CUDA_VER=11.8 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.1.1 on x86 platform
```
CUDA_VER=11.7 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.1 on x86 platform
```
CUDA_VER=11.6 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.0.1 / 6.0 on x86 platform
```
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.2 / 6.1.1 / 6.1 on Jetson platform
```
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.0.1 / 6.0 on Jetson platform
```
CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
```
##
### Edit the config_infer_primary_damoyolo file
Edit the `config_infer_primary_damoyolo.txt` file according to your model (example for DAMO-YOLO-S* with 80 classes)
```
[property]
...
onnx-file=damoyolo_tinynasL25_S.onnx
model-engine-file=damoyolo_tinynasL25_S.onnx_b1_gpu0_fp32.engine
...
num-detected-classes=80
...
parse-bbox-func-name=NvDsInferParseYoloE
...
```
**NOTE**: The **DAMO-YOLO** do not resize the input with padding. To get better accuracy, use
```
maintain-aspect-ratio=0
```
##
### Edit the deepstream_app_config file
```
...
[primary-gie]
...
config-file=config_infer_primary_damoyolo.txt
```
##
### Testing the model
```
deepstream-app -c deepstream_app_config.txt
```
**NOTE**: The TensorRT engine file may take a very long time to generate (sometimes more than 10 minutes).
**NOTE**: For more information about custom models configuration (`batch-size`, `network-mode`, etc), please check the [`docs/customModels.md`](customModels.md) file.

View File

@@ -30,7 +30,7 @@ Copy the `export_yolonas.py` file from `DeepStream-Yolo/utils` directory to the
#### 3. Download the model #### 3. Download the model
Download the `pth` file from [YOLO-NAS](https://sghub.deci.ai/) website (example for YOLO-NAS S) Download the `pth` file from [YOLO-NAS](https://sghub.deci.ai/) releases (example for YOLO-NAS S)
``` ```
wget https://sghub.deci.ai/models/yolo_nas_s_coco.pth wget https://sghub.deci.ai/models/yolo_nas_s_coco.pth

View File

@@ -2,8 +2,6 @@
**NOTE**: You can use the main branch of the YOLOX repo to convert all model versions. **NOTE**: You can use the main branch of the YOLOX repo to convert all model versions.
**NOTE**: The yaml file is not required.
* [Convert model](#convert-model) * [Convert model](#convert-model)
* [Compile the lib](#compile-the-lib) * [Compile the lib](#compile-the-lib)
* [Edit the config_infer_primary_yolox file](#edit-the-config_infer_primary_yolox-file) * [Edit the config_infer_primary_yolox file](#edit-the-config_infer_primary_yolox-file)

86
utils/export_damoyolo.py Normal file
View File

@@ -0,0 +1,86 @@
import os
import sys
import argparse
import warnings
import onnx
import torch
import torch.nn as nn
from damo.base_models.core.ops import RepConv, SiLU
from damo.config.base import parse_config
from damo.detectors.detector import build_local_model
from damo.utils.model_utils import replace_module
class DeepStreamOutput(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
boxes = x[1]
scores, classes = torch.max(x[0], 2, keepdim=True)
return torch.cat((boxes, scores, classes), dim=2)
def suppress_warnings():
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)
def damoyolo_export(weights, config_file, device):
config = parse_config(config_file)
config.model.head.export_with_post = True
model = build_local_model(config, device)
ckpt = torch.load(weights, map_location=device)
model.eval()
if 'model' in ckpt:
ckpt = ckpt['model']
model.load_state_dict(ckpt, strict=True)
model = replace_module(model, nn.SiLU, SiLU)
for layer in model.modules():
if isinstance(layer, RepConv):
layer.switch_to_deploy()
model.head.nms = False
return config, model
def main(args):
suppress_warnings()
device = torch.device('cpu')
cfg, model = damoyolo_export(args.weights, args.config, device)
model = nn.Sequential(model, DeepStreamOutput())
img_size = args.size * 2 if len(args.size) == 1 else args.size
onnx_input_im = torch.zeros(1, 3, *img_size).to(device)
onnx_output_file = cfg.miscs['exp_name'] + '.onnx'
torch.onnx.export(model, onnx_input_im, onnx_output_file, verbose=False, opset_version=args.opset,
do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes=None)
if args.simplify:
import onnxsim
model_onnx = onnx.load(onnx_output_file)
model_onnx, _ = onnxsim.simplify(model_onnx)
onnx.save(model_onnx, onnx_output_file)
def parse_args():
parser = argparse.ArgumentParser(description='DeepStream DAMO-YOLO conversion')
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pth) file path (required)')
parser.add_argument('-c', '--config', required=True, help='Input config (.py) file path (required)')
parser.add_argument('-s', '--size', nargs='+', type=int, default=[640], help='Inference size [H,W] (default [640])')
parser.add_argument('--opset', type=int, default=11, help='ONNX opset version')
parser.add_argument('--simplify', action='store_true', help='ONNX simplify model')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit('Invalid weights file')
if not os.path.isfile(args.config):
raise SystemExit('Invalid config file')
return args
if __name__ == '__main__':
args = parse_args()
sys.exit(main(args))