Add YOLO-NAS and ONNX support

This commit is contained in:
Marcos Luciano
2023-05-14 14:47:38 -03:00
parent a527d7807a
commit b4e2dbdcf8
13 changed files with 741 additions and 6 deletions

View File

@@ -6,7 +6,6 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod
### Future updates
* ONNX model support
* DeepStream tutorials
* Dynamic batch-size
* Segmentation model support
@@ -30,10 +29,12 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod
* YOLOv7 support
* Optimized NMS [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142)
* Models benchmarks
* **YOLOv8 support**
* **YOLOX support**
* **PP-YOLOE+ support**
* **YOLOv6 >= 2.0 support**
* YOLOv8 support
* YOLOX support
* PP-YOLOE+ support
* YOLOv6 >= 2.0 support
* **ONNX model support with GPU post-processing**
* **YOLO-NAS support (ONNX)**
##

View File

@@ -0,0 +1,25 @@
[property]
gpu-id=0
net-scale-factor=0.0173520735727919486
offsets=123.675;116.28;103.53
model-color-format=0
onnx-file=ppyoloe_crn_s_400e_coco.onnx
model-engine-file=ppyoloe_crn_s_400e_coco.onnx_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=0
parse-bbox-func-name=NvDsInferParse_PPYOLOE_ONNX
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
[class-attrs-all]
nms-iou-threshold=0.7
pre-cluster-threshold=0.25
topk=300

View File

@@ -0,0 +1,24 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
onnx-file=ppyoloe_plus_crn_s_80e_coco.onnx
model-engine-file=ppyoloe_plus_crn_s_80e_coco.onnx_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=0
parse-bbox-func-name=NvDsInferParse_PPYOLOE_ONNX
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
[class-attrs-all]
nms-iou-threshold=0.7
pre-cluster-threshold=0.25
topk=300

View File

@@ -0,0 +1,25 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
onnx-file=yolov5s.onnx
model-engine-file=yolov5s.onnx_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=1
parse-bbox-func-name=NvDsInferParseYolo_ONNX
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
[class-attrs-all]
nms-iou-threshold=0.45
pre-cluster-threshold=0.25
topk=300

View File

@@ -0,0 +1,25 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
onnx-file=yolov6s.onnx
model-engine-file=yolov6s.onnx_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=1
parse-bbox-func-name=NvDsInferParseYolo_ONNX
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
[class-attrs-all]
nms-iou-threshold=0.45
pre-cluster-threshold=0.25
topk=300

View File

@@ -0,0 +1,25 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
onnx-file=yolov7.onnx
model-engine-file=yolov7.onnx_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=1
parse-bbox-func-name=NvDsInferParseYolo_ONNX
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
[class-attrs-all]
nms-iou-threshold=0.45
pre-cluster-threshold=0.25
topk=300

View File

@@ -0,0 +1,25 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
onnx-file=yolov8s.onnx
model-engine-file=yolov8s.onnx_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=1
parse-bbox-func-name=NvDsInferParseYoloV8_ONNX
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
[class-attrs-all]
nms-iou-threshold=0.45
pre-cluster-threshold=0.25
topk=300

View File

@@ -0,0 +1,25 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
onnx-file=yolo_nas_s.onnx
model-engine-file=yolo_nas_s.onnx_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=1
parse-bbox-func-name=NvDsInferParse_YOLO_NAS_ONNX
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
[class-attrs-all]
nms-iou-threshold=0.45
pre-cluster-threshold=0.25
topk=300

View File

@@ -0,0 +1,26 @@
[property]
gpu-id=0
net-scale-factor=0.0173520735727919486
offsets=123.675;116.28;103.53
model-color-format=0
onnx-file=yolox_s.onnx
model-engine-file=yolox_s.onnx_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=0
parse-bbox-func-name=NvDsInferParseYoloX_ONNX
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
[class-attrs-all]
nms-iou-threshold=0.45
pre-cluster-threshold=0.25
topk=300

View File

@@ -0,0 +1,25 @@
[property]
gpu-id=0
net-scale-factor=0
model-color-format=0
onnx-file=yolox_s.onnx
model-engine-file=yolox_s.onnx_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=0
parse-bbox-func-name=NvDsInferParseYoloX_ONNX
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
[class-attrs-all]
nms-iou-threshold=0.45
pre-cluster-threshold=0.25
topk=300

View File

@@ -45,6 +45,8 @@ ifeq ($(OPENCV), 1)
LIBS+= $(shell pkg-config --libs opencv4 2> /dev/null || pkg-config --libs opencv)
endif
CUFLAGS:= -I/opt/nvidia/deepstream/deepstream/sources/includes -I/usr/local/cuda-$(CUDA_VER)/include
LIBS+= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib64 -lcudart -lcublas -lstdc++fs
LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group
@@ -70,7 +72,7 @@ all: $(TARGET_LIB)
$(CC) -c $(COMMON) -o $@ $(CFLAGS) $<
%.o: %.cu $(INCS) Makefile
$(NVCC) -c -o $@ --compiler-options '-fPIC' $<
$(NVCC) -c -o $@ --compiler-options '-fPIC' $(CUFLAGS) $<
$(TARGET_LIB) : $(TARGET_OBJS)
$(CC) -o $@ $(TARGET_OBJS) $(LFLAGS)

View File

@@ -0,0 +1,38 @@
/*
* Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
* Edited by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "nvdsinfer_custom_impl.h"
bool
NvDsInferInitializeInputLayers(std::vector<NvDsInferLayerInfo> const &inputLayersInfo,
NvDsInferNetworkInfo const &networkInfo, unsigned int maxBatchSize)
{
float *scaleFactor = (float *) inputLayersInfo[0].buffer;
for (unsigned int i = 0; i < maxBatchSize; i++) {
scaleFactor[i * 2 + 0] = 1.0;
scaleFactor[i * 2 + 1] = 1.0;
}
return true;
}

View File

@@ -0,0 +1,469 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
* Edited by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include "nvdsinfer_custom_impl.h"
#include "utils.h"
#include "yoloPlugins.h"
__global__ void decodeTensorYolo_ONNX(NvDsInferParseObjectInfo *binfo, const float* detections, const int numClasses,
const int outputSize, float netW, float netH)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
if (x_id >= outputSize)
return;
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numClasses; ++i) {
float prob = detections[x_id * (5 + numClasses) + 5 + i];
if (prob > maxProb) {
maxProb = prob;
maxIndex = i;
}
}
const float objectness = detections[x_id * (5 + numClasses) + 4];
const float bxc = detections[x_id * (5 + numClasses) + 0];
const float byc = detections[x_id * (5 + numClasses) + 1];
const float bw = detections[x_id * (5 + numClasses) + 2];
const float bh = detections[x_id * (5 + numClasses) + 3];
float x0 = bxc - bw / 2;
float y0 = byc - bh / 2;
float x1 = x0 + bw;
float y1 = y0 + bh;
x0 = fminf(float(netW), fmaxf(float(0.0), x0));
y0 = fminf(float(netH), fmaxf(float(0.0), y0));
x1 = fminf(float(netW), fmaxf(float(0.0), x1));
y1 = fminf(float(netH), fmaxf(float(0.0), y1));
binfo[x_id].left = x0;
binfo[x_id].top = y0;
binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
binfo[x_id].detectionConfidence = objectness * maxProb;
binfo[x_id].classId = maxIndex;
}
__global__ void decodeTensorYoloV8_ONNX(NvDsInferParseObjectInfo *binfo, const float* detections, const int numClasses,
const int outputSize, float netW, float netH)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
if (x_id >= outputSize)
return;
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numClasses; ++i) {
float prob = detections[x_id + outputSize * (i + 4)];
if (prob > maxProb) {
maxProb = prob;
maxIndex = i;
}
}
const float bxc = detections[x_id + outputSize * 0];
const float byc = detections[x_id + outputSize * 1];
const float bw = detections[x_id + outputSize * 2];
const float bh = detections[x_id + outputSize * 3];
float x0 = bxc - bw / 2;
float y0 = byc - bh / 2;
float x1 = x0 + bw;
float y1 = y0 + bh;
x0 = fminf(float(netW), fmaxf(float(0.0), x0));
y0 = fminf(float(netH), fmaxf(float(0.0), y0));
x1 = fminf(float(netW), fmaxf(float(0.0), x1));
y1 = fminf(float(netH), fmaxf(float(0.0), y1));
binfo[x_id].left = x0;
binfo[x_id].top = y0;
binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
binfo[x_id].detectionConfidence = maxProb;
binfo[x_id].classId = maxIndex;
}
__global__ void decodeTensorYoloX_ONNX(NvDsInferParseObjectInfo *binfo, const float* detections, const int numClasses,
const int outputSize, float netW, float netH, const int *grid0, const int *grid1, const int *strides)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
if (x_id >= outputSize)
return;
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numClasses; ++i) {
float prob = detections[x_id * (5 + numClasses) + 5 + i];
if (prob > maxProb) {
maxProb = prob;
maxIndex = i;
}
}
const float objectness = detections[x_id * (5 + numClasses) + 4];
const float bxc = (detections[x_id * (5 + numClasses) + 0] + grid0[x_id]) * strides[x_id];
const float byc = (detections[x_id * (5 + numClasses) + 1] + grid1[x_id]) * strides[x_id];
const float bw = __expf(detections[x_id * (5 + numClasses) + 2]) * strides[x_id];
const float bh = __expf(detections[x_id * (5 + numClasses) + 3]) * strides[x_id];
float x0 = bxc - bw / 2;
float y0 = byc - bh / 2;
float x1 = x0 + bw;
float y1 = y0 + bh;
x0 = fminf(float(netW), fmaxf(float(0.0), x0));
y0 = fminf(float(netH), fmaxf(float(0.0), y0));
x1 = fminf(float(netW), fmaxf(float(0.0), x1));
y1 = fminf(float(netH), fmaxf(float(0.0), y1));
binfo[x_id].left = x0;
binfo[x_id].top = y0;
binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
binfo[x_id].detectionConfidence = objectness * maxProb;
binfo[x_id].classId = maxIndex;
}
__global__ void decodeTensor_YOLO_NAS_ONNX(NvDsInferParseObjectInfo *binfo, const float* scores, const float* boxes,
const int numClasses, const int outputSize, float netW, float netH)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
if (x_id >= outputSize)
return;
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numClasses; ++i) {
float prob = scores[x_id * numClasses + i];
if (prob > maxProb) {
maxProb = prob;
maxIndex = i;
}
}
float x0 = boxes[x_id * 4 + 0];
float y0 = boxes[x_id * 4 + 1];
float x1 = boxes[x_id * 4 + 2];
float y1 = boxes[x_id * 4 + 3];
x0 = fminf(float(netW), fmaxf(float(0.0), x0));
y0 = fminf(float(netH), fmaxf(float(0.0), y0));
x1 = fminf(float(netW), fmaxf(float(0.0), x1));
y1 = fminf(float(netH), fmaxf(float(0.0), y1));
binfo[x_id].left = x0;
binfo[x_id].top = y0;
binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
binfo[x_id].detectionConfidence = maxProb;
binfo[x_id].classId = maxIndex;
}
__global__ void decodeTensor_PPYOLOE_ONNX(NvDsInferParseObjectInfo *binfo, const float* scores, const float* boxes,
const int numClasses, const int outputSize, float netW, float netH)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
if (x_id >= outputSize)
return;
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numClasses; ++i) {
float prob = scores[x_id + outputSize * i];
if (prob > maxProb) {
maxProb = prob;
maxIndex = i;
}
}
float x0 = boxes[x_id * 4 + 0];
float y0 = boxes[x_id * 4 + 1];
float x1 = boxes[x_id * 4 + 2];
float y1 = boxes[x_id * 4 + 3];
x0 = fminf(float(netW), fmaxf(float(0.0), x0));
y0 = fminf(float(netH), fmaxf(float(0.0), y0));
x1 = fminf(float(netW), fmaxf(float(0.0), x1));
y1 = fminf(float(netH), fmaxf(float(0.0), y1));
binfo[x_id].left = x0;
binfo[x_id].top = y0;
binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
binfo[x_id].detectionConfidence = maxProb;
binfo[x_id].classId = maxIndex;
}
static bool
NvDsInferParseCustomYolo_ONNX(std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
NvDsInferNetworkInfo const& networkInfo, NvDsInferParseDetectionParams const& detectionParams,
std::vector<NvDsInferParseObjectInfo>& objectList)
{
if (outputLayersInfo.empty()) {
std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl;
return false;
}
const NvDsInferLayerInfo& layer = outputLayersInfo[0];
const uint outputSize = layer.inferDims.d[0];
const uint numClasses = layer.inferDims.d[1] - 5;
if (numClasses != detectionParams.numClassesConfigured) {
std::cerr << "WARNING: Number of classes mismatch, make sure to set num-detected-classes=" << numClasses
<< " in config_infer file\n" << std::endl;
}
thrust::device_vector<NvDsInferParseObjectInfo> objects(outputSize);
int threads_per_block = 1024;
int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1;
decodeTensorYolo_ONNX<<<threads_per_block, number_of_blocks>>>(
thrust::raw_pointer_cast(objects.data()), (const float*) (layer.buffer), numClasses, outputSize,
static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height));
objectList.resize(outputSize);
thrust::copy(objects.begin(), objects.end(), objectList.begin());
return true;
}
static bool
NvDsInferParseCustomYoloV8_ONNX(std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
NvDsInferNetworkInfo const& networkInfo, NvDsInferParseDetectionParams const& detectionParams,
std::vector<NvDsInferParseObjectInfo>& objectList)
{
if (outputLayersInfo.empty()) {
std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl;
return false;
}
const NvDsInferLayerInfo& layer = outputLayersInfo[0];
const uint numClasses = layer.inferDims.d[0] - 4;
const uint outputSize = layer.inferDims.d[1];
if (numClasses != detectionParams.numClassesConfigured) {
std::cerr << "WARNING: Number of classes mismatch, make sure to set num-detected-classes=" << numClasses
<< " in config_infer file\n" << std::endl;
}
thrust::device_vector<NvDsInferParseObjectInfo> objects(outputSize);
int threads_per_block = 1024;
int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1;
decodeTensorYoloV8_ONNX<<<threads_per_block, number_of_blocks>>>(
thrust::raw_pointer_cast(objects.data()), (const float*) (layer.buffer), numClasses, outputSize,
static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height));
objectList.resize(outputSize);
thrust::copy(objects.begin(), objects.end(), objectList.begin());
return true;
}
static bool
NvDsInferParseCustomYoloX_ONNX(std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
NvDsInferNetworkInfo const& networkInfo, NvDsInferParseDetectionParams const& detectionParams,
std::vector<NvDsInferParseObjectInfo>& objectList)
{
if (outputLayersInfo.empty()) {
std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl;
return false;
}
const NvDsInferLayerInfo& layer = outputLayersInfo[0];
const uint outputSize = layer.inferDims.d[0];
const uint numClasses = layer.inferDims.d[1] - 5;
if (numClasses != detectionParams.numClassesConfigured) {
std::cerr << "WARNING: Number of classes mismatch, make sure to set num-detected-classes=" << numClasses
<< " in config_infer file\n" << std::endl;
}
thrust::device_vector<NvDsInferParseObjectInfo> objects(outputSize);
std::vector<int> strides = {8, 16, 32};
std::vector<int> grid0;
std::vector<int> grid1;
std::vector<int> grid_strides;
for (uint s = 0; s < strides.size(); ++s) {
int num_grid_y = networkInfo.height / strides[s];
int num_grid_x = networkInfo.width / strides[s];
for (int g1 = 0; g1 < num_grid_y; ++g1) {
for (int g0 = 0; g0 < num_grid_x; ++g0) {
grid0.push_back(g0);
grid1.push_back(g1);
grid_strides.push_back(strides[s]);
}
}
}
thrust::device_vector<int> d_grid0(grid0);
thrust::device_vector<int> d_grid1(grid1);
thrust::device_vector<int> d_grid_strides(grid_strides);
int threads_per_block = 1024;
int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1;
decodeTensorYoloX_ONNX<<<threads_per_block, number_of_blocks>>>(
thrust::raw_pointer_cast(objects.data()), (const float*) (layer.buffer), numClasses, outputSize,
static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height),
thrust::raw_pointer_cast(d_grid0.data()), thrust::raw_pointer_cast(d_grid1.data()),
thrust::raw_pointer_cast(d_grid_strides.data()));
objectList.resize(outputSize);
thrust::copy(objects.begin(), objects.end(), objectList.begin());
return true;
}
static bool
NvDsInferParseCustom_YOLO_NAS_ONNX(std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
NvDsInferNetworkInfo const& networkInfo, NvDsInferParseDetectionParams const& detectionParams,
std::vector<NvDsInferParseObjectInfo>& objectList)
{
if (outputLayersInfo.empty()) {
std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl;
return false;
}
const NvDsInferLayerInfo& scores = outputLayersInfo[0];
const NvDsInferLayerInfo& boxes = outputLayersInfo[1];
const uint outputSize = scores.inferDims.d[0];
const uint numClasses = scores.inferDims.d[1];
if (numClasses != detectionParams.numClassesConfigured) {
std::cerr << "WARNING: Number of classes mismatch, make sure to set num-detected-classes=" << numClasses
<< " in config_infer file\n" << std::endl;
}
thrust::device_vector<NvDsInferParseObjectInfo> objects(outputSize);
int threads_per_block = 1024;
int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1;
decodeTensor_YOLO_NAS_ONNX<<<threads_per_block, number_of_blocks>>>(
thrust::raw_pointer_cast(objects.data()), (const float*) (scores.buffer), (const float*) (boxes.buffer), numClasses,
outputSize, static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height));
objectList.resize(outputSize);
thrust::copy(objects.begin(), objects.end(), objectList.begin());
return true;
}
static bool
NvDsInferParseCustom_PPYOLOE_ONNX(std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
NvDsInferNetworkInfo const& networkInfo, NvDsInferParseDetectionParams const& detectionParams,
std::vector<NvDsInferParseObjectInfo>& objectList)
{
if (outputLayersInfo.empty()) {
std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl;
return false;
}
const NvDsInferLayerInfo& scores = outputLayersInfo[0];
const NvDsInferLayerInfo& boxes = outputLayersInfo[1];
const uint numClasses = scores.inferDims.d[0];
const uint outputSize = scores.inferDims.d[1];
if (numClasses != detectionParams.numClassesConfigured) {
std::cerr << "WARNING: Number of classes mismatch, make sure to set num-detected-classes=" << numClasses
<< " in config_infer file\n" << std::endl;
}
thrust::device_vector<NvDsInferParseObjectInfo> objects(outputSize);
int threads_per_block = 1024;
int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1;
decodeTensor_PPYOLOE_ONNX<<<threads_per_block, number_of_blocks>>>(
thrust::raw_pointer_cast(objects.data()), (const float*) (scores.buffer), (const float*) (boxes.buffer), numClasses,
outputSize, static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height));
objectList.resize(outputSize);
thrust::copy(objects.begin(), objects.end(), objectList.begin());
return true;
}
extern "C" bool
NvDsInferParseYolo_ONNX(std::vector<NvDsInferLayerInfo> const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams, std::vector<NvDsInferParseObjectInfo>& objectList)
{
return NvDsInferParseCustomYolo_ONNX(outputLayersInfo, networkInfo, detectionParams, objectList);
}
extern "C" bool
NvDsInferParseYoloV8_ONNX(std::vector<NvDsInferLayerInfo> const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams, std::vector<NvDsInferParseObjectInfo>& objectList)
{
return NvDsInferParseCustomYoloV8_ONNX(outputLayersInfo, networkInfo, detectionParams, objectList);
}
extern "C" bool
NvDsInferParseYoloX_ONNX(std::vector<NvDsInferLayerInfo> const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams, std::vector<NvDsInferParseObjectInfo>& objectList)
{
return NvDsInferParseCustomYoloX_ONNX(outputLayersInfo, networkInfo, detectionParams, objectList);
}
extern "C" bool
NvDsInferParse_YOLO_NAS_ONNX(std::vector<NvDsInferLayerInfo> const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams, std::vector<NvDsInferParseObjectInfo>& objectList)
{
return NvDsInferParseCustom_YOLO_NAS_ONNX(outputLayersInfo, networkInfo, detectionParams, objectList);
}
extern "C" bool
NvDsInferParse_PPYOLOE_ONNX(std::vector<NvDsInferLayerInfo> const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams, std::vector<NvDsInferParseObjectInfo>& objectList)
{
return NvDsInferParseCustom_PPYOLOE_ONNX(outputLayersInfo, networkInfo, detectionParams, objectList);
}