From 9565254551418c3b5309671df38386817ae67d7c Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 12 Dec 2021 09:58:23 -0300 Subject: [PATCH] Fix YOLO kernels - Fix YOLO kernels - Update deprecated functions --- nvdsinfer_custom_impl_Yolo/Makefile | 5 +- .../layers/convolutional_layer.cpp | 6 +- .../layers/maxpool_layer.cpp | 4 +- .../layers/upsample_layer.cpp | 2 +- nvdsinfer_custom_impl_Yolo/yolo.cpp | 16 +- nvdsinfer_custom_impl_Yolo/yoloForward.cu | 143 +++--------------- nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu | 74 +++++++++ nvdsinfer_custom_impl_Yolo/yoloForward_r.cu | 71 +++++++++ nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu | 80 ++++++++++ nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp | 66 ++++++-- readme.md | 2 +- 11 files changed, 316 insertions(+), 153 deletions(-) create mode 100644 nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu create mode 100644 nvdsinfer_custom_impl_Yolo/yoloForward_r.cu create mode 100644 nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu diff --git a/nvdsinfer_custom_impl_Yolo/Makefile b/nvdsinfer_custom_impl_Yolo/Makefile index b063e83..9b8b82b 100644 --- a/nvdsinfer_custom_impl_Yolo/Makefile +++ b/nvdsinfer_custom_impl_Yolo/Makefile @@ -63,7 +63,10 @@ SRCFILES:= nvdsinfer_yolo_engine.cpp \ layers/activation_layer.cpp \ utils.cpp \ yolo.cpp \ - yoloForward.cu + yoloForward.cu \ + yoloForward_v2.cu \ + yoloForward_nc.cu \ + yoloForward_r.cu ifeq ($(OPENCV), 1) SRCFILES+= calibrator.cpp diff --git a/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.cpp index 1be7b3f..08bd57e 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.cpp @@ -164,13 +164,13 @@ nvinfer1::ILayer* convolutionalLayer( } } - nvinfer1::IConvolutionLayer* conv = network->addConvolution( + nvinfer1::IConvolutionLayer* conv = network->addConvolutionNd( *input, filters, nvinfer1::DimsHW{kernelSize, kernelSize}, convWt, convBias); assert(conv != nullptr); std::string convLayerName = "conv_" + std::to_string(layerIdx); conv->setName(convLayerName.c_str()); - conv->setStride(nvinfer1::DimsHW{stride, stride}); - conv->setPadding(nvinfer1::DimsHW{pad, pad}); + conv->setStrideNd(nvinfer1::DimsHW{stride, stride}); + conv->setPaddingNd(nvinfer1::DimsHW{pad, pad}); if (block.find("groups") != block.end()) { diff --git a/nvdsinfer_custom_impl_Yolo/layers/maxpool_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/maxpool_layer.cpp index 06948dc..e5e53bf 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/maxpool_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/maxpool_layer.cpp @@ -19,10 +19,10 @@ nvinfer1::ILayer* maxpoolLayer( int stride = std::stoi(block.at("stride")); nvinfer1::IPoolingLayer* pool - = network->addPooling(*input, nvinfer1::PoolingType::kMAX, nvinfer1::DimsHW{size, size}); + = network->addPoolingNd(*input, nvinfer1::PoolingType::kMAX, nvinfer1::DimsHW{size, size}); assert(pool); std::string maxpoolLayerName = "maxpool_" + std::to_string(layerIdx); - pool->setStride(nvinfer1::DimsHW{stride, stride}); + pool->setStrideNd(nvinfer1::DimsHW{stride, stride}); pool->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); pool->setName(maxpoolLayerName.c_str()); diff --git a/nvdsinfer_custom_impl_Yolo/layers/upsample_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/upsample_layer.cpp index eb49011..f268bd2 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/upsample_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/upsample_layer.cpp @@ -16,7 +16,7 @@ nvinfer1::ILayer* upsampleLayer( nvinfer1::IResizeLayer* resize_layer = network->addResize(*input); resize_layer->setResizeMode(nvinfer1::ResizeMode::kNEAREST); - float scale[3] = {1, stride, stride}; + float scale[3] = {1, static_cast(stride), static_cast(stride)}; resize_layer->setScales(scale, 3); std::string layer_name = "upsample_" + std::to_string(layerIdx); resize_layer->setName(layer_name.c_str()); diff --git a/nvdsinfer_custom_impl_Yolo/yolo.cpp b/nvdsinfer_custom_impl_Yolo/yolo.cpp index 9d33ff3..bf65f26 100644 --- a/nvdsinfer_custom_impl_Yolo/yolo.cpp +++ b/nvdsinfer_custom_impl_Yolo/yolo.cpp @@ -75,7 +75,7 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder) nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0); if (parseModel(*network) != NVDSINFER_SUCCESS) { - network->destroy(); + delete network; return nullptr; } @@ -122,7 +122,7 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder) std::cerr << "Building engine failed\n" << std::endl; } - network->destroy(); + delete network; delete config; return engine; } @@ -232,12 +232,11 @@ NvDsInferStatus Yolo::buildYoloNetwork( printLayerInfo(layerIndex, layerType, " -", outputVol, " -"); } - else if (m_ConfigBlocks.at(i).at("type") == "dropout") { + else if (m_ConfigBlocks.at(i).at("type") == "dropout") { // Skip dropout layer assert(m_ConfigBlocks.at(i).find("probability") != m_ConfigBlocks.at(i).end()); //float probability = std::stof(m_ConfigBlocks.at(i).at("probability")); //nvinfer1::ILayer* out = dropoutLayer(probability, previous, &network); //previous = out->getOutput(0); - //Skip dropout layer assert(previous != nullptr); tensorOutputs.push_back(previous); printLayerInfo(layerIndex, "dropout", " -", " -", " -"); @@ -300,6 +299,13 @@ NvDsInferStatus Yolo::buildYoloNetwork( } else if (m_ConfigBlocks.at(i).at("type") == "yolo") { + uint model_type; + if (m_NetworkType.find("yolor") != std::string::npos) { + model_type = 2; + } + else { + model_type = 1; + } nvinfer1::Dims prevTensorDims = previous->getDimensions(); TensorInfo& curYoloTensor = m_OutputTensors.at(outputTensorCount); curYoloTensor.gridSizeY = prevTensorDims.d[1]; @@ -327,7 +333,7 @@ NvDsInferStatus Yolo::buildYoloNetwork( curYoloTensor.numClasses, curYoloTensor.gridSizeX, curYoloTensor.gridSizeY, - 1, new_coords, scale_x_y, beta_nms, + model_type, new_coords, scale_x_y, beta_nms, curYoloTensor.anchors, m_OutputMasks); assert(yoloPlugin != nullptr); diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward.cu b/nvdsinfer_custom_impl_Yolo/yoloForward.cu index cbaa29f..b3ceab2 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloForward.cu +++ b/nvdsinfer_custom_impl_Yolo/yoloForward.cu @@ -21,7 +21,7 @@ inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); } __global__ void gpuYoloLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses, - const uint numBBoxes, const uint new_coords, const float scale_x_y) + const uint numBBoxes, const float scale_x_y) { uint x_id = blockIdx.x * blockDim.x + threadIdx.x; uint y_id = blockIdx.y * blockDim.y + threadIdx.y; @@ -35,97 +35,14 @@ __global__ void gpuYoloLayer(const float* input, float* output, const uint gridS const int numGridCells = gridSizeX * gridSizeY; const int bbindex = y_id * gridSizeX + x_id; - float alpha = scale_x_y; - float beta = -0.5 * (scale_x_y - 1); - - if (new_coords == 1) { - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] - = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] * alpha + beta; - - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] - = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] * alpha + beta; - - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] - = pow(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] * 2, 2); - - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] - = pow(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] * 2, 2); - - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)] - = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]; - - for (uint i = 0; i < numOutputClasses; ++i) - { - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] - = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]; - } - } - else if (new_coords == 0 && scale_x_y != 1) { // YOLOR incorrect param - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * 2.0 - 0.5; - - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * 2.0 - 0.5; - - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] - = pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * 2, 2); - - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] - = pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * 2, 2); - - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)] - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]); - - for (uint i = 0; i < numOutputClasses; ++i) - { - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]); - } - } - else { - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta; - - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * alpha + beta; - - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] - = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]); - - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] - = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]); - - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)] - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]); - - for (uint i = 0; i < numOutputClasses; ++i) - { - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]); - } - } -} - -__global__ void gpuRegionLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses, - const uint numBBoxes) -{ - uint x_id = blockIdx.x * blockDim.x + threadIdx.x; - uint y_id = blockIdx.y * blockDim.y + threadIdx.y; - uint z_id = blockIdx.z * blockDim.z + threadIdx.z; - - if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes)) - { - return; - } - - const int numGridCells = gridSizeX * gridSizeY; - const int bbindex = y_id * gridSizeX + x_id; + const float alpha = scale_x_y; + const float beta = -0.5 * (scale_x_y - 1); output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]); + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta; output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]); + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * alpha + beta; output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]); @@ -136,53 +53,31 @@ __global__ void gpuRegionLayer(const float* input, float* output, const uint gri output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)] = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]); - float temp = 1.0; - int i; - float sum = 0; - float largest = -INFINITY; - for(i = 0; i < numOutputClasses; ++i){ - int val = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]; - largest = (val>largest) ? val : largest; - } - for(i = 0; i < numOutputClasses; ++i){ - float e = exp(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] / temp - largest / temp); - sum += e; - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] = e; - } - for(i = 0; i < numOutputClasses; ++i){ - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] /= sum; + for (uint i = 0; i < numOutputClasses; ++i) + { + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]); } } cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY, - const uint& numOutputClasses, const uint& numBBoxes, - uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType); + const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, + const float modelScale); cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY, - const uint& numOutputClasses, const uint& numBBoxes, - uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType) + const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, + const float modelScale) { dim3 threads_per_block(16, 16, 4); dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, (gridSizeY / threads_per_block.y) + 1, (numBBoxes / threads_per_block.z) + 1); - if (modelType == 1) { - for (unsigned int batch = 0; batch < batchSize; ++batch) - { - gpuYoloLayer<<>>( - reinterpret_cast(input) + (batch * outputSize), - reinterpret_cast(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses, - numBBoxes, modelCoords, modelScale); - } - } - else if (modelType == 0) { - for (unsigned int batch = 0; batch < batchSize; ++batch) - { - gpuRegionLayer<<>>( - reinterpret_cast(input) + (batch * outputSize), - reinterpret_cast(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses, - numBBoxes); - } + for (unsigned int batch = 0; batch < batchSize; ++batch) + { + gpuYoloLayer<<>>( + reinterpret_cast(input) + (batch * outputSize), + reinterpret_cast(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses, + numBBoxes, modelScale); } return cudaGetLastError(); } diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu b/nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu new file mode 100644 index 0000000..4894ae7 --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu @@ -0,0 +1,74 @@ +/* + * Created by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#include +#include +#include +#include +#include + +inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); } + +__global__ void gpuYoloLayer_nc(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses, + const uint numBBoxes, const float scale_x_y) +{ + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + uint y_id = blockIdx.y * blockDim.y + threadIdx.y; + uint z_id = blockIdx.z * blockDim.z + threadIdx.z; + + if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes)) + { + return; + } + + const int numGridCells = gridSizeX * gridSizeY; + const int bbindex = y_id * gridSizeX + x_id; + + const float alpha = scale_x_y; + const float beta = -0.5 * (scale_x_y - 1); + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] + = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] * alpha + beta; + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] + = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] * alpha + beta; + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] + = pow(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] * 2, 2); + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] + = pow(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] * 2, 2); + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)] + = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]; + + for (uint i = 0; i < numOutputClasses; ++i) + { + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] + = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]; + } +} + +cudaError_t cudaYoloLayer_nc(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY, + const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, + const float modelScale); + +cudaError_t cudaYoloLayer_nc(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY, + const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, + const float modelScale) +{ + dim3 threads_per_block(16, 16, 4); + dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, + (gridSizeY / threads_per_block.y) + 1, + (numBBoxes / threads_per_block.z) + 1); + for (unsigned int batch = 0; batch < batchSize; ++batch) + { + gpuYoloLayer_nc<<>>( + reinterpret_cast(input) + (batch * outputSize), + reinterpret_cast(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses, + numBBoxes, modelScale); + } + return cudaGetLastError(); +} diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward_r.cu b/nvdsinfer_custom_impl_Yolo/yoloForward_r.cu new file mode 100644 index 0000000..197b22d --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/yoloForward_r.cu @@ -0,0 +1,71 @@ +/* + * Created by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#include +#include +#include +#include +#include + +inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); } + +__global__ void gpuYoloLayer_r(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses, + const uint numBBoxes, const float scale_x_y) +{ + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + uint y_id = blockIdx.y * blockDim.y + threadIdx.y; + uint z_id = blockIdx.z * blockDim.z + threadIdx.z; + + if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes)) + { + return; + } + + const int numGridCells = gridSizeX * gridSizeY; + const int bbindex = y_id * gridSizeX + x_id; + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * 2.0 - 0.5; + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * 2.0 - 0.5; + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] + = pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * 2, 2); + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] + = pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * 2, 2); + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)] + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]); + + for (uint i = 0; i < numOutputClasses; ++i) + { + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]); + } +} + +cudaError_t cudaYoloLayer_r(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY, + const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, + const float modelScale); + +cudaError_t cudaYoloLayer_r(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY, + const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, + const float modelScale) +{ + dim3 threads_per_block(16, 16, 4); + dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, + (gridSizeY / threads_per_block.y) + 1, + (numBBoxes / threads_per_block.z) + 1); + for (unsigned int batch = 0; batch < batchSize; ++batch) + { + gpuYoloLayer_r<<>>( + reinterpret_cast(input) + (batch * outputSize), + reinterpret_cast(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses, + numBBoxes, modelScale); + } + return cudaGetLastError(); +} diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu b/nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu new file mode 100644 index 0000000..d7dc96e --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu @@ -0,0 +1,80 @@ +/* + * Created by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#include +#include +#include +#include +#include + +inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); } + +__global__ void gpuRegionLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses, + const uint numBBoxes) +{ + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + uint y_id = blockIdx.y * blockDim.y + threadIdx.y; + uint z_id = blockIdx.z * blockDim.z + threadIdx.z; + + if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes)) + { + return; + } + + const int numGridCells = gridSizeX * gridSizeY; + const int bbindex = y_id * gridSizeX + x_id; + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]); + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]); + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] + = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]); + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] + = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]); + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)] + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]); + + float temp = 1.0; + int i; + float sum = 0; + float largest = -INFINITY; + for(i = 0; i < numOutputClasses; ++i){ + int val = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]; + largest = (val>largest) ? val : largest; + } + for(i = 0; i < numOutputClasses; ++i){ + float e = exp(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] / temp - largest / temp); + sum += e; + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] = e; + } + for(i = 0; i < numOutputClasses; ++i){ + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] /= sum; + } +} + +cudaError_t cudaYoloLayer_v2(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY, + const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream); + +cudaError_t cudaYoloLayer_v2(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY, + const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream) +{ + dim3 threads_per_block(16, 16, 4); + dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, + (gridSizeY / threads_per_block.y) + 1, + (numBBoxes / threads_per_block.z) + 1); + for (unsigned int batch = 0; batch < batchSize; ++batch) + { + gpuRegionLayer<<>>( + reinterpret_cast(input) + (batch * outputSize), + reinterpret_cast(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses, + numBBoxes); + } + return cudaGetLastError(); +} diff --git a/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp b/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp index 3d88377..2030d22 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp +++ b/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp @@ -35,25 +35,40 @@ std::vector kANCHORS; std::vector> kMASK; namespace { -template -void write(char*& buffer, const T& val) -{ - *reinterpret_cast(buffer) = val; - buffer += sizeof(T); -} + template + void write(char*& buffer, const T& val) + { + *reinterpret_cast(buffer) = val; + buffer += sizeof(T); + } -template -void read(const char*& buffer, T& val) -{ - val = *reinterpret_cast(buffer); - buffer += sizeof(T); -} + template + void read(const char*& buffer, T& val) + { + val = *reinterpret_cast(buffer); + buffer += sizeof(T); + } } cudaError_t cudaYoloLayer ( const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, - const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType); + const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const float modelScale); + +cudaError_t cudaYoloLayer_v2 ( + const void* input, void* output, const uint& batchSize, + const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, + const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream); + +cudaError_t cudaYoloLayer_nc ( + const void* input, void* output, const uint& batchSize, + const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, + const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const float modelScale); + +cudaError_t cudaYoloLayer_r ( + const void* input, void* output, const uint& batchSize, + const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, + const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const float modelScale); YoloLayer::YoloLayer (const void* data, size_t length) { @@ -144,9 +159,28 @@ int YoloLayer::enqueue( int batchSize, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { - CHECK(cudaYoloLayer( - inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes, - m_OutputSize, stream, m_new_coords, m_scale_x_y, m_type)); + if (m_type == 2) { // YOLOR incorrect param + CHECK(cudaYoloLayer_r( + inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes, + m_OutputSize, stream, m_scale_x_y)); + } + else if (m_type == 1) { + if (m_new_coords) { + CHECK(cudaYoloLayer_nc( + inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes, + m_OutputSize, stream, m_scale_x_y)); + } + else { + CHECK(cudaYoloLayer( + inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes, + m_OutputSize, stream, m_scale_x_y)); + } + } + else { + CHECK(cudaYoloLayer_v2( + inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes, + m_OutputSize, stream)); + } return 0; } diff --git a/readme.md b/readme.md index cb536f0..f26bc1e 100644 --- a/readme.md +++ b/readme.md @@ -16,7 +16,7 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models * Darknet CFG params parser (it doesn't need to edit nvdsparsebbox_Yolo.cpp or another file for native models) * Support for new_coords, beta_nms and scale_x_y params * Support for new models -* Support for new layers types +* Support for new layers * Support for new activations * Support for convolutional groups * Support for INT8 calibration