Add YOLOX support
This commit is contained in:
@@ -44,11 +44,11 @@ detectV8Layer(int layerIdx, std::map<std::string, std::string>& block, std::vect
|
||||
shuffle1Box->setName(shuffle1BoxLayerName.c_str());
|
||||
nvinfer1::Dims reshape1Dims = {3, {4, reg_max, inputDims.d[1]}};
|
||||
shuffle1Box->setReshapeDimensions(reshape1Dims);
|
||||
nvinfer1::Permutation permutation1;
|
||||
permutation1.order[0] = 1;
|
||||
permutation1.order[1] = 0;
|
||||
permutation1.order[2] = 2;
|
||||
shuffle1Box->setSecondTranspose(permutation1);
|
||||
nvinfer1::Permutation permutation1Box;
|
||||
permutation1Box.order[0] = 1;
|
||||
permutation1Box.order[1] = 0;
|
||||
permutation1Box.order[2] = 2;
|
||||
shuffle1Box->setSecondTranspose(permutation1Box);
|
||||
box = shuffle1Box->getOutput(0);
|
||||
|
||||
nvinfer1::ISoftMaxLayer* softmax = network->addSoftMax(*box);
|
||||
@@ -186,10 +186,10 @@ detectV8Layer(int layerIdx, std::map<std::string, std::string>& block, std::vect
|
||||
assert(shuffle != nullptr);
|
||||
std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx);
|
||||
shuffle->setName(shuffleLayerName.c_str());
|
||||
nvinfer1::Permutation permutation2;
|
||||
permutation2.order[0] = 1;
|
||||
permutation2.order[1] = 0;
|
||||
shuffle->setFirstTranspose(permutation2);
|
||||
nvinfer1::Permutation permutation;
|
||||
permutation.order[0] = 1;
|
||||
permutation.order[1] = 0;
|
||||
shuffle->setFirstTranspose(permutation);
|
||||
output = shuffle->getOutput(0);
|
||||
|
||||
return output;
|
||||
|
||||
@@ -18,16 +18,15 @@ shuffleLayer(int layerIdx, std::string& layer, std::map<std::string, std::string
|
||||
std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx);
|
||||
shuffle->setName(shuffleLayerName.c_str());
|
||||
|
||||
int from = -1;
|
||||
if (block.find("from") != block.end())
|
||||
from = std::stoi(block.at("from"));
|
||||
if (from < 0)
|
||||
from = tensorOutputs.size() + from;
|
||||
|
||||
layer = std::to_string(from);
|
||||
|
||||
if (block.find("reshape") != block.end()) {
|
||||
int from = -1;
|
||||
if (block.find("from") != block.end())
|
||||
from = std::stoi(block.at("from"));
|
||||
|
||||
if (from < 0)
|
||||
from = tensorOutputs.size() + from;
|
||||
|
||||
layer = std::to_string(from);
|
||||
|
||||
nvinfer1::Dims inputTensorDims = tensorOutputs[from]->getDimensions();
|
||||
|
||||
std::string strReshape = block.at("reshape");
|
||||
|
||||
@@ -136,7 +136,7 @@ Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::INetworkDefinition
|
||||
|
||||
float eps = 1.0e-5;
|
||||
if (m_NetworkType.find("yolov5") != std::string::npos || m_NetworkType.find("yolov7") != std::string::npos ||
|
||||
m_NetworkType.find("yolov8") != std::string::npos)
|
||||
m_NetworkType.find("yolov8") != std::string::npos || m_NetworkType.find("yolox") != std::string::npos)
|
||||
eps = 1.0e-3;
|
||||
else if (m_NetworkType.find("yolor") != std::string::npos)
|
||||
eps = 1.0e-4;
|
||||
@@ -398,6 +398,23 @@ Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::INetworkDefinition
|
||||
std::string layerName = "detect_v8";
|
||||
printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr));
|
||||
}
|
||||
else if (m_ConfigBlocks.at(i).at("type") == "detect_x") {
|
||||
modelType = 5;
|
||||
|
||||
std::string blobName = "detect_x_" + std::to_string(i);
|
||||
nvinfer1::Dims prevTensorDims = previous->getDimensions();
|
||||
TensorInfo& curYoloTensor = m_YoloTensors.at(yoloCountInputs);
|
||||
curYoloTensor.blobName = blobName;
|
||||
curYoloTensor.numBBoxes = prevTensorDims.d[0];
|
||||
m_NumClasses = prevTensorDims.d[1] - 5;
|
||||
|
||||
std::string outputVol = dimsToString(previous->getDimensions());
|
||||
tensorOutputs.push_back(previous);
|
||||
yoloTensorInputs[yoloCountInputs] = previous;
|
||||
++yoloCountInputs;
|
||||
std::string layerName = "detect_x";
|
||||
printLayerInfo(layerIndex, layerName, "-", outputVol, std::to_string(weightPtr));
|
||||
}
|
||||
else {
|
||||
std::cerr << "\nUnsupported layer type --> \"" << m_ConfigBlocks.at(i).at("type") << "\"" << std::endl;
|
||||
assert(0);
|
||||
@@ -415,7 +432,7 @@ Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::INetworkDefinition
|
||||
uint64_t outputSize = 0;
|
||||
for (uint j = 0; j < yoloCountInputs; ++j) {
|
||||
TensorInfo& curYoloTensor = m_YoloTensors.at(j);
|
||||
if (modelType == 3 || modelType == 4)
|
||||
if (modelType == 3 || modelType == 4 || modelType == 5)
|
||||
outputSize = curYoloTensor.numBBoxes;
|
||||
else
|
||||
outputSize += curYoloTensor.gridSizeX * curYoloTensor.gridSizeY * curYoloTensor.numBBoxes;
|
||||
@@ -587,6 +604,41 @@ Yolo::parseConfigBlocks()
|
||||
TensorInfo outputTensor;
|
||||
m_YoloTensors.push_back(outputTensor);
|
||||
}
|
||||
else if (block.at("type") == "detect_x") {
|
||||
++m_YoloCount;
|
||||
TensorInfo outputTensor;
|
||||
|
||||
std::vector<int> strides;
|
||||
|
||||
std::string stridesString = block.at("strides");
|
||||
while (!stridesString.empty()) {
|
||||
int npos = stridesString.find_first_of(',');
|
||||
if (npos != -1) {
|
||||
int stride = std::stof(trim(stridesString.substr(0, npos)));
|
||||
strides.push_back(stride);
|
||||
stridesString.erase(0, npos + 1);
|
||||
}
|
||||
else {
|
||||
int stride = std::stof(trim(stridesString));
|
||||
strides.push_back(stride);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint i = 0; i < strides.size(); ++i) {
|
||||
int num_grid_y = m_InputH / strides[i];
|
||||
int num_grid_x = m_InputW / strides[i];
|
||||
for (int g1 = 0; g1 < num_grid_y; ++g1) {
|
||||
for (int g0 = 0; g0 < num_grid_x; ++g0) {
|
||||
outputTensor.anchors.push_back((float) g0);
|
||||
outputTensor.anchors.push_back((float) g1);
|
||||
outputTensor.mask.push_back(strides[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m_YoloTensors.push_back(outputTensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ __global__ void gpuYoloLayer_v8(const float* input, int* num_detections, float*
|
||||
int maxIndex = -1;
|
||||
|
||||
for (uint i = 0; i < numOutputClasses; ++i) {
|
||||
float prob = input[x_id * (4 + numOutputClasses) + i + 4];
|
||||
float prob = input[x_id * (4 + numOutputClasses) + 4 + i];
|
||||
if (prob > maxProb) {
|
||||
maxProb = prob;
|
||||
maxIndex = i;
|
||||
|
||||
73
nvdsinfer_custom_impl_Yolo/yoloForward_x.cu
Normal file
73
nvdsinfer_custom_impl_Yolo/yoloForward_x.cu
Normal file
@@ -0,0 +1,73 @@
|
||||
/*
|
||||
* Created by Marcos Luciano
|
||||
* https://www.github.com/marcoslucianops
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
__global__ void gpuYoloLayer_x(const float* input, int* num_detections, float* detection_boxes, float* detection_scores,
|
||||
int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight,
|
||||
const uint numOutputClasses, const uint64_t outputSize, const float* anchors, const int* mask)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (x_id >= outputSize)
|
||||
return;
|
||||
|
||||
const float objectness = input[x_id * (5 + numOutputClasses) + 4];
|
||||
|
||||
if (objectness < scoreThreshold)
|
||||
return;
|
||||
|
||||
int count = (int)atomicAdd(num_detections, 1);
|
||||
|
||||
float x = (input[x_id * (5 + numOutputClasses) + 0] + anchors[x_id * 2]) * mask[x_id];
|
||||
|
||||
float y = (input[x_id * (5 + numOutputClasses) + 1] + anchors[x_id * 2 + 1]) * mask[x_id];
|
||||
|
||||
float w = __expf(input[x_id * (5 + numOutputClasses) + 2]) * mask[x_id];
|
||||
|
||||
float h = __expf(input[x_id * (5 + numOutputClasses) + 3]) * mask[x_id];
|
||||
|
||||
float maxProb = 0.0f;
|
||||
int maxIndex = -1;
|
||||
|
||||
for (uint i = 0; i < numOutputClasses; ++i) {
|
||||
float prob = input[x_id * (5 + numOutputClasses) + 5 + i];
|
||||
if (prob > maxProb) {
|
||||
maxProb = prob;
|
||||
maxIndex = i;
|
||||
}
|
||||
}
|
||||
|
||||
detection_boxes[count * 4 + 0] = x - 0.5 * w;
|
||||
detection_boxes[count * 4 + 1] = y - 0.5 * h;
|
||||
detection_boxes[count * 4 + 2] = x + 0.5 * w;
|
||||
detection_boxes[count * 4 + 3] = y + 0.5 * h;
|
||||
detection_scores[count] = objectness * maxProb;
|
||||
detection_classes[count] = maxIndex;
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer_x(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& numOutputClasses, const void* anchors, const void* mask, cudaStream_t stream);
|
||||
|
||||
cudaError_t cudaYoloLayer_x(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& numOutputClasses, const void* anchors, const void* mask, cudaStream_t stream)
|
||||
{
|
||||
int threads_per_block = 16;
|
||||
int number_of_blocks = (outputSize / threads_per_block) + 1;
|
||||
|
||||
for (unsigned int batch = 0; batch < batchSize; ++batch) {
|
||||
gpuYoloLayer_x<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * (5 + numOutputClasses) * outputSize),
|
||||
reinterpret_cast<int*>(num_detections) + (batch),
|
||||
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
|
||||
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
|
||||
scoreThreshold, netWidth, netHeight, numOutputClasses, outputSize, reinterpret_cast<const float*>(anchors),
|
||||
reinterpret_cast<const int*>(mask));
|
||||
}
|
||||
return cudaGetLastError();
|
||||
}
|
||||
@@ -38,6 +38,10 @@ namespace {
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer_x(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& numOutputClasses, const void* anchors, const void* mask, cudaStream_t stream);
|
||||
|
||||
cudaError_t cudaYoloLayer_v8(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream);
|
||||
@@ -158,7 +162,35 @@ YoloLayer::enqueue(int batchSize, void const* const* inputs, void* const* output
|
||||
CUDA_CHECK(cudaMemsetAsync((float*)detection_scores, 0, sizeof(float) * m_OutputSize * batchSize, stream));
|
||||
CUDA_CHECK(cudaMemsetAsync((int*)detection_classes, 0, sizeof(int) * m_OutputSize * batchSize, stream));
|
||||
|
||||
if (m_Type == 4) {
|
||||
if (m_Type == 5) {
|
||||
TensorInfo& curYoloTensor = m_YoloTensors.at(0);
|
||||
std::vector<float> anchors = curYoloTensor.anchors;
|
||||
std::vector<int> mask = curYoloTensor.mask;
|
||||
|
||||
void* v_anchors;
|
||||
void* v_mask;
|
||||
if (anchors.size() > 0) {
|
||||
float* f_anchors = anchors.data();
|
||||
CUDA_CHECK(cudaMalloc(&v_anchors, sizeof(float) * anchors.size()));
|
||||
CUDA_CHECK(cudaMemcpyAsync(v_anchors, f_anchors, sizeof(float) * anchors.size(), cudaMemcpyHostToDevice, stream));
|
||||
}
|
||||
if (mask.size() > 0) {
|
||||
int* f_mask = mask.data();
|
||||
CUDA_CHECK(cudaMalloc(&v_mask, sizeof(int) * mask.size()));
|
||||
CUDA_CHECK(cudaMemcpyAsync(v_mask, f_mask, sizeof(int) * mask.size(), cudaMemcpyHostToDevice, stream));
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaYoloLayer_x(inputs[0], num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
|
||||
m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, v_anchors, v_mask, stream));
|
||||
|
||||
if (anchors.size() > 0) {
|
||||
CUDA_CHECK(cudaFree(v_anchors));
|
||||
}
|
||||
if (mask.size() > 0) {
|
||||
CUDA_CHECK(cudaFree(v_mask));
|
||||
}
|
||||
}
|
||||
else if (m_Type == 4) {
|
||||
CUDA_CHECK(cudaYoloLayer_v8(inputs[0], num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
|
||||
m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, stream));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user