From e5d994e2d73966d059978de7315c3f2964025581 Mon Sep 17 00:00:00 2001 From: Marcos Luciano Date: Wed, 27 Nov 2024 23:16:57 -0300 Subject: [PATCH] Add support CO-DETR (MMDetection) --- README.md | 6 +- config_infer_primary_codetr.txt | 28 +++++ docs/CODETR.md | 187 ++++++++++++++++++++++++++++++++ docs/PPYOLOE.md | 4 +- docs/RTDETR_Paddle.md | 2 +- utils/export_codetr.py | 149 +++++++++++++++++++++++++ 6 files changed, 371 insertions(+), 5 deletions(-) create mode 100644 config_infer_primary_codetr.txt create mode 100644 docs/CODETR.md create mode 100644 utils/export_codetr.py diff --git a/README.md b/README.md index b27c37a..1743bc3 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ NVIDIA DeepStream SDK 7.1 / 7.0 / 6.4 / 6.3 / 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 / * Support for non square models * Models benchmarks * Support for Darknet models (YOLOv4, etc) using cfg and weights conversion with GPU post-processing -* Support for RT-DETR, YOLO-NAS, PPYOLOE+, PPYOLOE, DAMO-YOLO, Gold-YOLO, RTMDet (MMYOLO), YOLOX, YOLOR, YOLOv9, YOLOv8, YOLOv7, YOLOv6 and YOLOv5 using ONNX conversion with GPU post-processing +* Support for RT-DETR, CO-DETR (MMDetection), YOLO-NAS, PPYOLOE+, PPYOLOE, DAMO-YOLO, Gold-YOLO, RTMDet (MMYOLO), YOLOX, YOLOR, YOLOv9, YOLOv8, YOLOv7, YOLOv6 and YOLOv5 using ONNX conversion with GPU post-processing * GPU bbox parser * Custom ONNX model parser * Dynamic batch-size @@ -49,6 +49,7 @@ NVIDIA DeepStream SDK 7.1 / 7.0 / 6.4 / 6.3 / 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 / * [DAMO-YOLO usage](docs/DAMOYOLO.md) * [PP-YOLOE / PP-YOLOE+ usage](docs/PPYOLOE.md) * [YOLO-NAS usage](docs/YOLONAS.md) +* [CO-DETR (MMDetection) usage](docs/CODETR.md) * [RT-DETR PyTorch usage](docs/RTDETR_PyTorch.md) * [RT-DETR Paddle usage](docs/RTDETR_Paddle.md) * [RT-DETR Ultralytics usage](docs/RTDETR_Ultralytics.md) @@ -220,8 +221,9 @@ NVIDIA DeepStream SDK 7.1 / 7.0 / 6.4 / 6.3 / 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 / * [RTMDet (MMYOLO)](https://github.com/open-mmlab/mmyolo/tree/main/configs/rtmdet) * [Gold-YOLO](https://github.com/huawei-noah/Efficient-Computing/tree/master/Detection/Gold-YOLO) * [DAMO-YOLO](https://github.com/tinyvision/DAMO-YOLO) -* [PP-YOLOE / PP-YOLOE+](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ppyoloe) +* [PP-YOLOE / PP-YOLOE+](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.8/configs/ppyoloe) * [YOLO-NAS](https://github.com/Deci-AI/super-gradients/blob/master/YOLONAS.md) +* [CO-DETR (MMDetection)](https://github.com/open-mmlab/mmdetection/tree/main/projects/CO-DETR) * [RT-DETR](https://github.com/lyuwenyu/RT-DETR) ## diff --git a/config_infer_primary_codetr.txt b/config_infer_primary_codetr.txt new file mode 100644 index 0000000..6263cb6 --- /dev/null +++ b/config_infer_primary_codetr.txt @@ -0,0 +1,28 @@ +[property] +gpu-id=0 +net-scale-factor=0.0039215697906911373 +model-color-format=0 +onnx-file=co_dino_5scale_r50_1x_coco-7481f903.onnx +model-engine-file=model_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=1 +symmetric-padding=0 +#workspace-size=2000 +parse-bbox-func-name=NvDsInferParseYolo +#parse-bbox-func-name=NvDsInferParseYoloCuda +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so +engine-create-func-name=NvDsInferYoloCudaEngineGet + +[class-attrs-all] +nms-iou-threshold=0.45 +pre-cluster-threshold=0.25 +topk=300 diff --git a/docs/CODETR.md b/docs/CODETR.md new file mode 100644 index 0000000..61b7080 --- /dev/null +++ b/docs/CODETR.md @@ -0,0 +1,187 @@ +# CO-DETR (MMDetection) usage + +* [Convert model](#convert-model) +* [Compile the lib](#compile-the-lib) +* [Edit the config_infer_primary_codetr file](#edit-the-config_infer_primary_codetr-file) +* [Edit the deepstream_app_config file](#edit-the-deepstream_app_config-file) +* [Testing the model](#testing-the-model) + +## + +### Convert model + +#### 1. Download the CO-DETR (MMDetection) repo and install the requirements + +``` +git clone https://github.com/open-mmlab/mmdetection.git +cd mmdetection +pip3 install openmim +mim install mmengine +mim install mmdeploy +mim install "mmcv>=2.0.0rc4,<2.2.0" +pip3 install -v -e . +pip3 install onnx onnxslim onnxruntime +``` + +**NOTE**: It is recommended to use Python virtualenv. + +#### 2. Copy conversor + +Copy the `export_codetr.py` file from `DeepStream-Yolo/utils` directory to the `mmdetection` folder. + +#### 3. Download the model + +Download the `pth` file from [CO-DETR (MMDetection)](https://github.com/open-mmlab/mmdetection/tree/main/projects/CO-DETR) releases (example for Co-DINO R50 DETR*) + +``` +wget https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_r50_1x_coco-7481f903.pth +``` + +**NOTE**: You can use your custom model. + +#### 4. Convert model + +Generate the ONNX model file (example for Co-DINO R50 DETR) + +``` +python3 export_codetr.py -w co_dino_5scale_r50_1x_coco-7481f903.pth -c projects/CO-DETR/configs/codino/co_dino_5scale_r50_8xb2_1x_coco.py --dynamic +``` + +**NOTE**: To change the inference size (defaut: 640) + +``` +-s SIZE +--size SIZE +-s HEIGHT WIDTH +--size HEIGHT WIDTH +``` + +Example for 1280 + +``` +-s 1280 +``` + +or + +``` +-s 1280 1280 +``` + +**NOTE**: To simplify the ONNX model (DeepStream >= 6.0) + +``` +--simplify +``` + +**NOTE**: To use dynamic batch-size (DeepStream >= 6.1) + +``` +--dynamic +``` + +**NOTE**: To use static batch-size (example for batch-size = 4) + +``` +--batch 4 +``` + +**NOTE**: If you are using the DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 11. + +``` +--opset 12 +``` + +#### 5. Copy generated files + +Copy the generated ONNX model file and labels.txt file (if generated) to the `DeepStream-Yolo` folder. + +## + +### Compile the lib + +1. Open the `DeepStream-Yolo` folder and compile the lib + +2. Set the `CUDA_VER` according to your DeepStream version + +``` +export CUDA_VER=XY.Z +``` + +* x86 platform + + ``` + DeepStream 7.1 = 12.6 + DeepStream 7.0 / 6.4 = 12.2 + DeepStream 6.3 = 12.1 + DeepStream 6.2 = 11.8 + DeepStream 6.1.1 = 11.7 + DeepStream 6.1 = 11.6 + DeepStream 6.0.1 / 6.0 = 11.4 + DeepStream 5.1 = 11.1 + ``` + +* Jetson platform + + ``` + DeepStream 7.1 = 12.6 + DeepStream 7.0 / 6.4 = 12.2 + DeepStream 6.3 / 6.2 / 6.1.1 / 6.1 = 11.4 + DeepStream 6.0.1 / 6.0 / 5.1 = 10.2 + ``` + +3. Make the lib + +``` +make -C nvdsinfer_custom_impl_Yolo clean && make -C nvdsinfer_custom_impl_Yolo +``` + +## + +### Edit the config_infer_primary_codetr file + +Edit the `config_infer_primary_codetr.txt` file according to your model (example for Co-DINO R50 DETR with 80 classes) + +``` +[property] +... +onnx-file=co_dino_5scale_r50_1x_coco-7481f903.pth.onnx +... +num-detected-classes=80 +... +parse-bbox-func-name=NvDsInferParseYolo +... +``` + +**NOTE**: The **CO-DETR (MMDetection)** resizes the input with left/top padding. To get better accuracy, use + +``` +[property] +... +maintain-aspect-ratio=1 +symmetric-padding=0 +... +``` + +## + +### Edit the deepstream_app_config file + +``` +... +[primary-gie] +... +config-file=config_infer_primary_codetr.txt +``` + +## + +### Testing the model + +``` +deepstream-app -c deepstream_app_config.txt +``` + +**NOTE**: The TensorRT engine file may take a very long time to generate (sometimes more than 10 minutes). + +**NOTE**: For more information about custom models configuration (`batch-size`, `network-mode`, etc), please check the [`docs/customModels.md`](customModels.md) file. diff --git a/docs/PPYOLOE.md b/docs/PPYOLOE.md index 5071810..ec3a293 100644 --- a/docs/PPYOLOE.md +++ b/docs/PPYOLOE.md @@ -14,7 +14,7 @@ #### 1. Download the PaddleDetection repo and install the requirements -https://github.com/PaddlePaddle/PaddleDetection/blob/develop/docs/tutorials/INSTALL.md +https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.8/docs/tutorials/INSTALL.md **NOTE**: It is recommended to use Python virtualenv. @@ -24,7 +24,7 @@ Copy the `export_ppyoloe.py` file from `DeepStream-Yolo/utils` directory to the #### 3. Download the model -Download the `pdparams` file from [PP-YOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ppyoloe) releases (example for PP-YOLOE+_s) +Download the `pdparams` file from [PP-YOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.8/configs/ppyoloe) releases (example for PP-YOLOE+_s) ``` wget https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_s_80e_coco.pdparams diff --git a/docs/RTDETR_Paddle.md b/docs/RTDETR_Paddle.md index 1dd94a2..311529d 100644 --- a/docs/RTDETR_Paddle.md +++ b/docs/RTDETR_Paddle.md @@ -14,7 +14,7 @@ #### 1. Download the PaddleDetection repo and install the requirements -https://github.com/PaddlePaddle/PaddleDetection/blob/develop/docs/tutorials/INSTALL.md +https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.8/docs/tutorials/INSTALL.md ``` git clone https://github.com/lyuwenyu/RT-DETR.git diff --git a/utils/export_codetr.py b/utils/export_codetr.py new file mode 100644 index 0000000..c963b14 --- /dev/null +++ b/utils/export_codetr.py @@ -0,0 +1,149 @@ +import os +import types +import onnx +import torch +import torch.nn as nn +from copy import deepcopy + +from projects import * +from mmengine.registry import MODELS +from mmdeploy.utils import load_config +from mmdet.utils import register_all_modules +from mmengine.model import revert_sync_batchnorm +from mmengine.runner.checkpoint import load_checkpoint + + +class DeepStreamOutput(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + boxes = [] + scores = [] + labels = [] + for det in x: + boxes.append(det.bboxes) + scores.append(det.scores.unsqueeze(-1)) + labels.append(det.labels.unsqueeze(-1)) + boxes = torch.stack(boxes, dim=0) + scores = torch.stack(scores, dim=0) + labels = torch.stack(labels, dim=0) + return torch.cat([boxes, scores, labels.to(boxes.dtype)], dim=-1) + + +def forward_deepstream(self, batch_inputs, batch_data_samples): + b, _, h, w = batch_inputs.shape + batch_data_samples = [{'batch_input_shape': (h, w), 'img_shape': (h, w)} for _ in range(b)] + img_feats = self.extract_feat(batch_inputs) + return self.predict_query_head(img_feats, batch_data_samples, rescale=False) + + +def query_head_predict_deepstream(self, feats, batch_data_samples, rescale=False): + with torch.no_grad(): + outs = self.forward(feats, batch_data_samples) + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_data_samples, rescale=rescale) + return predictions + + +def codetr_export(weights, config, device): + register_all_modules() + model_cfg = load_config(config)[0] + model = deepcopy(model_cfg.model) + model.pop('pretrained', None) + for key in model['train_cfg']: + if 'rpn_proposal' in key: + key['rpn_proposal'] = {} + model['test_cfg'] = [{}, {'rpn': {}, 'rcnn': {}}, {}] + preprocess_cfg = deepcopy(model_cfg.get('preprocess_cfg', {})) + preprocess_cfg.update(deepcopy(model_cfg.get('data_preprocessor', {}))) + model.setdefault('data_preprocessor', preprocess_cfg) + model = MODELS.build(model) + load_checkpoint(model, weights, map_location=device) + model = revert_sync_batchnorm(model) + if hasattr(model, 'backbone') and hasattr(model.backbone, 'switch_to_deploy'): + model.backbone.switch_to_deploy() + if hasattr(model, 'switch_to_deploy') and callable(model.switch_to_deploy): + model.switch_to_deploy() + model = model.to(device) + model.eval() + del model.data_preprocessor + model._forward = types.MethodType(forward_deepstream, model) + model.query_head.predict = types.MethodType(query_head_predict_deepstream, model.query_head) + return model + + +def suppress_warnings(): + import warnings + warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) + warnings.filterwarnings('ignore', category=UserWarning) + warnings.filterwarnings('ignore', category=DeprecationWarning) + warnings.filterwarnings('ignore', category=FutureWarning) + warnings.filterwarnings('ignore', category=ResourceWarning) + + +def main(args): + suppress_warnings() + + print(f'\nStarting: {args.weights}') + + print('Opening CO-DETR model') + + device = torch.device('cpu') + model = codetr_export(args.weights, args.config, device) + + model = nn.Sequential(model, DeepStreamOutput()) + + img_size = args.size * 2 if len(args.size) == 1 else args.size + + onnx_input_im = torch.zeros(args.batch, 3, *img_size).to(device) + onnx_output_file = f'{args.weights}.onnx' + + dynamic_axes = { + 'input': { + 0: 'batch' + }, + 'output': { + 0: 'batch' + } + } + + print('Exporting the model to ONNX') + torch.onnx.export( + model, onnx_input_im, onnx_output_file, verbose=False, opset_version=args.opset, do_constant_folding=True, + input_names=['input'], output_names=['output'], dynamic_axes=dynamic_axes if args.dynamic else None + ) + + if args.simplify: + print('Simplifying the ONNX model') + import onnxslim + model_onnx = onnx.load(onnx_output_file) + model_onnx = onnxslim.slim(model_onnx) + onnx.save(model_onnx, onnx_output_file) + + print(f'Done: {onnx_output_file}\n') + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='DeepStream CO-DETR conversion') + parser.add_argument('-w', '--weights', required=True, type=str, help='Input weights (.pth) file path (required)') + parser.add_argument('-c', '--config', required=True, help='Input config (.py) file path (required)') + parser.add_argument('-s', '--size', nargs='+', type=int, default=[640], help='Inference size [H,W] (default [640])') + parser.add_argument('--opset', type=int, default=11, help='ONNX opset version') + parser.add_argument('--simplify', action='store_true', help='ONNX simplify model') + parser.add_argument('--dynamic', action='store_true', help='Dynamic batch-size') + parser.add_argument('--batch', type=int, default=1, help='Static batch-size') + args = parser.parse_args() + if not os.path.isfile(args.weights): + raise SystemExit('Invalid weights file') + if not os.path.isfile(args.config): + raise SystemExit('Invalid config file') + if args.dynamic and args.batch > 1: + raise SystemExit('Cannot set dynamic batch-size and static batch-size at same time') + return args + + +if __name__ == '__main__': + args = parse_args() + main(args)