Add YOLO-NAS and ONNX support
This commit is contained in:
@@ -32,7 +32,7 @@
|
||||
#include "yoloPlugins.h"
|
||||
|
||||
__global__ void decodeTensor_YOLO_ONNX(NvDsInferParseObjectInfo *binfo, const float* detections, const int numClasses,
|
||||
const int outputSize, float netW, float netH)
|
||||
const int outputSize, float netW, float netH, const float* preclusterThreshold, int* numDetections)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
@@ -52,6 +52,11 @@ __global__ void decodeTensor_YOLO_ONNX(NvDsInferParseObjectInfo *binfo, const fl
|
||||
|
||||
const float objectness = detections[x_id * (5 + numClasses) + 4];
|
||||
|
||||
if (objectness * maxProb < preclusterThreshold[maxIndex])
|
||||
return;
|
||||
|
||||
int count = (int)atomicAdd(numDetections, 1);
|
||||
|
||||
const float bxc = detections[x_id * (5 + numClasses) + 0];
|
||||
const float byc = detections[x_id * (5 + numClasses) + 1];
|
||||
const float bw = detections[x_id * (5 + numClasses) + 2];
|
||||
@@ -66,16 +71,16 @@ __global__ void decodeTensor_YOLO_ONNX(NvDsInferParseObjectInfo *binfo, const fl
|
||||
x1 = fminf(float(netW), fmaxf(float(0.0), x1));
|
||||
y1 = fminf(float(netH), fmaxf(float(0.0), y1));
|
||||
|
||||
binfo[x_id].left = x0;
|
||||
binfo[x_id].top = y0;
|
||||
binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
|
||||
binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
|
||||
binfo[x_id].detectionConfidence = objectness * maxProb;
|
||||
binfo[x_id].classId = maxIndex;
|
||||
binfo[count].left = x0;
|
||||
binfo[count].top = y0;
|
||||
binfo[count].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
|
||||
binfo[count].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
|
||||
binfo[count].detectionConfidence = objectness * maxProb;
|
||||
binfo[count].classId = maxIndex;
|
||||
}
|
||||
|
||||
__global__ void decodeTensor_YOLOV8_ONNX(NvDsInferParseObjectInfo* binfo, const float* detections, const int numClasses,
|
||||
const int outputSize, float netW, float netH)
|
||||
const int outputSize, float netW, float netH, const float* preclusterThreshold, int* numDetections)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
@@ -93,6 +98,11 @@ __global__ void decodeTensor_YOLOV8_ONNX(NvDsInferParseObjectInfo *binfo, const
|
||||
}
|
||||
}
|
||||
|
||||
if (maxProb < preclusterThreshold[maxIndex])
|
||||
return;
|
||||
|
||||
int count = (int)atomicAdd(numDetections, 1);
|
||||
|
||||
const float bxc = detections[x_id + outputSize * 0];
|
||||
const float byc = detections[x_id + outputSize * 1];
|
||||
const float bw = detections[x_id + outputSize * 2];
|
||||
@@ -107,16 +117,17 @@ __global__ void decodeTensor_YOLOV8_ONNX(NvDsInferParseObjectInfo *binfo, const
|
||||
x1 = fminf(float(netW), fmaxf(float(0.0), x1));
|
||||
y1 = fminf(float(netH), fmaxf(float(0.0), y1));
|
||||
|
||||
binfo[x_id].left = x0;
|
||||
binfo[x_id].top = y0;
|
||||
binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
|
||||
binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
|
||||
binfo[x_id].detectionConfidence = maxProb;
|
||||
binfo[x_id].classId = maxIndex;
|
||||
binfo[count].left = x0;
|
||||
binfo[count].top = y0;
|
||||
binfo[count].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
|
||||
binfo[count].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
|
||||
binfo[count].detectionConfidence = maxProb;
|
||||
binfo[count].classId = maxIndex;
|
||||
}
|
||||
|
||||
__global__ void decodeTensor_YOLOX_ONNX(NvDsInferParseObjectInfo *binfo, const float* detections, const int numClasses,
|
||||
const int outputSize, float netW, float netH, const int *grid0, const int *grid1, const int *strides)
|
||||
const int outputSize, float netW, float netH, const int *grid0, const int *grid1, const int *strides,
|
||||
const float* preclusterThreshold, int* numDetections)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
@@ -136,6 +147,11 @@ __global__ void decodeTensor_YOLOX_ONNX(NvDsInferParseObjectInfo *binfo, const f
|
||||
|
||||
const float objectness = detections[x_id * (5 + numClasses) + 4];
|
||||
|
||||
if (objectness * maxProb < preclusterThreshold[maxIndex])
|
||||
return;
|
||||
|
||||
int count = (int)atomicAdd(numDetections, 1);
|
||||
|
||||
const float bxc = (detections[x_id * (5 + numClasses) + 0] + grid0[x_id]) * strides[x_id];
|
||||
const float byc = (detections[x_id * (5 + numClasses) + 1] + grid1[x_id]) * strides[x_id];
|
||||
const float bw = __expf(detections[x_id * (5 + numClasses) + 2]) * strides[x_id];
|
||||
@@ -150,16 +166,16 @@ __global__ void decodeTensor_YOLOX_ONNX(NvDsInferParseObjectInfo *binfo, const f
|
||||
x1 = fminf(float(netW), fmaxf(float(0.0), x1));
|
||||
y1 = fminf(float(netH), fmaxf(float(0.0), y1));
|
||||
|
||||
binfo[x_id].left = x0;
|
||||
binfo[x_id].top = y0;
|
||||
binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
|
||||
binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
|
||||
binfo[x_id].detectionConfidence = objectness * maxProb;
|
||||
binfo[x_id].classId = maxIndex;
|
||||
binfo[count].left = x0;
|
||||
binfo[count].top = y0;
|
||||
binfo[count].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
|
||||
binfo[count].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
|
||||
binfo[count].detectionConfidence = objectness * maxProb;
|
||||
binfo[count].classId = maxIndex;
|
||||
}
|
||||
|
||||
__global__ void decodeTensor_YOLO_NAS_ONNX(NvDsInferParseObjectInfo *binfo, const float* scores, const float* boxes,
|
||||
const int numClasses, const int outputSize, float netW, float netH)
|
||||
const int numClasses, const int outputSize, float netW, float netH, const float* preclusterThreshold, int* numDetections)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
@@ -177,6 +193,11 @@ __global__ void decodeTensor_YOLO_NAS_ONNX(NvDsInferParseObjectInfo *binfo, cons
|
||||
}
|
||||
}
|
||||
|
||||
if (maxProb < preclusterThreshold[maxIndex])
|
||||
return;
|
||||
|
||||
int count = (int)atomicAdd(numDetections, 1);
|
||||
|
||||
float x0 = boxes[x_id * 4 + 0];
|
||||
float y0 = boxes[x_id * 4 + 1];
|
||||
float x1 = boxes[x_id * 4 + 2];
|
||||
@@ -187,16 +208,16 @@ __global__ void decodeTensor_YOLO_NAS_ONNX(NvDsInferParseObjectInfo *binfo, cons
|
||||
x1 = fminf(float(netW), fmaxf(float(0.0), x1));
|
||||
y1 = fminf(float(netH), fmaxf(float(0.0), y1));
|
||||
|
||||
binfo[x_id].left = x0;
|
||||
binfo[x_id].top = y0;
|
||||
binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
|
||||
binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
|
||||
binfo[x_id].detectionConfidence = maxProb;
|
||||
binfo[x_id].classId = maxIndex;
|
||||
binfo[count].left = x0;
|
||||
binfo[count].top = y0;
|
||||
binfo[count].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
|
||||
binfo[count].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
|
||||
binfo[count].detectionConfidence = maxProb;
|
||||
binfo[count].classId = maxIndex;
|
||||
}
|
||||
|
||||
__global__ void decodeTensor_PPYOLOE_ONNX(NvDsInferParseObjectInfo *binfo, const float* scores, const float* boxes,
|
||||
const int numClasses, const int outputSize, float netW, float netH)
|
||||
const int numClasses, const int outputSize, float netW, float netH, const float* preclusterThreshold, int* numDetections)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
@@ -214,6 +235,11 @@ __global__ void decodeTensor_PPYOLOE_ONNX(NvDsInferParseObjectInfo *binfo, const
|
||||
}
|
||||
}
|
||||
|
||||
if (maxProb < preclusterThreshold[maxIndex])
|
||||
return;
|
||||
|
||||
int count = (int)atomicAdd(numDetections, 1);
|
||||
|
||||
float x0 = boxes[x_id * 4 + 0];
|
||||
float y0 = boxes[x_id * 4 + 1];
|
||||
float x1 = boxes[x_id * 4 + 2];
|
||||
@@ -224,12 +250,12 @@ __global__ void decodeTensor_PPYOLOE_ONNX(NvDsInferParseObjectInfo *binfo, const
|
||||
x1 = fminf(float(netW), fmaxf(float(0.0), x1));
|
||||
y1 = fminf(float(netH), fmaxf(float(0.0), y1));
|
||||
|
||||
binfo[x_id].left = x0;
|
||||
binfo[x_id].top = y0;
|
||||
binfo[x_id].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
|
||||
binfo[x_id].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
|
||||
binfo[x_id].detectionConfidence = maxProb;
|
||||
binfo[x_id].classId = maxIndex;
|
||||
binfo[count].left = x0;
|
||||
binfo[count].top = y0;
|
||||
binfo[count].width = fminf(float(netW), fmaxf(float(0.0), x1 - x0));
|
||||
binfo[count].height = fminf(float(netH), fmaxf(float(0.0), y1 - y0));
|
||||
binfo[count].detectionConfidence = maxProb;
|
||||
binfo[count].classId = maxIndex;
|
||||
}
|
||||
|
||||
static bool
|
||||
@@ -254,15 +280,22 @@ NvDsInferParseCustom_YOLO_ONNX(std::vector<NvDsInferLayerInfo> const& outputLaye
|
||||
|
||||
thrust::device_vector<NvDsInferParseObjectInfo> objects(outputSize);
|
||||
|
||||
std::vector<int> numDetections = { 0 };
|
||||
thrust::device_vector<int> d_numDetections(numDetections);
|
||||
|
||||
thrust::device_vector<float> preclusterThreshold(detectionParams.perClassPreclusterThreshold);
|
||||
|
||||
int threads_per_block = 1024;
|
||||
int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1;
|
||||
|
||||
decodeTensor_YOLO_ONNX<<<threads_per_block, number_of_blocks>>>(
|
||||
thrust::raw_pointer_cast(objects.data()), (const float*) (layer.buffer), numClasses, outputSize,
|
||||
static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height));
|
||||
static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height),
|
||||
thrust::raw_pointer_cast(preclusterThreshold.data()), thrust::raw_pointer_cast(d_numDetections.data()));
|
||||
|
||||
objectList.resize(outputSize);
|
||||
thrust::copy(objects.begin(), objects.end(), objectList.begin());
|
||||
thrust::copy(d_numDetections.begin(), d_numDetections.end(), numDetections.begin());
|
||||
objectList.resize(numDetections[0]);
|
||||
thrust::copy(objects.begin(), objects.begin() + numDetections[0], objectList.begin());
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -289,15 +322,22 @@ NvDsInferParseCustom_YOLOV8_ONNX(std::vector<NvDsInferLayerInfo> const& outputLa
|
||||
|
||||
thrust::device_vector<NvDsInferParseObjectInfo> objects(outputSize);
|
||||
|
||||
std::vector<int> numDetections = { 0 };
|
||||
thrust::device_vector<int> d_numDetections(numDetections);
|
||||
|
||||
thrust::device_vector<float> preclusterThreshold(detectionParams.perClassPreclusterThreshold);
|
||||
|
||||
int threads_per_block = 1024;
|
||||
int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1;
|
||||
|
||||
decodeTensor_YOLOV8_ONNX<<<threads_per_block, number_of_blocks>>>(
|
||||
thrust::raw_pointer_cast(objects.data()), (const float*) (layer.buffer), numClasses, outputSize,
|
||||
static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height));
|
||||
static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height),
|
||||
thrust::raw_pointer_cast(preclusterThreshold.data()), thrust::raw_pointer_cast(d_numDetections.data()));
|
||||
|
||||
objectList.resize(outputSize);
|
||||
thrust::copy(objects.begin(), objects.end(), objectList.begin());
|
||||
thrust::copy(d_numDetections.begin(), d_numDetections.end(), numDetections.begin());
|
||||
objectList.resize(numDetections[0]);
|
||||
thrust::copy(objects.begin(), objects.begin() + numDetections[0], objectList.begin());
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -324,11 +364,16 @@ NvDsInferParseCustom_YOLOX_ONNX(std::vector<NvDsInferLayerInfo> const& outputLay
|
||||
|
||||
thrust::device_vector<NvDsInferParseObjectInfo> objects(outputSize);
|
||||
|
||||
std::vector<int> numDetections = { 0 };
|
||||
thrust::device_vector<int> d_numDetections(numDetections);
|
||||
|
||||
thrust::device_vector<float> preclusterThreshold(detectionParams.perClassPreclusterThreshold);
|
||||
|
||||
std::vector<int> strides = {8, 16, 32};
|
||||
|
||||
std::vector<int> grid0;
|
||||
std::vector<int> grid1;
|
||||
std::vector<int> grid_strides;
|
||||
std::vector<int> gridStrides;
|
||||
|
||||
for (uint s = 0; s < strides.size(); ++s) {
|
||||
int num_grid_y = networkInfo.height / strides[s];
|
||||
@@ -337,14 +382,14 @@ NvDsInferParseCustom_YOLOX_ONNX(std::vector<NvDsInferLayerInfo> const& outputLay
|
||||
for (int g0 = 0; g0 < num_grid_x; ++g0) {
|
||||
grid0.push_back(g0);
|
||||
grid1.push_back(g1);
|
||||
grid_strides.push_back(strides[s]);
|
||||
gridStrides.push_back(strides[s]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
thrust::device_vector<int> d_grid0(grid0);
|
||||
thrust::device_vector<int> d_grid1(grid1);
|
||||
thrust::device_vector<int> d_grid_strides(grid_strides);
|
||||
thrust::device_vector<int> d_gridStrides(gridStrides);
|
||||
|
||||
int threads_per_block = 1024;
|
||||
int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1;
|
||||
@@ -353,10 +398,12 @@ NvDsInferParseCustom_YOLOX_ONNX(std::vector<NvDsInferLayerInfo> const& outputLay
|
||||
thrust::raw_pointer_cast(objects.data()), (const float*) (layer.buffer), numClasses, outputSize,
|
||||
static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height),
|
||||
thrust::raw_pointer_cast(d_grid0.data()), thrust::raw_pointer_cast(d_grid1.data()),
|
||||
thrust::raw_pointer_cast(d_grid_strides.data()));
|
||||
thrust::raw_pointer_cast(d_gridStrides.data()), thrust::raw_pointer_cast(preclusterThreshold.data()),
|
||||
thrust::raw_pointer_cast(d_numDetections.data()));
|
||||
|
||||
objectList.resize(outputSize);
|
||||
thrust::copy(objects.begin(), objects.end(), objectList.begin());
|
||||
thrust::copy(d_numDetections.begin(), d_numDetections.end(), numDetections.begin());
|
||||
objectList.resize(numDetections[0]);
|
||||
thrust::copy(objects.begin(), objects.begin() + numDetections[0], objectList.begin());
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -384,15 +431,22 @@ NvDsInferParseCustom_YOLO_NAS_ONNX(std::vector<NvDsInferLayerInfo> const& output
|
||||
|
||||
thrust::device_vector<NvDsInferParseObjectInfo> objects(outputSize);
|
||||
|
||||
std::vector<int> numDetections = { 0 };
|
||||
thrust::device_vector<int> d_numDetections(numDetections);
|
||||
|
||||
thrust::device_vector<float> preclusterThreshold(detectionParams.perClassPreclusterThreshold);
|
||||
|
||||
int threads_per_block = 1024;
|
||||
int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1;
|
||||
|
||||
decodeTensor_YOLO_NAS_ONNX<<<threads_per_block, number_of_blocks>>>(
|
||||
thrust::raw_pointer_cast(objects.data()), (const float*) (scores.buffer), (const float*) (boxes.buffer), numClasses,
|
||||
outputSize, static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height));
|
||||
outputSize, static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height),
|
||||
thrust::raw_pointer_cast(preclusterThreshold.data()), thrust::raw_pointer_cast(d_numDetections.data()));
|
||||
|
||||
objectList.resize(outputSize);
|
||||
thrust::copy(objects.begin(), objects.end(), objectList.begin());
|
||||
thrust::copy(d_numDetections.begin(), d_numDetections.end(), numDetections.begin());
|
||||
objectList.resize(numDetections[0]);
|
||||
thrust::copy(objects.begin(), objects.begin() + numDetections[0], objectList.begin());
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -420,15 +474,22 @@ NvDsInferParseCustom_PPYOLOE_ONNX(std::vector<NvDsInferLayerInfo> const& outputL
|
||||
|
||||
thrust::device_vector<NvDsInferParseObjectInfo> objects(outputSize);
|
||||
|
||||
std::vector<int> numDetections = { 0 };
|
||||
thrust::device_vector<int> d_numDetections(numDetections);
|
||||
|
||||
thrust::device_vector<float> preclusterThreshold(detectionParams.perClassPreclusterThreshold);
|
||||
|
||||
int threads_per_block = 1024;
|
||||
int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1;
|
||||
|
||||
decodeTensor_PPYOLOE_ONNX<<<threads_per_block, number_of_blocks>>>(
|
||||
thrust::raw_pointer_cast(objects.data()), (const float*) (scores.buffer), (const float*) (boxes.buffer), numClasses,
|
||||
outputSize, static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height));
|
||||
outputSize, static_cast<float>(networkInfo.width), static_cast<float>(networkInfo.height),
|
||||
thrust::raw_pointer_cast(preclusterThreshold.data()), thrust::raw_pointer_cast(d_numDetections.data()));
|
||||
|
||||
objectList.resize(outputSize);
|
||||
thrust::copy(objects.begin(), objects.end(), objectList.begin());
|
||||
thrust::copy(d_numDetections.begin(), d_numDetections.end(), numDetections.begin());
|
||||
objectList.resize(numDetections[0]);
|
||||
thrust::copy(objects.begin(), objects.begin() + numDetections[0], objectList.begin());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user