diff --git a/README.md b/README.md index 77583bd..379bd7c 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,6 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod ### Future updates -* ONNX model support * DeepStream tutorials * Dynamic batch-size * Segmentation model support @@ -30,10 +29,12 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod * YOLOv7 support * Optimized NMS [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142) * Models benchmarks -* **YOLOv8 support** -* **YOLOX support** -* **PP-YOLOE+ support** -* **YOLOv6 >= 2.0 support** +* YOLOv8 support +* YOLOX support +* PP-YOLOE+ support +* YOLOv6 >= 2.0 support +* **ONNX model support with GPU post-processing** +* **YOLO-NAS support (ONNX)** ## diff --git a/config_infer_primary_ppyoloe_onnx.txt b/config_infer_primary_ppyoloe_onnx.txt new file mode 100644 index 0000000..f5c0036 --- /dev/null +++ b/config_infer_primary_ppyoloe_onnx.txt @@ -0,0 +1,25 @@ +[property] +gpu-id=0 +net-scale-factor=0.0173520735727919486 +offsets=123.675;116.28;103.53 +model-color-format=0 +onnx-file=ppyoloe_crn_s_400e_coco.onnx +model-engine-file=ppyoloe_crn_s_400e_coco.onnx_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=0 +parse-bbox-func-name=NvDsInferParse_PPYOLOE_ONNX +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so + +[class-attrs-all] +nms-iou-threshold=0.7 +pre-cluster-threshold=0.25 +topk=300 diff --git a/config_infer_primary_ppyoloe_plus_onnx.txt b/config_infer_primary_ppyoloe_plus_onnx.txt new file mode 100644 index 0000000..0baa131 --- /dev/null +++ b/config_infer_primary_ppyoloe_plus_onnx.txt @@ -0,0 +1,24 @@ +[property] +gpu-id=0 +net-scale-factor=0.0039215697906911373 +model-color-format=0 +onnx-file=ppyoloe_plus_crn_s_80e_coco.onnx +model-engine-file=ppyoloe_plus_crn_s_80e_coco.onnx_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=0 +parse-bbox-func-name=NvDsInferParse_PPYOLOE_ONNX +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so + +[class-attrs-all] +nms-iou-threshold=0.7 +pre-cluster-threshold=0.25 +topk=300 diff --git a/config_infer_primary_yoloV5_onnx.txt b/config_infer_primary_yoloV5_onnx.txt new file mode 100644 index 0000000..7313f42 --- /dev/null +++ b/config_infer_primary_yoloV5_onnx.txt @@ -0,0 +1,25 @@ +[property] +gpu-id=0 +net-scale-factor=0.0039215697906911373 +model-color-format=0 +onnx-file=yolov5s.onnx +model-engine-file=yolov5s.onnx_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=1 +symmetric-padding=1 +parse-bbox-func-name=NvDsInferParseYolo_ONNX +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so + +[class-attrs-all] +nms-iou-threshold=0.45 +pre-cluster-threshold=0.25 +topk=300 diff --git a/config_infer_primary_yoloV6_onnx.txt b/config_infer_primary_yoloV6_onnx.txt new file mode 100644 index 0000000..92b151f --- /dev/null +++ b/config_infer_primary_yoloV6_onnx.txt @@ -0,0 +1,25 @@ +[property] +gpu-id=0 +net-scale-factor=0.0039215697906911373 +model-color-format=0 +onnx-file=yolov6s.onnx +model-engine-file=yolov6s.onnx_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=1 +symmetric-padding=1 +parse-bbox-func-name=NvDsInferParseYolo_ONNX +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so + +[class-attrs-all] +nms-iou-threshold=0.45 +pre-cluster-threshold=0.25 +topk=300 diff --git a/config_infer_primary_yoloV7_onnx.txt b/config_infer_primary_yoloV7_onnx.txt new file mode 100644 index 0000000..fab408b --- /dev/null +++ b/config_infer_primary_yoloV7_onnx.txt @@ -0,0 +1,25 @@ +[property] +gpu-id=0 +net-scale-factor=0.0039215697906911373 +model-color-format=0 +onnx-file=yolov7.onnx +model-engine-file=yolov7.onnx_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=1 +symmetric-padding=1 +parse-bbox-func-name=NvDsInferParseYolo_ONNX +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so + +[class-attrs-all] +nms-iou-threshold=0.45 +pre-cluster-threshold=0.25 +topk=300 diff --git a/config_infer_primary_yoloV8_onnx.txt b/config_infer_primary_yoloV8_onnx.txt new file mode 100644 index 0000000..b18a3b2 --- /dev/null +++ b/config_infer_primary_yoloV8_onnx.txt @@ -0,0 +1,25 @@ +[property] +gpu-id=0 +net-scale-factor=0.0039215697906911373 +model-color-format=0 +onnx-file=yolov8s.onnx +model-engine-file=yolov8s.onnx_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=1 +symmetric-padding=1 +parse-bbox-func-name=NvDsInferParseYoloV8_ONNX +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so + +[class-attrs-all] +nms-iou-threshold=0.45 +pre-cluster-threshold=0.25 +topk=300 diff --git a/config_infer_primary_yolo_nas_onnx.txt b/config_infer_primary_yolo_nas_onnx.txt new file mode 100644 index 0000000..5364ad7 --- /dev/null +++ b/config_infer_primary_yolo_nas_onnx.txt @@ -0,0 +1,25 @@ +[property] +gpu-id=0 +net-scale-factor=0.0039215697906911373 +model-color-format=0 +onnx-file=yolo_nas_s.onnx +model-engine-file=yolo_nas_s.onnx_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=1 +symmetric-padding=1 +parse-bbox-func-name=NvDsInferParse_YOLO_NAS_ONNX +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so + +[class-attrs-all] +nms-iou-threshold=0.45 +pre-cluster-threshold=0.25 +topk=300 diff --git a/config_infer_primary_yolox_legacy_onnx.txt b/config_infer_primary_yolox_legacy_onnx.txt new file mode 100644 index 0000000..1d9f410 --- /dev/null +++ b/config_infer_primary_yolox_legacy_onnx.txt @@ -0,0 +1,26 @@ +[property] +gpu-id=0 +net-scale-factor=0.0173520735727919486 +offsets=123.675;116.28;103.53 +model-color-format=0 +onnx-file=yolox_s.onnx +model-engine-file=yolox_s.onnx_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=1 +symmetric-padding=0 +parse-bbox-func-name=NvDsInferParseYoloX_ONNX +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so + +[class-attrs-all] +nms-iou-threshold=0.45 +pre-cluster-threshold=0.25 +topk=300 diff --git a/config_infer_primary_yolox_onnx.txt b/config_infer_primary_yolox_onnx.txt new file mode 100644 index 0000000..49fb4aa --- /dev/null +++ b/config_infer_primary_yolox_onnx.txt @@ -0,0 +1,25 @@ +[property] +gpu-id=0 +net-scale-factor=0 +model-color-format=0 +onnx-file=yolox_s.onnx +model-engine-file=yolox_s.onnx_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=1 +symmetric-padding=0 +parse-bbox-func-name=NvDsInferParseYoloX_ONNX +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so + +[class-attrs-all] +nms-iou-threshold=0.45 +pre-cluster-threshold=0.25 +topk=300 diff --git a/nvdsinfer_custom_impl_Yolo/Makefile b/nvdsinfer_custom_impl_Yolo/Makefile index c6e19b8..9d7316d 100644 --- a/nvdsinfer_custom_impl_Yolo/Makefile +++ b/nvdsinfer_custom_impl_Yolo/Makefile @@ -45,6 +45,8 @@ ifeq ($(OPENCV), 1) LIBS+= $(shell pkg-config --libs opencv4 2> /dev/null || pkg-config --libs opencv) endif +CUFLAGS:= -I/opt/nvidia/deepstream/deepstream/sources/includes -I/usr/local/cuda-$(CUDA_VER)/include + LIBS+= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib64 -lcudart -lcublas -lstdc++fs LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group @@ -70,7 +72,7 @@ all: $(TARGET_LIB) $(CC) -c $(COMMON) -o $@ $(CFLAGS) $< %.o: %.cu $(INCS) Makefile - $(NVCC) -c -o $@ --compiler-options '-fPIC' $< + $(NVCC) -c -o $@ --compiler-options '-fPIC' $(CUFLAGS) $< $(TARGET_LIB) : $(TARGET_OBJS) $(CC) -o $@ $(TARGET_OBJS) $(LFLAGS) diff --git a/nvdsinfer_custom_impl_Yolo/nvdsinitinputlayers_Yolo.cpp b/nvdsinfer_custom_impl_Yolo/nvdsinitinputlayers_Yolo.cpp new file mode 100644 index 0000000..2742dfb --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/nvdsinitinputlayers_Yolo.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + * Edited by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#include "nvdsinfer_custom_impl.h" + +bool +NvDsInferInitializeInputLayers(std::vector const &inputLayersInfo, + NvDsInferNetworkInfo const &networkInfo, unsigned int maxBatchSize) +{ + float *scaleFactor = (float *) inputLayersInfo[0].buffer; + for (unsigned int i = 0; i < maxBatchSize; i++) { + scaleFactor[i * 2 + 0] = 1.0; + scaleFactor[i * 2 + 1] = 1.0; + } + return true; +} diff --git a/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo_cuda.cu b/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo_cuda.cu new file mode 100644 index 0000000..f25c0ec --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo_cuda.cu @@ -0,0 +1,469 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + * Edited by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#include +#include + +#include "nvdsinfer_custom_impl.h" + +#include "utils.h" +#include "yoloPlugins.h" + +__global__ void decodeTensorYolo_ONNX(NvDsInferParseObjectInfo *binfo, const float* detections, const int numClasses, + const int outputSize, float netW, float netH) +{ + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (x_id >= outputSize) + return; + + float maxProb = 0.0f; + int maxIndex = -1; + + for (uint i = 0; i < numClasses; ++i) { + float prob = detections[x_id * (5 + numClasses) + 5 + i]; + if (prob > maxProb) { + maxProb = prob; + maxIndex = i; + } + } + + const float objectness = detections[x_id * (5 + numClasses) + 4]; + + const float bxc = detections[x_id * (5 + numClasses) + 0]; + const float byc = detections[x_id * (5 + numClasses) + 1]; + const float bw = detections[x_id * (5 + numClasses) + 2]; + const float bh = detections[x_id * (5 + numClasses) + 3]; + + float x0 = bxc - bw / 2; + float y0 = byc - bh / 2; + float x1 = x0 + bw; + float y1 = y0 + bh; + x0 = fminf(float(netW), fmaxf(float(0.0), x0)); + y0 = fminf(float(netH), fmaxf(float(0.0), y0)); + x1 = fminf(float(netW), fmaxf(float(0.0), x1)); + y1 = fminf(float(netH), fmaxf(float(0.0), y1)); + + binfo[x_id].left = x0; + binfo[x_id].top = y0; + binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0)); + binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0)); + binfo[x_id].detectionConfidence = objectness * maxProb; + binfo[x_id].classId = maxIndex; +} + +__global__ void decodeTensorYoloV8_ONNX(NvDsInferParseObjectInfo *binfo, const float* detections, const int numClasses, + const int outputSize, float netW, float netH) +{ + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (x_id >= outputSize) + return; + + float maxProb = 0.0f; + int maxIndex = -1; + + for (uint i = 0; i < numClasses; ++i) { + float prob = detections[x_id + outputSize * (i + 4)]; + if (prob > maxProb) { + maxProb = prob; + maxIndex = i; + } + } + + const float bxc = detections[x_id + outputSize * 0]; + const float byc = detections[x_id + outputSize * 1]; + const float bw = detections[x_id + outputSize * 2]; + const float bh = detections[x_id + outputSize * 3]; + + float x0 = bxc - bw / 2; + float y0 = byc - bh / 2; + float x1 = x0 + bw; + float y1 = y0 + bh; + x0 = fminf(float(netW), fmaxf(float(0.0), x0)); + y0 = fminf(float(netH), fmaxf(float(0.0), y0)); + x1 = fminf(float(netW), fmaxf(float(0.0), x1)); + y1 = fminf(float(netH), fmaxf(float(0.0), y1)); + + binfo[x_id].left = x0; + binfo[x_id].top = y0; + binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0)); + binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0)); + binfo[x_id].detectionConfidence = maxProb; + binfo[x_id].classId = maxIndex; +} + +__global__ void decodeTensorYoloX_ONNX(NvDsInferParseObjectInfo *binfo, const float* detections, const int numClasses, + const int outputSize, float netW, float netH, const int *grid0, const int *grid1, const int *strides) +{ + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (x_id >= outputSize) + return; + + float maxProb = 0.0f; + int maxIndex = -1; + + for (uint i = 0; i < numClasses; ++i) { + float prob = detections[x_id * (5 + numClasses) + 5 + i]; + if (prob > maxProb) { + maxProb = prob; + maxIndex = i; + } + } + + const float objectness = detections[x_id * (5 + numClasses) + 4]; + + const float bxc = (detections[x_id * (5 + numClasses) + 0] + grid0[x_id]) * strides[x_id]; + const float byc = (detections[x_id * (5 + numClasses) + 1] + grid1[x_id]) * strides[x_id]; + const float bw = __expf(detections[x_id * (5 + numClasses) + 2]) * strides[x_id]; + const float bh = __expf(detections[x_id * (5 + numClasses) + 3]) * strides[x_id]; + + float x0 = bxc - bw / 2; + float y0 = byc - bh / 2; + float x1 = x0 + bw; + float y1 = y0 + bh; + x0 = fminf(float(netW), fmaxf(float(0.0), x0)); + y0 = fminf(float(netH), fmaxf(float(0.0), y0)); + x1 = fminf(float(netW), fmaxf(float(0.0), x1)); + y1 = fminf(float(netH), fmaxf(float(0.0), y1)); + + binfo[x_id].left = x0; + binfo[x_id].top = y0; + binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0)); + binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0)); + binfo[x_id].detectionConfidence = objectness * maxProb; + binfo[x_id].classId = maxIndex; +} + +__global__ void decodeTensor_YOLO_NAS_ONNX(NvDsInferParseObjectInfo *binfo, const float* scores, const float* boxes, + const int numClasses, const int outputSize, float netW, float netH) +{ + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (x_id >= outputSize) + return; + + float maxProb = 0.0f; + int maxIndex = -1; + + for (uint i = 0; i < numClasses; ++i) { + float prob = scores[x_id * numClasses + i]; + if (prob > maxProb) { + maxProb = prob; + maxIndex = i; + } + } + + float x0 = boxes[x_id * 4 + 0]; + float y0 = boxes[x_id * 4 + 1]; + float x1 = boxes[x_id * 4 + 2]; + float y1 = boxes[x_id * 4 + 3]; + + x0 = fminf(float(netW), fmaxf(float(0.0), x0)); + y0 = fminf(float(netH), fmaxf(float(0.0), y0)); + x1 = fminf(float(netW), fmaxf(float(0.0), x1)); + y1 = fminf(float(netH), fmaxf(float(0.0), y1)); + + binfo[x_id].left = x0; + binfo[x_id].top = y0; + binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0)); + binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0)); + binfo[x_id].detectionConfidence = maxProb; + binfo[x_id].classId = maxIndex; +} + +__global__ void decodeTensor_PPYOLOE_ONNX(NvDsInferParseObjectInfo *binfo, const float* scores, const float* boxes, + const int numClasses, const int outputSize, float netW, float netH) +{ + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (x_id >= outputSize) + return; + + float maxProb = 0.0f; + int maxIndex = -1; + + for (uint i = 0; i < numClasses; ++i) { + float prob = scores[x_id + outputSize * i]; + if (prob > maxProb) { + maxProb = prob; + maxIndex = i; + } + } + + float x0 = boxes[x_id * 4 + 0]; + float y0 = boxes[x_id * 4 + 1]; + float x1 = boxes[x_id * 4 + 2]; + float y1 = boxes[x_id * 4 + 3]; + + x0 = fminf(float(netW), fmaxf(float(0.0), x0)); + y0 = fminf(float(netH), fmaxf(float(0.0), y0)); + x1 = fminf(float(netW), fmaxf(float(0.0), x1)); + y1 = fminf(float(netH), fmaxf(float(0.0), y1)); + + binfo[x_id].left = x0; + binfo[x_id].top = y0; + binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0)); + binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0)); + binfo[x_id].detectionConfidence = maxProb; + binfo[x_id].classId = maxIndex; +} + +static bool +NvDsInferParseCustomYolo_ONNX(std::vector const& outputLayersInfo, + NvDsInferNetworkInfo const& networkInfo, NvDsInferParseDetectionParams const& detectionParams, + std::vector& objectList) +{ + if (outputLayersInfo.empty()) { + std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl; + return false; + } + + const NvDsInferLayerInfo& layer = outputLayersInfo[0]; + + const uint outputSize = layer.inferDims.d[0]; + const uint numClasses = layer.inferDims.d[1] - 5; + + if (numClasses != detectionParams.numClassesConfigured) { + std::cerr << "WARNING: Number of classes mismatch, make sure to set num-detected-classes=" << numClasses + << " in config_infer file\n" << std::endl; + } + + thrust::device_vector objects(outputSize); + + int threads_per_block = 1024; + int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1; + + decodeTensorYolo_ONNX<<>>( + thrust::raw_pointer_cast(objects.data()), (const float*) (layer.buffer), numClasses, outputSize, + static_cast(networkInfo.width), static_cast(networkInfo.height)); + + objectList.resize(outputSize); + thrust::copy(objects.begin(), objects.end(), objectList.begin()); + + return true; +} + +static bool +NvDsInferParseCustomYoloV8_ONNX(std::vector const& outputLayersInfo, + NvDsInferNetworkInfo const& networkInfo, NvDsInferParseDetectionParams const& detectionParams, + std::vector& objectList) +{ + if (outputLayersInfo.empty()) { + std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl; + return false; + } + + const NvDsInferLayerInfo& layer = outputLayersInfo[0]; + + const uint numClasses = layer.inferDims.d[0] - 4; + const uint outputSize = layer.inferDims.d[1]; + + if (numClasses != detectionParams.numClassesConfigured) { + std::cerr << "WARNING: Number of classes mismatch, make sure to set num-detected-classes=" << numClasses + << " in config_infer file\n" << std::endl; + } + + thrust::device_vector objects(outputSize); + + int threads_per_block = 1024; + int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1; + + decodeTensorYoloV8_ONNX<<>>( + thrust::raw_pointer_cast(objects.data()), (const float*) (layer.buffer), numClasses, outputSize, + static_cast(networkInfo.width), static_cast(networkInfo.height)); + + objectList.resize(outputSize); + thrust::copy(objects.begin(), objects.end(), objectList.begin()); + + return true; +} + +static bool +NvDsInferParseCustomYoloX_ONNX(std::vector const& outputLayersInfo, + NvDsInferNetworkInfo const& networkInfo, NvDsInferParseDetectionParams const& detectionParams, + std::vector& objectList) +{ + if (outputLayersInfo.empty()) { + std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl; + return false; + } + + const NvDsInferLayerInfo& layer = outputLayersInfo[0]; + + const uint outputSize = layer.inferDims.d[0]; + const uint numClasses = layer.inferDims.d[1] - 5; + + if (numClasses != detectionParams.numClassesConfigured) { + std::cerr << "WARNING: Number of classes mismatch, make sure to set num-detected-classes=" << numClasses + << " in config_infer file\n" << std::endl; + } + + thrust::device_vector objects(outputSize); + + std::vector strides = {8, 16, 32}; + + std::vector grid0; + std::vector grid1; + std::vector grid_strides; + + for (uint s = 0; s < strides.size(); ++s) { + int num_grid_y = networkInfo.height / strides[s]; + int num_grid_x = networkInfo.width / strides[s]; + for (int g1 = 0; g1 < num_grid_y; ++g1) { + for (int g0 = 0; g0 < num_grid_x; ++g0) { + grid0.push_back(g0); + grid1.push_back(g1); + grid_strides.push_back(strides[s]); + } + } + } + + thrust::device_vector d_grid0(grid0); + thrust::device_vector d_grid1(grid1); + thrust::device_vector d_grid_strides(grid_strides); + + int threads_per_block = 1024; + int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1; + + decodeTensorYoloX_ONNX<<>>( + thrust::raw_pointer_cast(objects.data()), (const float*) (layer.buffer), numClasses, outputSize, + static_cast(networkInfo.width), static_cast(networkInfo.height), + thrust::raw_pointer_cast(d_grid0.data()), thrust::raw_pointer_cast(d_grid1.data()), + thrust::raw_pointer_cast(d_grid_strides.data())); + + objectList.resize(outputSize); + thrust::copy(objects.begin(), objects.end(), objectList.begin()); + + return true; +} + +static bool +NvDsInferParseCustom_YOLO_NAS_ONNX(std::vector const& outputLayersInfo, + NvDsInferNetworkInfo const& networkInfo, NvDsInferParseDetectionParams const& detectionParams, + std::vector& objectList) +{ + if (outputLayersInfo.empty()) { + std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl; + return false; + } + + const NvDsInferLayerInfo& scores = outputLayersInfo[0]; + const NvDsInferLayerInfo& boxes = outputLayersInfo[1]; + + const uint outputSize = scores.inferDims.d[0]; + const uint numClasses = scores.inferDims.d[1]; + + if (numClasses != detectionParams.numClassesConfigured) { + std::cerr << "WARNING: Number of classes mismatch, make sure to set num-detected-classes=" << numClasses + << " in config_infer file\n" << std::endl; + } + + thrust::device_vector objects(outputSize); + + int threads_per_block = 1024; + int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1; + + decodeTensor_YOLO_NAS_ONNX<<>>( + thrust::raw_pointer_cast(objects.data()), (const float*) (scores.buffer), (const float*) (boxes.buffer), numClasses, + outputSize, static_cast(networkInfo.width), static_cast(networkInfo.height)); + + objectList.resize(outputSize); + thrust::copy(objects.begin(), objects.end(), objectList.begin()); + + return true; +} + +static bool +NvDsInferParseCustom_PPYOLOE_ONNX(std::vector const& outputLayersInfo, + NvDsInferNetworkInfo const& networkInfo, NvDsInferParseDetectionParams const& detectionParams, + std::vector& objectList) +{ + if (outputLayersInfo.empty()) { + std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl; + return false; + } + + const NvDsInferLayerInfo& scores = outputLayersInfo[0]; + const NvDsInferLayerInfo& boxes = outputLayersInfo[1]; + + const uint numClasses = scores.inferDims.d[0]; + const uint outputSize = scores.inferDims.d[1]; + + if (numClasses != detectionParams.numClassesConfigured) { + std::cerr << "WARNING: Number of classes mismatch, make sure to set num-detected-classes=" << numClasses + << " in config_infer file\n" << std::endl; + } + + thrust::device_vector objects(outputSize); + + int threads_per_block = 1024; + int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1; + + decodeTensor_PPYOLOE_ONNX<<>>( + thrust::raw_pointer_cast(objects.data()), (const float*) (scores.buffer), (const float*) (boxes.buffer), numClasses, + outputSize, static_cast(networkInfo.width), static_cast(networkInfo.height)); + + objectList.resize(outputSize); + thrust::copy(objects.begin(), objects.end(), objectList.begin()); + + return true; +} + +extern "C" bool +NvDsInferParseYolo_ONNX(std::vector const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo, + NvDsInferParseDetectionParams const& detectionParams, std::vector& objectList) +{ + return NvDsInferParseCustomYolo_ONNX(outputLayersInfo, networkInfo, detectionParams, objectList); +} + +extern "C" bool +NvDsInferParseYoloV8_ONNX(std::vector const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo, + NvDsInferParseDetectionParams const& detectionParams, std::vector& objectList) +{ + return NvDsInferParseCustomYoloV8_ONNX(outputLayersInfo, networkInfo, detectionParams, objectList); +} + +extern "C" bool +NvDsInferParseYoloX_ONNX(std::vector const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo, + NvDsInferParseDetectionParams const& detectionParams, std::vector& objectList) +{ + return NvDsInferParseCustomYoloX_ONNX(outputLayersInfo, networkInfo, detectionParams, objectList); +} + +extern "C" bool +NvDsInferParse_YOLO_NAS_ONNX(std::vector const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo, + NvDsInferParseDetectionParams const& detectionParams, std::vector& objectList) +{ + return NvDsInferParseCustom_YOLO_NAS_ONNX(outputLayersInfo, networkInfo, detectionParams, objectList); +} + +extern "C" bool +NvDsInferParse_PPYOLOE_ONNX(std::vector const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo, + NvDsInferParseDetectionParams const& detectionParams, std::vector& objectList) +{ + return NvDsInferParseCustom_PPYOLOE_ONNX(outputLayersInfo, networkInfo, detectionParams, objectList); +}