Add DAMO-YOLO + Fixes
This commit is contained in:
@@ -22,7 +22,7 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod
|
|||||||
* Support for non square models
|
* Support for non square models
|
||||||
* Models benchmarks
|
* Models benchmarks
|
||||||
* **Support for Darknet YOLO models (YOLOv4, etc) using cfg and weights conversion with GPU post-processing**
|
* **Support for Darknet YOLO models (YOLOv4, etc) using cfg and weights conversion with GPU post-processing**
|
||||||
* **Support for YOLO-NAS, PPYOLOE+, PPYOLOE, YOLOX, YOLOR, YOLOv8, YOLOv7, YOLOv6 and YOLOv5 using ONNX conversion with GPU post-processing**
|
* **Support for YOLO-NAS, PPYOLOE+, PPYOLOE, DAMO-YOLO, YOLOX, YOLOR, YOLOv8, YOLOv7, YOLOv6 and YOLOv5 using ONNX conversion with GPU post-processing**
|
||||||
|
|
||||||
##
|
##
|
||||||
|
|
||||||
@@ -42,6 +42,7 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod
|
|||||||
* [YOLOv8 usage](docs/YOLOv8.md)
|
* [YOLOv8 usage](docs/YOLOv8.md)
|
||||||
* [YOLOR usage](docs/YOLOR.md)
|
* [YOLOR usage](docs/YOLOR.md)
|
||||||
* [YOLOX usage](docs/YOLOX.md)
|
* [YOLOX usage](docs/YOLOX.md)
|
||||||
|
* [DAMO-YOLO usage](docs/DAMOYOLO.md)
|
||||||
* [PP-YOLOE / PP-YOLOE+ usage](docs/PPYOLOE.md)
|
* [PP-YOLOE / PP-YOLOE+ usage](docs/PPYOLOE.md)
|
||||||
* [YOLO-NAS usage](docs/YOLONAS.md)
|
* [YOLO-NAS usage](docs/YOLONAS.md)
|
||||||
* [Using your custom model](docs/customModels.md)
|
* [Using your custom model](docs/customModels.md)
|
||||||
@@ -128,6 +129,7 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod
|
|||||||
* [YOLOv8](https://github.com/ultralytics/ultralytics)
|
* [YOLOv8](https://github.com/ultralytics/ultralytics)
|
||||||
* [YOLOR](https://github.com/WongKinYiu/yolor)
|
* [YOLOR](https://github.com/WongKinYiu/yolor)
|
||||||
* [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)
|
* [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)
|
||||||
|
* [DAMO-YOLO](https://github.com/tinyvision/DAMO-YOLO)
|
||||||
* [PP-YOLOE / PP-YOLOE+](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/ppyoloe)
|
* [PP-YOLOE / PP-YOLOE+](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/ppyoloe)
|
||||||
* [YOLO-NAS](https://github.com/Deci-AI/super-gradients/blob/master/YOLONAS.md)
|
* [YOLO-NAS](https://github.com/Deci-AI/super-gradients/blob/master/YOLONAS.md)
|
||||||
|
|
||||||
@@ -170,7 +172,7 @@ topk = 300
|
|||||||
|
|
||||||
**NOTE**: ** = The YOLOv4 is trained with the trainvalno5k set, so the mAP is high on val2017 test.
|
**NOTE**: ** = The YOLOv4 is trained with the trainvalno5k set, so the mAP is high on val2017 test.
|
||||||
|
|
||||||
**NOTE**: The p3.2xlarge instance (AWS) seems to max out at 625-635 FPS on DeepStream even using lighter models.
|
**NOTE**: The V100 GPU decoder seems to max out at 625-635 FPS on DeepStream even using lighter models.
|
||||||
|
|
||||||
| DeepStream | Precision | Resolution | IoU=0.5:0.95 | IoU=0.5 | IoU=0.75 | FPS<br />(without display) |
|
| DeepStream | Precision | Resolution | IoU=0.5:0.95 | IoU=0.5 | IoU=0.75 | FPS<br />(without display) |
|
||||||
|:------------------:|:---------:|:----------:|:------------:|:-------:|:--------:|:--------------------------:|
|
|:------------------:|:---------:|:----------:|:------------:|:-------:|:--------:|:--------------------------:|
|
||||||
|
|||||||
24
config_infer_primary_damoyolo.txt
Normal file
24
config_infer_primary_damoyolo.txt
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
[property]
|
||||||
|
gpu-id=0
|
||||||
|
net-scale-factor=1
|
||||||
|
model-color-format=0
|
||||||
|
onnx-file=damoyolo_tinynasL25_S.onnx
|
||||||
|
model-engine-file=damoyolo_tinynasL25_S.onnx_b1_gpu0_fp32.engine
|
||||||
|
#int8-calib-file=calib.table
|
||||||
|
labelfile-path=labels.txt
|
||||||
|
batch-size=1
|
||||||
|
network-mode=0
|
||||||
|
num-detected-classes=80
|
||||||
|
interval=0
|
||||||
|
gie-unique-id=1
|
||||||
|
process-mode=1
|
||||||
|
network-type=0
|
||||||
|
cluster-mode=2
|
||||||
|
maintain-aspect-ratio=0
|
||||||
|
parse-bbox-func-name=NvDsInferParseYoloE
|
||||||
|
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
|
||||||
|
|
||||||
|
[class-attrs-all]
|
||||||
|
nms-iou-threshold=0.45
|
||||||
|
pre-cluster-threshold=0.25
|
||||||
|
topk=300
|
||||||
158
docs/DAMOYOLO.md
Normal file
158
docs/DAMOYOLO.md
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
# DAMO-YOLO usage
|
||||||
|
|
||||||
|
* [Convert model](#convert-model)
|
||||||
|
* [Compile the lib](#compile-the-lib)
|
||||||
|
* [Edit the config_infer_primary_damoyolo file](#edit-the-config_infer_primary_damoyolo-file)
|
||||||
|
* [Edit the deepstream_app_config file](#edit-the-deepstream_app_config-file)
|
||||||
|
* [Testing the model](#testing-the-model)
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
### Convert model
|
||||||
|
|
||||||
|
#### 1. Download the DAMO-YOLO repo and install the requirements
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://github.com/tinyvision/DAMO-YOLO.git
|
||||||
|
cd DAMO-YOLO
|
||||||
|
pip3 install -r requirements.txt
|
||||||
|
pip3 install onnx onnxsim onnxruntime
|
||||||
|
```
|
||||||
|
|
||||||
|
**NOTE**: It is recommended to use Python virtualenv.
|
||||||
|
|
||||||
|
#### 2. Copy conversor
|
||||||
|
|
||||||
|
Copy the `export_damoyolo.py` file from `DeepStream-Yolo/utils` directory to the `DAMO-YOLO` folder.
|
||||||
|
|
||||||
|
#### 3. Download the model
|
||||||
|
|
||||||
|
Download the `pth` file from [DAMO-YOLO](https://github.com/tinyvision/DAMO-YOLO) releases (example for DAMO-YOLO-S*)
|
||||||
|
|
||||||
|
```
|
||||||
|
wget https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/DAMO-YOLO/release_model/clean_model_0317/damoyolo_tinynasL25_S_477.pth
|
||||||
|
```
|
||||||
|
|
||||||
|
**NOTE**: You can use your custom model.
|
||||||
|
|
||||||
|
#### 4. Convert model
|
||||||
|
|
||||||
|
Generate the ONNX model file (example for DAMO-YOLO-S*)
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 export_damoyolo.py -w damoyolo_tinynasL25_S_477.pth -c configs/damoyolo_tinynasL25_S.py --simplify
|
||||||
|
```
|
||||||
|
|
||||||
|
**NOTE**: To change the inference size (defaut: 640)
|
||||||
|
|
||||||
|
```
|
||||||
|
-s SIZE
|
||||||
|
--size SIZE
|
||||||
|
-s HEIGHT WIDTH
|
||||||
|
--size HEIGHT WIDTH
|
||||||
|
```
|
||||||
|
|
||||||
|
Example for 1280
|
||||||
|
|
||||||
|
```
|
||||||
|
-s 1280
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```
|
||||||
|
-s 1280 1280
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 5. Copy generated files
|
||||||
|
|
||||||
|
Copy the generated ONNX model file to the `DeepStream-Yolo` folder.
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
### Compile the lib
|
||||||
|
|
||||||
|
Open the `DeepStream-Yolo` folder and compile the lib
|
||||||
|
|
||||||
|
* DeepStream 6.2 on x86 platform
|
||||||
|
|
||||||
|
```
|
||||||
|
CUDA_VER=11.8 make -C nvdsinfer_custom_impl_Yolo
|
||||||
|
```
|
||||||
|
|
||||||
|
* DeepStream 6.1.1 on x86 platform
|
||||||
|
|
||||||
|
```
|
||||||
|
CUDA_VER=11.7 make -C nvdsinfer_custom_impl_Yolo
|
||||||
|
```
|
||||||
|
|
||||||
|
* DeepStream 6.1 on x86 platform
|
||||||
|
|
||||||
|
```
|
||||||
|
CUDA_VER=11.6 make -C nvdsinfer_custom_impl_Yolo
|
||||||
|
```
|
||||||
|
|
||||||
|
* DeepStream 6.0.1 / 6.0 on x86 platform
|
||||||
|
|
||||||
|
```
|
||||||
|
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
|
||||||
|
```
|
||||||
|
|
||||||
|
* DeepStream 6.2 / 6.1.1 / 6.1 on Jetson platform
|
||||||
|
|
||||||
|
```
|
||||||
|
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
|
||||||
|
```
|
||||||
|
|
||||||
|
* DeepStream 6.0.1 / 6.0 on Jetson platform
|
||||||
|
|
||||||
|
```
|
||||||
|
CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
|
||||||
|
```
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
### Edit the config_infer_primary_damoyolo file
|
||||||
|
|
||||||
|
Edit the `config_infer_primary_damoyolo.txt` file according to your model (example for DAMO-YOLO-S* with 80 classes)
|
||||||
|
|
||||||
|
```
|
||||||
|
[property]
|
||||||
|
...
|
||||||
|
onnx-file=damoyolo_tinynasL25_S.onnx
|
||||||
|
model-engine-file=damoyolo_tinynasL25_S.onnx_b1_gpu0_fp32.engine
|
||||||
|
...
|
||||||
|
num-detected-classes=80
|
||||||
|
...
|
||||||
|
parse-bbox-func-name=NvDsInferParseYoloE
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
**NOTE**: The **DAMO-YOLO** do not resize the input with padding. To get better accuracy, use
|
||||||
|
|
||||||
|
```
|
||||||
|
maintain-aspect-ratio=0
|
||||||
|
```
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
### Edit the deepstream_app_config file
|
||||||
|
|
||||||
|
```
|
||||||
|
...
|
||||||
|
[primary-gie]
|
||||||
|
...
|
||||||
|
config-file=config_infer_primary_damoyolo.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
### Testing the model
|
||||||
|
|
||||||
|
```
|
||||||
|
deepstream-app -c deepstream_app_config.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
**NOTE**: The TensorRT engine file may take a very long time to generate (sometimes more than 10 minutes).
|
||||||
|
|
||||||
|
**NOTE**: For more information about custom models configuration (`batch-size`, `network-mode`, etc), please check the [`docs/customModels.md`](customModels.md) file.
|
||||||
@@ -30,7 +30,7 @@ Copy the `export_yolonas.py` file from `DeepStream-Yolo/utils` directory to the
|
|||||||
|
|
||||||
#### 3. Download the model
|
#### 3. Download the model
|
||||||
|
|
||||||
Download the `pth` file from [YOLO-NAS](https://sghub.deci.ai/) website (example for YOLO-NAS S)
|
Download the `pth` file from [YOLO-NAS](https://sghub.deci.ai/) releases (example for YOLO-NAS S)
|
||||||
|
|
||||||
```
|
```
|
||||||
wget https://sghub.deci.ai/models/yolo_nas_s_coco.pth
|
wget https://sghub.deci.ai/models/yolo_nas_s_coco.pth
|
||||||
|
|||||||
@@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
**NOTE**: You can use the main branch of the YOLOX repo to convert all model versions.
|
**NOTE**: You can use the main branch of the YOLOX repo to convert all model versions.
|
||||||
|
|
||||||
**NOTE**: The yaml file is not required.
|
|
||||||
|
|
||||||
* [Convert model](#convert-model)
|
* [Convert model](#convert-model)
|
||||||
* [Compile the lib](#compile-the-lib)
|
* [Compile the lib](#compile-the-lib)
|
||||||
* [Edit the config_infer_primary_yolox file](#edit-the-config_infer_primary_yolox-file)
|
* [Edit the config_infer_primary_yolox file](#edit-the-config_infer_primary_yolox-file)
|
||||||
|
|||||||
86
utils/export_damoyolo.py
Normal file
86
utils/export_damoyolo.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import warnings
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from damo.base_models.core.ops import RepConv, SiLU
|
||||||
|
from damo.config.base import parse_config
|
||||||
|
from damo.detectors.detector import build_local_model
|
||||||
|
from damo.utils.model_utils import replace_module
|
||||||
|
|
||||||
|
|
||||||
|
class DeepStreamOutput(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
boxes = x[1]
|
||||||
|
scores, classes = torch.max(x[0], 2, keepdim=True)
|
||||||
|
return torch.cat((boxes, scores, classes), dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
def suppress_warnings():
|
||||||
|
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
|
|
||||||
|
|
||||||
|
def damoyolo_export(weights, config_file, device):
|
||||||
|
config = parse_config(config_file)
|
||||||
|
config.model.head.export_with_post = True
|
||||||
|
model = build_local_model(config, device)
|
||||||
|
ckpt = torch.load(weights, map_location=device)
|
||||||
|
model.eval()
|
||||||
|
if 'model' in ckpt:
|
||||||
|
ckpt = ckpt['model']
|
||||||
|
model.load_state_dict(ckpt, strict=True)
|
||||||
|
model = replace_module(model, nn.SiLU, SiLU)
|
||||||
|
for layer in model.modules():
|
||||||
|
if isinstance(layer, RepConv):
|
||||||
|
layer.switch_to_deploy()
|
||||||
|
model.head.nms = False
|
||||||
|
return config, model
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
suppress_warnings()
|
||||||
|
device = torch.device('cpu')
|
||||||
|
cfg, model = damoyolo_export(args.weights, args.config, device)
|
||||||
|
|
||||||
|
model = nn.Sequential(model, DeepStreamOutput())
|
||||||
|
|
||||||
|
img_size = args.size * 2 if len(args.size) == 1 else args.size
|
||||||
|
|
||||||
|
onnx_input_im = torch.zeros(1, 3, *img_size).to(device)
|
||||||
|
onnx_output_file = cfg.miscs['exp_name'] + '.onnx'
|
||||||
|
|
||||||
|
torch.onnx.export(model, onnx_input_im, onnx_output_file, verbose=False, opset_version=args.opset,
|
||||||
|
do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes=None)
|
||||||
|
|
||||||
|
if args.simplify:
|
||||||
|
import onnxsim
|
||||||
|
model_onnx = onnx.load(onnx_output_file)
|
||||||
|
model_onnx, _ = onnxsim.simplify(model_onnx)
|
||||||
|
onnx.save(model_onnx, onnx_output_file)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='DeepStream DAMO-YOLO conversion')
|
||||||
|
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pth) file path (required)')
|
||||||
|
parser.add_argument('-c', '--config', required=True, help='Input config (.py) file path (required)')
|
||||||
|
parser.add_argument('-s', '--size', nargs='+', type=int, default=[640], help='Inference size [H,W] (default [640])')
|
||||||
|
parser.add_argument('--opset', type=int, default=11, help='ONNX opset version')
|
||||||
|
parser.add_argument('--simplify', action='store_true', help='ONNX simplify model')
|
||||||
|
args = parser.parse_args()
|
||||||
|
if not os.path.isfile(args.weights):
|
||||||
|
raise SystemExit('Invalid weights file')
|
||||||
|
if not os.path.isfile(args.config):
|
||||||
|
raise SystemExit('Invalid config file')
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
sys.exit(main(args))
|
||||||
Reference in New Issue
Block a user