Fix YOLO kernels
- Fix YOLO kernels - Update deprecated functions
This commit is contained in:
@@ -63,7 +63,10 @@ SRCFILES:= nvdsinfer_yolo_engine.cpp \
|
|||||||
layers/activation_layer.cpp \
|
layers/activation_layer.cpp \
|
||||||
utils.cpp \
|
utils.cpp \
|
||||||
yolo.cpp \
|
yolo.cpp \
|
||||||
yoloForward.cu
|
yoloForward.cu \
|
||||||
|
yoloForward_v2.cu \
|
||||||
|
yoloForward_nc.cu \
|
||||||
|
yoloForward_r.cu
|
||||||
|
|
||||||
ifeq ($(OPENCV), 1)
|
ifeq ($(OPENCV), 1)
|
||||||
SRCFILES+= calibrator.cpp
|
SRCFILES+= calibrator.cpp
|
||||||
|
|||||||
@@ -164,13 +164,13 @@ nvinfer1::ILayer* convolutionalLayer(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nvinfer1::IConvolutionLayer* conv = network->addConvolution(
|
nvinfer1::IConvolutionLayer* conv = network->addConvolutionNd(
|
||||||
*input, filters, nvinfer1::DimsHW{kernelSize, kernelSize}, convWt, convBias);
|
*input, filters, nvinfer1::DimsHW{kernelSize, kernelSize}, convWt, convBias);
|
||||||
assert(conv != nullptr);
|
assert(conv != nullptr);
|
||||||
std::string convLayerName = "conv_" + std::to_string(layerIdx);
|
std::string convLayerName = "conv_" + std::to_string(layerIdx);
|
||||||
conv->setName(convLayerName.c_str());
|
conv->setName(convLayerName.c_str());
|
||||||
conv->setStride(nvinfer1::DimsHW{stride, stride});
|
conv->setStrideNd(nvinfer1::DimsHW{stride, stride});
|
||||||
conv->setPadding(nvinfer1::DimsHW{pad, pad});
|
conv->setPaddingNd(nvinfer1::DimsHW{pad, pad});
|
||||||
|
|
||||||
if (block.find("groups") != block.end())
|
if (block.find("groups") != block.end())
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -19,10 +19,10 @@ nvinfer1::ILayer* maxpoolLayer(
|
|||||||
int stride = std::stoi(block.at("stride"));
|
int stride = std::stoi(block.at("stride"));
|
||||||
|
|
||||||
nvinfer1::IPoolingLayer* pool
|
nvinfer1::IPoolingLayer* pool
|
||||||
= network->addPooling(*input, nvinfer1::PoolingType::kMAX, nvinfer1::DimsHW{size, size});
|
= network->addPoolingNd(*input, nvinfer1::PoolingType::kMAX, nvinfer1::DimsHW{size, size});
|
||||||
assert(pool);
|
assert(pool);
|
||||||
std::string maxpoolLayerName = "maxpool_" + std::to_string(layerIdx);
|
std::string maxpoolLayerName = "maxpool_" + std::to_string(layerIdx);
|
||||||
pool->setStride(nvinfer1::DimsHW{stride, stride});
|
pool->setStrideNd(nvinfer1::DimsHW{stride, stride});
|
||||||
pool->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
|
pool->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
|
||||||
pool->setName(maxpoolLayerName.c_str());
|
pool->setName(maxpoolLayerName.c_str());
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ nvinfer1::ILayer* upsampleLayer(
|
|||||||
|
|
||||||
nvinfer1::IResizeLayer* resize_layer = network->addResize(*input);
|
nvinfer1::IResizeLayer* resize_layer = network->addResize(*input);
|
||||||
resize_layer->setResizeMode(nvinfer1::ResizeMode::kNEAREST);
|
resize_layer->setResizeMode(nvinfer1::ResizeMode::kNEAREST);
|
||||||
float scale[3] = {1, stride, stride};
|
float scale[3] = {1, static_cast<float>(stride), static_cast<float>(stride)};
|
||||||
resize_layer->setScales(scale, 3);
|
resize_layer->setScales(scale, 3);
|
||||||
std::string layer_name = "upsample_" + std::to_string(layerIdx);
|
std::string layer_name = "upsample_" + std::to_string(layerIdx);
|
||||||
resize_layer->setName(layer_name.c_str());
|
resize_layer->setName(layer_name.c_str());
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
|
|||||||
|
|
||||||
nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0);
|
nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0);
|
||||||
if (parseModel(*network) != NVDSINFER_SUCCESS) {
|
if (parseModel(*network) != NVDSINFER_SUCCESS) {
|
||||||
network->destroy();
|
delete network;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,7 +122,7 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
|
|||||||
std::cerr << "Building engine failed\n" << std::endl;
|
std::cerr << "Building engine failed\n" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
network->destroy();
|
delete network;
|
||||||
delete config;
|
delete config;
|
||||||
return engine;
|
return engine;
|
||||||
}
|
}
|
||||||
@@ -232,12 +232,11 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
|||||||
printLayerInfo(layerIndex, layerType, " -", outputVol, " -");
|
printLayerInfo(layerIndex, layerType, " -", outputVol, " -");
|
||||||
}
|
}
|
||||||
|
|
||||||
else if (m_ConfigBlocks.at(i).at("type") == "dropout") {
|
else if (m_ConfigBlocks.at(i).at("type") == "dropout") { // Skip dropout layer
|
||||||
assert(m_ConfigBlocks.at(i).find("probability") != m_ConfigBlocks.at(i).end());
|
assert(m_ConfigBlocks.at(i).find("probability") != m_ConfigBlocks.at(i).end());
|
||||||
//float probability = std::stof(m_ConfigBlocks.at(i).at("probability"));
|
//float probability = std::stof(m_ConfigBlocks.at(i).at("probability"));
|
||||||
//nvinfer1::ILayer* out = dropoutLayer(probability, previous, &network);
|
//nvinfer1::ILayer* out = dropoutLayer(probability, previous, &network);
|
||||||
//previous = out->getOutput(0);
|
//previous = out->getOutput(0);
|
||||||
//Skip dropout layer
|
|
||||||
assert(previous != nullptr);
|
assert(previous != nullptr);
|
||||||
tensorOutputs.push_back(previous);
|
tensorOutputs.push_back(previous);
|
||||||
printLayerInfo(layerIndex, "dropout", " -", " -", " -");
|
printLayerInfo(layerIndex, "dropout", " -", " -", " -");
|
||||||
@@ -300,6 +299,13 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
|||||||
}
|
}
|
||||||
|
|
||||||
else if (m_ConfigBlocks.at(i).at("type") == "yolo") {
|
else if (m_ConfigBlocks.at(i).at("type") == "yolo") {
|
||||||
|
uint model_type;
|
||||||
|
if (m_NetworkType.find("yolor") != std::string::npos) {
|
||||||
|
model_type = 2;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
model_type = 1;
|
||||||
|
}
|
||||||
nvinfer1::Dims prevTensorDims = previous->getDimensions();
|
nvinfer1::Dims prevTensorDims = previous->getDimensions();
|
||||||
TensorInfo& curYoloTensor = m_OutputTensors.at(outputTensorCount);
|
TensorInfo& curYoloTensor = m_OutputTensors.at(outputTensorCount);
|
||||||
curYoloTensor.gridSizeY = prevTensorDims.d[1];
|
curYoloTensor.gridSizeY = prevTensorDims.d[1];
|
||||||
@@ -327,7 +333,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
|||||||
curYoloTensor.numClasses,
|
curYoloTensor.numClasses,
|
||||||
curYoloTensor.gridSizeX,
|
curYoloTensor.gridSizeX,
|
||||||
curYoloTensor.gridSizeY,
|
curYoloTensor.gridSizeY,
|
||||||
1, new_coords, scale_x_y, beta_nms,
|
model_type, new_coords, scale_x_y, beta_nms,
|
||||||
curYoloTensor.anchors,
|
curYoloTensor.anchors,
|
||||||
m_OutputMasks);
|
m_OutputMasks);
|
||||||
assert(yoloPlugin != nullptr);
|
assert(yoloPlugin != nullptr);
|
||||||
|
|||||||
@@ -21,7 +21,7 @@
|
|||||||
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
|
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
|
||||||
|
|
||||||
__global__ void gpuYoloLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
__global__ void gpuYoloLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
||||||
const uint numBBoxes, const uint new_coords, const float scale_x_y)
|
const uint numBBoxes, const float scale_x_y)
|
||||||
{
|
{
|
||||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
@@ -35,97 +35,14 @@ __global__ void gpuYoloLayer(const float* input, float* output, const uint gridS
|
|||||||
const int numGridCells = gridSizeX * gridSizeY;
|
const int numGridCells = gridSizeX * gridSizeY;
|
||||||
const int bbindex = y_id * gridSizeX + x_id;
|
const int bbindex = y_id * gridSizeX + x_id;
|
||||||
|
|
||||||
float alpha = scale_x_y;
|
const float alpha = scale_x_y;
|
||||||
float beta = -0.5 * (scale_x_y - 1);
|
const float beta = -0.5 * (scale_x_y - 1);
|
||||||
|
|
||||||
if (new_coords == 1) {
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
|
||||||
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] * alpha + beta;
|
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
|
|
||||||
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] * alpha + beta;
|
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
|
|
||||||
= pow(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] * 2, 2);
|
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
|
|
||||||
= pow(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] * 2, 2);
|
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
|
||||||
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)];
|
|
||||||
|
|
||||||
for (uint i = 0; i < numOutputClasses; ++i)
|
|
||||||
{
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]
|
|
||||||
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (new_coords == 0 && scale_x_y != 1) { // YOLOR incorrect param
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
|
||||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * 2.0 - 0.5;
|
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
|
|
||||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * 2.0 - 0.5;
|
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
|
|
||||||
= pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * 2, 2);
|
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
|
|
||||||
= pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * 2, 2);
|
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
|
||||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
|
|
||||||
|
|
||||||
for (uint i = 0; i < numOutputClasses; ++i)
|
|
||||||
{
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]
|
|
||||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
|
||||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta;
|
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
|
|
||||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * alpha + beta;
|
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
|
|
||||||
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]);
|
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
|
|
||||||
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]);
|
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
|
||||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
|
|
||||||
|
|
||||||
for (uint i = 0; i < numOutputClasses; ++i)
|
|
||||||
{
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]
|
|
||||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void gpuRegionLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
|
||||||
const uint numBBoxes)
|
|
||||||
{
|
|
||||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
|
||||||
uint z_id = blockIdx.z * blockDim.z + threadIdx.z;
|
|
||||||
|
|
||||||
if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes))
|
|
||||||
{
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int numGridCells = gridSizeX * gridSizeY;
|
|
||||||
const int bbindex = y_id * gridSizeX + x_id;
|
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
||||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]);
|
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta;
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
|
||||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]);
|
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * alpha + beta;
|
||||||
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
|
||||||
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]);
|
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]);
|
||||||
@@ -136,53 +53,31 @@ __global__ void gpuRegionLayer(const float* input, float* output, const uint gri
|
|||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
||||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
|
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
|
||||||
|
|
||||||
float temp = 1.0;
|
for (uint i = 0; i < numOutputClasses; ++i)
|
||||||
int i;
|
{
|
||||||
float sum = 0;
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]
|
||||||
float largest = -INFINITY;
|
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]);
|
||||||
for(i = 0; i < numOutputClasses; ++i){
|
|
||||||
int val = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
|
|
||||||
largest = (val>largest) ? val : largest;
|
|
||||||
}
|
|
||||||
for(i = 0; i < numOutputClasses; ++i){
|
|
||||||
float e = exp(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] / temp - largest / temp);
|
|
||||||
sum += e;
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] = e;
|
|
||||||
}
|
|
||||||
for(i = 0; i < numOutputClasses; ++i){
|
|
||||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] /= sum;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||||
const uint& numOutputClasses, const uint& numBBoxes,
|
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||||
uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType);
|
const float modelScale);
|
||||||
|
|
||||||
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||||
const uint& numOutputClasses, const uint& numBBoxes,
|
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||||
uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType)
|
const float modelScale)
|
||||||
{
|
{
|
||||||
dim3 threads_per_block(16, 16, 4);
|
dim3 threads_per_block(16, 16, 4);
|
||||||
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
|
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
|
||||||
(gridSizeY / threads_per_block.y) + 1,
|
(gridSizeY / threads_per_block.y) + 1,
|
||||||
(numBBoxes / threads_per_block.z) + 1);
|
(numBBoxes / threads_per_block.z) + 1);
|
||||||
if (modelType == 1) {
|
for (unsigned int batch = 0; batch < batchSize; ++batch)
|
||||||
for (unsigned int batch = 0; batch < batchSize; ++batch)
|
{
|
||||||
{
|
gpuYoloLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||||
gpuYoloLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
||||||
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
||||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
numBBoxes, modelScale);
|
||||||
numBBoxes, modelCoords, modelScale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (modelType == 0) {
|
|
||||||
for (unsigned int batch = 0; batch < batchSize; ++batch)
|
|
||||||
{
|
|
||||||
gpuRegionLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
|
||||||
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
|
||||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
|
||||||
numBBoxes);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return cudaGetLastError();
|
return cudaGetLastError();
|
||||||
}
|
}
|
||||||
|
|||||||
74
nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu
Normal file
74
nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
/*
|
||||||
|
* Created by Marcos Luciano
|
||||||
|
* https://www.github.com/marcoslucianops
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
|
||||||
|
|
||||||
|
__global__ void gpuYoloLayer_nc(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
||||||
|
const uint numBBoxes, const float scale_x_y)
|
||||||
|
{
|
||||||
|
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
uint z_id = blockIdx.z * blockDim.z + threadIdx.z;
|
||||||
|
|
||||||
|
if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes))
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int numGridCells = gridSizeX * gridSizeY;
|
||||||
|
const int bbindex = y_id * gridSizeX + x_id;
|
||||||
|
|
||||||
|
const float alpha = scale_x_y;
|
||||||
|
const float beta = -0.5 * (scale_x_y - 1);
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
||||||
|
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] * alpha + beta;
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
|
||||||
|
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] * alpha + beta;
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
|
||||||
|
= pow(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] * 2, 2);
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
|
||||||
|
= pow(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] * 2, 2);
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
||||||
|
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)];
|
||||||
|
|
||||||
|
for (uint i = 0; i < numOutputClasses; ++i)
|
||||||
|
{
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]
|
||||||
|
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaError_t cudaYoloLayer_nc(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||||
|
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||||
|
const float modelScale);
|
||||||
|
|
||||||
|
cudaError_t cudaYoloLayer_nc(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||||
|
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||||
|
const float modelScale)
|
||||||
|
{
|
||||||
|
dim3 threads_per_block(16, 16, 4);
|
||||||
|
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
|
||||||
|
(gridSizeY / threads_per_block.y) + 1,
|
||||||
|
(numBBoxes / threads_per_block.z) + 1);
|
||||||
|
for (unsigned int batch = 0; batch < batchSize; ++batch)
|
||||||
|
{
|
||||||
|
gpuYoloLayer_nc<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||||
|
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
||||||
|
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
||||||
|
numBBoxes, modelScale);
|
||||||
|
}
|
||||||
|
return cudaGetLastError();
|
||||||
|
}
|
||||||
71
nvdsinfer_custom_impl_Yolo/yoloForward_r.cu
Normal file
71
nvdsinfer_custom_impl_Yolo/yoloForward_r.cu
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
/*
|
||||||
|
* Created by Marcos Luciano
|
||||||
|
* https://www.github.com/marcoslucianops
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
|
||||||
|
|
||||||
|
__global__ void gpuYoloLayer_r(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
||||||
|
const uint numBBoxes, const float scale_x_y)
|
||||||
|
{
|
||||||
|
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
uint z_id = blockIdx.z * blockDim.z + threadIdx.z;
|
||||||
|
|
||||||
|
if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes))
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int numGridCells = gridSizeX * gridSizeY;
|
||||||
|
const int bbindex = y_id * gridSizeX + x_id;
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
||||||
|
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * 2.0 - 0.5;
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
|
||||||
|
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * 2.0 - 0.5;
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
|
||||||
|
= pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * 2, 2);
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
|
||||||
|
= pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * 2, 2);
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
||||||
|
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
|
||||||
|
|
||||||
|
for (uint i = 0; i < numOutputClasses; ++i)
|
||||||
|
{
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]
|
||||||
|
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaError_t cudaYoloLayer_r(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||||
|
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||||
|
const float modelScale);
|
||||||
|
|
||||||
|
cudaError_t cudaYoloLayer_r(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||||
|
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||||
|
const float modelScale)
|
||||||
|
{
|
||||||
|
dim3 threads_per_block(16, 16, 4);
|
||||||
|
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
|
||||||
|
(gridSizeY / threads_per_block.y) + 1,
|
||||||
|
(numBBoxes / threads_per_block.z) + 1);
|
||||||
|
for (unsigned int batch = 0; batch < batchSize; ++batch)
|
||||||
|
{
|
||||||
|
gpuYoloLayer_r<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||||
|
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
||||||
|
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
||||||
|
numBBoxes, modelScale);
|
||||||
|
}
|
||||||
|
return cudaGetLastError();
|
||||||
|
}
|
||||||
80
nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu
Normal file
80
nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
/*
|
||||||
|
* Created by Marcos Luciano
|
||||||
|
* https://www.github.com/marcoslucianops
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
|
||||||
|
|
||||||
|
__global__ void gpuRegionLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
||||||
|
const uint numBBoxes)
|
||||||
|
{
|
||||||
|
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
uint z_id = blockIdx.z * blockDim.z + threadIdx.z;
|
||||||
|
|
||||||
|
if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes))
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int numGridCells = gridSizeX * gridSizeY;
|
||||||
|
const int bbindex = y_id * gridSizeX + x_id;
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
||||||
|
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]);
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
|
||||||
|
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]);
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
|
||||||
|
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]);
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
|
||||||
|
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]);
|
||||||
|
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
||||||
|
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
|
||||||
|
|
||||||
|
float temp = 1.0;
|
||||||
|
int i;
|
||||||
|
float sum = 0;
|
||||||
|
float largest = -INFINITY;
|
||||||
|
for(i = 0; i < numOutputClasses; ++i){
|
||||||
|
int val = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
|
||||||
|
largest = (val>largest) ? val : largest;
|
||||||
|
}
|
||||||
|
for(i = 0; i < numOutputClasses; ++i){
|
||||||
|
float e = exp(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] / temp - largest / temp);
|
||||||
|
sum += e;
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] = e;
|
||||||
|
}
|
||||||
|
for(i = 0; i < numOutputClasses; ++i){
|
||||||
|
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaError_t cudaYoloLayer_v2(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||||
|
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream);
|
||||||
|
|
||||||
|
cudaError_t cudaYoloLayer_v2(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||||
|
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
dim3 threads_per_block(16, 16, 4);
|
||||||
|
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
|
||||||
|
(gridSizeY / threads_per_block.y) + 1,
|
||||||
|
(numBBoxes / threads_per_block.z) + 1);
|
||||||
|
for (unsigned int batch = 0; batch < batchSize; ++batch)
|
||||||
|
{
|
||||||
|
gpuRegionLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||||
|
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
||||||
|
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
||||||
|
numBBoxes);
|
||||||
|
}
|
||||||
|
return cudaGetLastError();
|
||||||
|
}
|
||||||
@@ -35,25 +35,40 @@ std::vector<float> kANCHORS;
|
|||||||
std::vector<std::vector<int>> kMASK;
|
std::vector<std::vector<int>> kMASK;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void write(char*& buffer, const T& val)
|
void write(char*& buffer, const T& val)
|
||||||
{
|
{
|
||||||
*reinterpret_cast<T*>(buffer) = val;
|
*reinterpret_cast<T*>(buffer) = val;
|
||||||
buffer += sizeof(T);
|
buffer += sizeof(T);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void read(const char*& buffer, T& val)
|
void read(const char*& buffer, T& val)
|
||||||
{
|
{
|
||||||
val = *reinterpret_cast<const T*>(buffer);
|
val = *reinterpret_cast<const T*>(buffer);
|
||||||
buffer += sizeof(T);
|
buffer += sizeof(T);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaError_t cudaYoloLayer (
|
cudaError_t cudaYoloLayer (
|
||||||
const void* input, void* output, const uint& batchSize,
|
const void* input, void* output, const uint& batchSize,
|
||||||
const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
||||||
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType);
|
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const float modelScale);
|
||||||
|
|
||||||
|
cudaError_t cudaYoloLayer_v2 (
|
||||||
|
const void* input, void* output, const uint& batchSize,
|
||||||
|
const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
||||||
|
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream);
|
||||||
|
|
||||||
|
cudaError_t cudaYoloLayer_nc (
|
||||||
|
const void* input, void* output, const uint& batchSize,
|
||||||
|
const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
||||||
|
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const float modelScale);
|
||||||
|
|
||||||
|
cudaError_t cudaYoloLayer_r (
|
||||||
|
const void* input, void* output, const uint& batchSize,
|
||||||
|
const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
||||||
|
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const float modelScale);
|
||||||
|
|
||||||
YoloLayer::YoloLayer (const void* data, size_t length)
|
YoloLayer::YoloLayer (const void* data, size_t length)
|
||||||
{
|
{
|
||||||
@@ -144,9 +159,28 @@ int YoloLayer::enqueue(
|
|||||||
int batchSize, void const* const* inputs, void* const* outputs, void* workspace,
|
int batchSize, void const* const* inputs, void* const* outputs, void* workspace,
|
||||||
cudaStream_t stream) noexcept
|
cudaStream_t stream) noexcept
|
||||||
{
|
{
|
||||||
CHECK(cudaYoloLayer(
|
if (m_type == 2) { // YOLOR incorrect param
|
||||||
inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
|
CHECK(cudaYoloLayer_r(
|
||||||
m_OutputSize, stream, m_new_coords, m_scale_x_y, m_type));
|
inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
|
||||||
|
m_OutputSize, stream, m_scale_x_y));
|
||||||
|
}
|
||||||
|
else if (m_type == 1) {
|
||||||
|
if (m_new_coords) {
|
||||||
|
CHECK(cudaYoloLayer_nc(
|
||||||
|
inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
|
||||||
|
m_OutputSize, stream, m_scale_x_y));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
CHECK(cudaYoloLayer(
|
||||||
|
inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
|
||||||
|
m_OutputSize, stream, m_scale_x_y));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
CHECK(cudaYoloLayer_v2(
|
||||||
|
inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
|
||||||
|
m_OutputSize, stream));
|
||||||
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
|
|||||||
* Darknet CFG params parser (it doesn't need to edit nvdsparsebbox_Yolo.cpp or another file for native models)
|
* Darknet CFG params parser (it doesn't need to edit nvdsparsebbox_Yolo.cpp or another file for native models)
|
||||||
* Support for new_coords, beta_nms and scale_x_y params
|
* Support for new_coords, beta_nms and scale_x_y params
|
||||||
* Support for new models
|
* Support for new models
|
||||||
* Support for new layers types
|
* Support for new layers
|
||||||
* Support for new activations
|
* Support for new activations
|
||||||
* Support for convolutional groups
|
* Support for convolutional groups
|
||||||
* Support for INT8 calibration
|
* Support for INT8 calibration
|
||||||
|
|||||||
Reference in New Issue
Block a user