diff --git a/README.md b/README.md index 5694949..c2f017e 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod * Support for non square models * Models benchmarks * **Support for Darknet YOLO models (YOLOv4, etc) using cfg and weights conversion with GPU post-processing** -* **Support for YOLO-NAS, PPYOLOE+, PPYOLOE, YOLOX, YOLOR, YOLOv8, YOLOv7, YOLOv6 and YOLOv5 using ONNX conversion with GPU post-processing** +* **Support for YOLO-NAS, PPYOLOE+, PPYOLOE, DAMO-YOLO, YOLOX, YOLOR, YOLOv8, YOLOv7, YOLOv6 and YOLOv5 using ONNX conversion with GPU post-processing** ## @@ -42,6 +42,7 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod * [YOLOv8 usage](docs/YOLOv8.md) * [YOLOR usage](docs/YOLOR.md) * [YOLOX usage](docs/YOLOX.md) +* [DAMO-YOLO usage](docs/DAMOYOLO.md) * [PP-YOLOE / PP-YOLOE+ usage](docs/PPYOLOE.md) * [YOLO-NAS usage](docs/YOLONAS.md) * [Using your custom model](docs/customModels.md) @@ -128,6 +129,7 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod * [YOLOv8](https://github.com/ultralytics/ultralytics) * [YOLOR](https://github.com/WongKinYiu/yolor) * [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX) +* [DAMO-YOLO](https://github.com/tinyvision/DAMO-YOLO) * [PP-YOLOE / PP-YOLOE+](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/ppyoloe) * [YOLO-NAS](https://github.com/Deci-AI/super-gradients/blob/master/YOLONAS.md) @@ -170,7 +172,7 @@ topk = 300 **NOTE**: ** = The YOLOv4 is trained with the trainvalno5k set, so the mAP is high on val2017 test. -**NOTE**: The p3.2xlarge instance (AWS) seems to max out at 625-635 FPS on DeepStream even using lighter models. +**NOTE**: The V100 GPU decoder seems to max out at 625-635 FPS on DeepStream even using lighter models. | DeepStream | Precision | Resolution | IoU=0.5:0.95 | IoU=0.5 | IoU=0.75 | FPS
(without display) | |:------------------:|:---------:|:----------:|:------------:|:-------:|:--------:|:--------------------------:| diff --git a/config_infer_primary_damoyolo.txt b/config_infer_primary_damoyolo.txt new file mode 100644 index 0000000..6ab6541 --- /dev/null +++ b/config_infer_primary_damoyolo.txt @@ -0,0 +1,24 @@ +[property] +gpu-id=0 +net-scale-factor=1 +model-color-format=0 +onnx-file=damoyolo_tinynasL25_S.onnx +model-engine-file=damoyolo_tinynasL25_S.onnx_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=0 +parse-bbox-func-name=NvDsInferParseYoloE +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so + +[class-attrs-all] +nms-iou-threshold=0.45 +pre-cluster-threshold=0.25 +topk=300 diff --git a/docs/DAMOYOLO.md b/docs/DAMOYOLO.md new file mode 100644 index 0000000..a8be40c --- /dev/null +++ b/docs/DAMOYOLO.md @@ -0,0 +1,158 @@ +# DAMO-YOLO usage + +* [Convert model](#convert-model) +* [Compile the lib](#compile-the-lib) +* [Edit the config_infer_primary_damoyolo file](#edit-the-config_infer_primary_damoyolo-file) +* [Edit the deepstream_app_config file](#edit-the-deepstream_app_config-file) +* [Testing the model](#testing-the-model) + +## + +### Convert model + +#### 1. Download the DAMO-YOLO repo and install the requirements + +``` +git clone https://github.com/tinyvision/DAMO-YOLO.git +cd DAMO-YOLO +pip3 install -r requirements.txt +pip3 install onnx onnxsim onnxruntime +``` + +**NOTE**: It is recommended to use Python virtualenv. + +#### 2. Copy conversor + +Copy the `export_damoyolo.py` file from `DeepStream-Yolo/utils` directory to the `DAMO-YOLO` folder. + +#### 3. Download the model + +Download the `pth` file from [DAMO-YOLO](https://github.com/tinyvision/DAMO-YOLO) releases (example for DAMO-YOLO-S*) + +``` +wget https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/DAMO-YOLO/release_model/clean_model_0317/damoyolo_tinynasL25_S_477.pth +``` + +**NOTE**: You can use your custom model. + +#### 4. Convert model + +Generate the ONNX model file (example for DAMO-YOLO-S*) + +``` +python3 export_damoyolo.py -w damoyolo_tinynasL25_S_477.pth -c configs/damoyolo_tinynasL25_S.py --simplify +``` + +**NOTE**: To change the inference size (defaut: 640) + +``` +-s SIZE +--size SIZE +-s HEIGHT WIDTH +--size HEIGHT WIDTH +``` + +Example for 1280 + +``` +-s 1280 +``` + +or + +``` +-s 1280 1280 +``` + +#### 5. Copy generated files + +Copy the generated ONNX model file to the `DeepStream-Yolo` folder. + +## + +### Compile the lib + +Open the `DeepStream-Yolo` folder and compile the lib + +* DeepStream 6.2 on x86 platform + + ``` + CUDA_VER=11.8 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.1.1 on x86 platform + + ``` + CUDA_VER=11.7 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.1 on x86 platform + + ``` + CUDA_VER=11.6 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.0.1 / 6.0 on x86 platform + + ``` + CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.2 / 6.1.1 / 6.1 on Jetson platform + + ``` + CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.0.1 / 6.0 on Jetson platform + + ``` + CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo + ``` + +## + +### Edit the config_infer_primary_damoyolo file + +Edit the `config_infer_primary_damoyolo.txt` file according to your model (example for DAMO-YOLO-S* with 80 classes) + +``` +[property] +... +onnx-file=damoyolo_tinynasL25_S.onnx +model-engine-file=damoyolo_tinynasL25_S.onnx_b1_gpu0_fp32.engine +... +num-detected-classes=80 +... +parse-bbox-func-name=NvDsInferParseYoloE +... +``` + +**NOTE**: The **DAMO-YOLO** do not resize the input with padding. To get better accuracy, use + +``` +maintain-aspect-ratio=0 +``` + +## + +### Edit the deepstream_app_config file + +``` +... +[primary-gie] +... +config-file=config_infer_primary_damoyolo.txt +``` + +## + +### Testing the model + +``` +deepstream-app -c deepstream_app_config.txt +``` + +**NOTE**: The TensorRT engine file may take a very long time to generate (sometimes more than 10 minutes). + +**NOTE**: For more information about custom models configuration (`batch-size`, `network-mode`, etc), please check the [`docs/customModels.md`](customModels.md) file. diff --git a/docs/YOLONAS.md b/docs/YOLONAS.md index cce48ab..eb94fa6 100644 --- a/docs/YOLONAS.md +++ b/docs/YOLONAS.md @@ -30,7 +30,7 @@ Copy the `export_yolonas.py` file from `DeepStream-Yolo/utils` directory to the #### 3. Download the model -Download the `pth` file from [YOLO-NAS](https://sghub.deci.ai/) website (example for YOLO-NAS S) +Download the `pth` file from [YOLO-NAS](https://sghub.deci.ai/) releases (example for YOLO-NAS S) ``` wget https://sghub.deci.ai/models/yolo_nas_s_coco.pth diff --git a/docs/YOLOX.md b/docs/YOLOX.md index f1beea7..da8835e 100644 --- a/docs/YOLOX.md +++ b/docs/YOLOX.md @@ -2,8 +2,6 @@ **NOTE**: You can use the main branch of the YOLOX repo to convert all model versions. -**NOTE**: The yaml file is not required. - * [Convert model](#convert-model) * [Compile the lib](#compile-the-lib) * [Edit the config_infer_primary_yolox file](#edit-the-config_infer_primary_yolox-file) diff --git a/utils/export_damoyolo.py b/utils/export_damoyolo.py new file mode 100644 index 0000000..68a6744 --- /dev/null +++ b/utils/export_damoyolo.py @@ -0,0 +1,86 @@ +import os +import sys +import argparse +import warnings +import onnx +import torch +import torch.nn as nn +from damo.base_models.core.ops import RepConv, SiLU +from damo.config.base import parse_config +from damo.detectors.detector import build_local_model +from damo.utils.model_utils import replace_module + + +class DeepStreamOutput(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + boxes = x[1] + scores, classes = torch.max(x[0], 2, keepdim=True) + return torch.cat((boxes, scores, classes), dim=2) + + +def suppress_warnings(): + warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) + warnings.filterwarnings('ignore', category=UserWarning) + warnings.filterwarnings('ignore', category=DeprecationWarning) + + +def damoyolo_export(weights, config_file, device): + config = parse_config(config_file) + config.model.head.export_with_post = True + model = build_local_model(config, device) + ckpt = torch.load(weights, map_location=device) + model.eval() + if 'model' in ckpt: + ckpt = ckpt['model'] + model.load_state_dict(ckpt, strict=True) + model = replace_module(model, nn.SiLU, SiLU) + for layer in model.modules(): + if isinstance(layer, RepConv): + layer.switch_to_deploy() + model.head.nms = False + return config, model + + +def main(args): + suppress_warnings() + device = torch.device('cpu') + cfg, model = damoyolo_export(args.weights, args.config, device) + + model = nn.Sequential(model, DeepStreamOutput()) + + img_size = args.size * 2 if len(args.size) == 1 else args.size + + onnx_input_im = torch.zeros(1, 3, *img_size).to(device) + onnx_output_file = cfg.miscs['exp_name'] + '.onnx' + + torch.onnx.export(model, onnx_input_im, onnx_output_file, verbose=False, opset_version=args.opset, + do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes=None) + + if args.simplify: + import onnxsim + model_onnx = onnx.load(onnx_output_file) + model_onnx, _ = onnxsim.simplify(model_onnx) + onnx.save(model_onnx, onnx_output_file) + + +def parse_args(): + parser = argparse.ArgumentParser(description='DeepStream DAMO-YOLO conversion') + parser.add_argument('-w', '--weights', required=True, help='Input weights (.pth) file path (required)') + parser.add_argument('-c', '--config', required=True, help='Input config (.py) file path (required)') + parser.add_argument('-s', '--size', nargs='+', type=int, default=[640], help='Inference size [H,W] (default [640])') + parser.add_argument('--opset', type=int, default=11, help='ONNX opset version') + parser.add_argument('--simplify', action='store_true', help='ONNX simplify model') + args = parser.parse_args() + if not os.path.isfile(args.weights): + raise SystemExit('Invalid weights file') + if not os.path.isfile(args.config): + raise SystemExit('Invalid config file') + return args + + +if __name__ == '__main__': + args = parse_args() + sys.exit(main(args))