Move YOLO Decoder from CPU to GPU
This commit is contained in:
@@ -143,50 +143,32 @@ static void addBBoxProposal(const float bx, const float by, const float bw, cons
|
||||
|
||||
static std::vector<NvDsInferParseObjectInfo>
|
||||
decodeYoloTensor(
|
||||
const float* detections, const std::vector<int> &mask, const std::vector<float> &anchors,
|
||||
const float* detections,
|
||||
const uint gridSizeW, const uint gridSizeH, const uint stride, const uint numBBoxes,
|
||||
const uint numOutputClasses, const uint& netW,
|
||||
const uint& netH,
|
||||
const float confThresh)
|
||||
const uint numOutputClasses, const uint& netW, const uint& netH, const float confThresh)
|
||||
{
|
||||
std::vector<NvDsInferParseObjectInfo> binfo;
|
||||
for (uint y = 0; y < gridSizeH; ++y) {
|
||||
for (uint x = 0; x < gridSizeW; ++x) {
|
||||
for (uint b = 0; b < numBBoxes; ++b)
|
||||
{
|
||||
const float pw = anchors[mask[b] * 2];
|
||||
const float ph = anchors[mask[b] * 2 + 1];
|
||||
|
||||
const int numGridCells = gridSizeH * gridSizeW;
|
||||
const int bbindex = y * gridSizeW + x;
|
||||
const float bx
|
||||
= x + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 0)];
|
||||
const float by
|
||||
= y + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 1)];
|
||||
const float bw
|
||||
= pw * detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 2)];
|
||||
const float bh
|
||||
= ph * detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 3)];
|
||||
|
||||
const float objectness
|
||||
const float bx
|
||||
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 0)];
|
||||
const float by
|
||||
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 1)];
|
||||
const float bw
|
||||
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 2)];
|
||||
const float bh
|
||||
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 3)];
|
||||
|
||||
const float maxProb
|
||||
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 4)];
|
||||
|
||||
float maxProb = 0.0f;
|
||||
int maxIndex = -1;
|
||||
|
||||
for (uint i = 0; i < numOutputClasses; ++i)
|
||||
{
|
||||
float prob
|
||||
= (detections[bbindex
|
||||
+ numGridCells * (b * (5 + numOutputClasses) + (5 + i))]);
|
||||
|
||||
if (prob > maxProb)
|
||||
{
|
||||
maxProb = prob;
|
||||
maxIndex = i;
|
||||
}
|
||||
}
|
||||
maxProb = objectness * maxProb;
|
||||
const int maxIndex
|
||||
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 5)];
|
||||
|
||||
if (maxProb > confThresh)
|
||||
{
|
||||
@@ -200,49 +182,32 @@ decodeYoloTensor(
|
||||
|
||||
static std::vector<NvDsInferParseObjectInfo>
|
||||
decodeYoloV2Tensor(
|
||||
const float* detections, const std::vector<float> &anchors,
|
||||
const float* detections,
|
||||
const uint gridSizeW, const uint gridSizeH, const uint stride, const uint numBBoxes,
|
||||
const uint numOutputClasses, const uint& netW,
|
||||
const uint& netH)
|
||||
const uint numOutputClasses, const uint& netW, const uint& netH)
|
||||
{
|
||||
std::vector<NvDsInferParseObjectInfo> binfo;
|
||||
for (uint y = 0; y < gridSizeH; ++y) {
|
||||
for (uint x = 0; x < gridSizeW; ++x) {
|
||||
for (uint b = 0; b < numBBoxes; ++b)
|
||||
{
|
||||
const float pw = anchors[b * 2];
|
||||
const float ph = anchors[b * 2 + 1];
|
||||
|
||||
const int numGridCells = gridSizeH * gridSizeW;
|
||||
const int bbindex = y * gridSizeW + x;
|
||||
const float bx
|
||||
= x + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 0)];
|
||||
const float by
|
||||
= y + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 1)];
|
||||
const float bw
|
||||
= pw * detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 2)];
|
||||
const float bh
|
||||
= ph * detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 3)];
|
||||
|
||||
const float objectness
|
||||
const float bx
|
||||
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 0)];
|
||||
const float by
|
||||
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 1)];
|
||||
const float bw
|
||||
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 2)] * stride;
|
||||
const float bh
|
||||
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 3)] * stride;
|
||||
|
||||
const float maxProb
|
||||
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 4)];
|
||||
|
||||
float maxProb = 0.0f;
|
||||
int maxIndex = -1;
|
||||
|
||||
for (uint i = 0; i < numOutputClasses; ++i)
|
||||
{
|
||||
float prob
|
||||
= (detections[bbindex
|
||||
+ numGridCells * (b * (5 + numOutputClasses) + (5 + i))]);
|
||||
|
||||
if (prob > maxProb)
|
||||
{
|
||||
maxProb = prob;
|
||||
maxIndex = i;
|
||||
}
|
||||
}
|
||||
maxProb = objectness * maxProb;
|
||||
const int maxIndex
|
||||
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 5)];
|
||||
|
||||
addBBoxProposal(bx, by, bw, bh, stride, netW, netH, maxIndex, maxProb, binfo);
|
||||
}
|
||||
@@ -270,32 +235,30 @@ static bool NvDsInferParseYolo(
|
||||
NvDsInferNetworkInfo const& networkInfo,
|
||||
NvDsInferParseDetectionParams const& detectionParams,
|
||||
std::vector<NvDsInferParseObjectInfo>& objectList,
|
||||
const std::vector<float> &anchors,
|
||||
const std::vector<std::vector<int>> &masks,
|
||||
const uint &num_classes,
|
||||
const float &beta_nms)
|
||||
const uint &numBBoxes,
|
||||
const uint &numClasses,
|
||||
const float &betaNMS)
|
||||
{
|
||||
if (outputLayersInfo.empty()) {
|
||||
std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl;;
|
||||
return false;
|
||||
}
|
||||
|
||||
const float kCONF_THRESH = detectionParams.perClassThreshold[0];
|
||||
|
||||
const std::vector<const NvDsInferLayerInfo*> sortedLayers =
|
||||
SortLayers (outputLayersInfo);
|
||||
|
||||
if (sortedLayers.size() != masks.size()) {
|
||||
std::cerr << "ERROR: YOLO output layer.size: " << sortedLayers.size()
|
||||
<< " does not match mask.size: " << masks.size() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (num_classes != detectionParams.numClassesConfigured)
|
||||
if (numClasses != detectionParams.numClassesConfigured)
|
||||
{
|
||||
std::cerr << "WARNING: Num classes mismatch. Configured: "
|
||||
<< detectionParams.numClassesConfigured
|
||||
<< ", detected by network: " << num_classes << std::endl;
|
||||
<< ", detected by network: " << numClasses << std::endl;
|
||||
}
|
||||
|
||||
std::vector<NvDsInferParseObjectInfo> objects;
|
||||
|
||||
for (uint idx = 0; idx < masks.size(); ++idx) {
|
||||
for (uint idx = 0; idx < sortedLayers.size(); ++idx) {
|
||||
const NvDsInferLayerInfo &layer = *sortedLayers[idx]; // 255 x Grid x Grid
|
||||
|
||||
assert(layer.inferDims.numDims == 3);
|
||||
@@ -304,14 +267,13 @@ static bool NvDsInferParseYolo(
|
||||
const uint stride = DIVUP(networkInfo.width, gridSizeW);
|
||||
|
||||
std::vector<NvDsInferParseObjectInfo> outObjs =
|
||||
decodeYoloTensor((const float*)(layer.buffer), masks[idx], anchors, gridSizeW, gridSizeH, stride, masks[idx].size(),
|
||||
num_classes, networkInfo.width, networkInfo.height, kCONF_THRESH);
|
||||
decodeYoloTensor((const float*)(layer.buffer), gridSizeW, gridSizeH, stride, numBBoxes,
|
||||
numClasses, networkInfo.width, networkInfo.height, kCONF_THRESH);
|
||||
objects.insert(objects.end(), outObjs.begin(), outObjs.end());
|
||||
}
|
||||
|
||||
|
||||
objectList.clear();
|
||||
objectList = nmsAllClasses(beta_nms, objects, num_classes);
|
||||
objectList = nmsAllClasses(betaNMS, objects, numClasses);
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -321,34 +283,31 @@ static bool NvDsInferParseYoloV2(
|
||||
NvDsInferNetworkInfo const& networkInfo,
|
||||
NvDsInferParseDetectionParams const& detectionParams,
|
||||
std::vector<NvDsInferParseObjectInfo>& objectList,
|
||||
std::vector<float> &anchors,
|
||||
const uint &num_classes)
|
||||
const uint &numBBoxes,
|
||||
const uint &numClasses)
|
||||
{
|
||||
if (outputLayersInfo.empty()) {
|
||||
std::cerr << "Could not find output layer in bbox parsing" << std::endl;;
|
||||
std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl;;
|
||||
return false;
|
||||
}
|
||||
const uint kNUM_BBOXES = anchors.size() / 2;
|
||||
|
||||
const NvDsInferLayerInfo &layer = outputLayersInfo[0];
|
||||
|
||||
if (num_classes != detectionParams.numClassesConfigured)
|
||||
if (numClasses != detectionParams.numClassesConfigured)
|
||||
{
|
||||
std::cerr << "WARNING: Num classes mismatch. Configured: "
|
||||
<< detectionParams.numClassesConfigured
|
||||
<< ", detected by network: " << num_classes << std::endl;
|
||||
<< ", detected by network: " << numClasses << std::endl;
|
||||
}
|
||||
|
||||
assert(layer.inferDims.numDims == 3);
|
||||
const uint gridSizeH = layer.inferDims.d[1];
|
||||
const uint gridSizeW = layer.inferDims.d[2];
|
||||
const uint stride = DIVUP(networkInfo.width, gridSizeW);
|
||||
for (auto& anchor : anchors) {
|
||||
anchor *= stride;
|
||||
}
|
||||
|
||||
std::vector<NvDsInferParseObjectInfo> objects =
|
||||
decodeYoloV2Tensor((const float*)(layer.buffer), anchors, gridSizeW, gridSizeH, stride, kNUM_BBOXES,
|
||||
num_classes, networkInfo.width, networkInfo.height);
|
||||
decodeYoloV2Tensor((const float*)(layer.buffer), gridSizeW, gridSizeH, stride, numBBoxes,
|
||||
numClasses, networkInfo.width, networkInfo.height);
|
||||
|
||||
objectList = objects;
|
||||
|
||||
@@ -361,17 +320,18 @@ extern "C" bool NvDsInferParseYolo(
|
||||
NvDsInferParseDetectionParams const& detectionParams,
|
||||
std::vector<NvDsInferParseObjectInfo>& objectList)
|
||||
{
|
||||
|
||||
int model_type = kMODEL_TYPE;
|
||||
int num_bboxes = kNUM_BBOXES;
|
||||
int num_classes = kNUM_CLASSES;
|
||||
float beta_nms = kBETA_NMS;
|
||||
std::vector<float> anchors = kANCHORS;
|
||||
std::vector<std::vector<int>> mask = kMASK;
|
||||
|
||||
if (mask.size() > 0) {
|
||||
return NvDsInferParseYolo (outputLayersInfo, networkInfo, detectionParams, objectList, anchors, mask, num_classes, beta_nms);
|
||||
if (model_type != 0) {
|
||||
return NvDsInferParseYolo (outputLayersInfo, networkInfo, detectionParams, objectList,
|
||||
num_bboxes, num_classes, beta_nms);
|
||||
}
|
||||
else {
|
||||
return NvDsInferParseYoloV2 (outputLayersInfo, networkInfo, detectionParams, objectList, anchors, num_classes);
|
||||
return NvDsInferParseYoloV2 (outputLayersInfo, networkInfo, detectionParams, objectList,
|
||||
num_bboxes, num_classes);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -31,21 +31,6 @@
|
||||
#include "calibrator.h"
|
||||
#endif
|
||||
|
||||
void orderParams(std::vector<std::vector<int>> *maskVector) {
|
||||
std::vector<std::vector<int>> maskinput = *maskVector;
|
||||
std::vector<int> maskPartial;
|
||||
for (uint i = 0; i < maskinput.size(); i++) {
|
||||
for (uint j = i + 1; j < maskinput.size(); j++) {
|
||||
if (maskinput[i][0] <= maskinput[j][0]) {
|
||||
maskPartial = maskinput[i];
|
||||
maskinput[i] = maskinput[j];
|
||||
maskinput[j] = maskPartial;
|
||||
}
|
||||
}
|
||||
}
|
||||
*maskVector = maskinput;
|
||||
}
|
||||
|
||||
Yolo::Yolo(const NetworkInfo& networkInfo)
|
||||
: m_NetworkType(networkInfo.networkType), // YOLO type
|
||||
m_ConfigFilePath(networkInfo.configFilePath), // YOLO cfg
|
||||
@@ -71,7 +56,6 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder, nvinfer1
|
||||
|
||||
m_ConfigBlocks = parseConfigFile(m_ConfigFilePath);
|
||||
parseConfigBlocks();
|
||||
orderParams(&m_OutputMasks);
|
||||
|
||||
nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0);
|
||||
if (parseModel(*network) != NVDSINFER_SUCCESS) {
|
||||
@@ -366,7 +350,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
||||
curYoloTensor.gridSizeY,
|
||||
model_type, new_coords, scale_x_y, beta_nms,
|
||||
curYoloTensor.anchors,
|
||||
m_OutputMasks);
|
||||
curYoloTensor.masks);
|
||||
assert(yoloPlugin != nullptr);
|
||||
nvinfer1::IPluginV2Layer* yolo =
|
||||
network.addPluginV2(&previous, 1, *yoloPlugin);
|
||||
@@ -396,7 +380,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
||||
* (curRegionTensor.numBBoxes * (5 + curRegionTensor.numClasses));
|
||||
std::string layerName = "region_" + std::to_string(i);
|
||||
curRegionTensor.blobName = layerName;
|
||||
std::vector<std::vector<int>> mask;
|
||||
std::vector<int> mask;
|
||||
nvinfer1::IPluginV2* regionPlugin
|
||||
= new YoloLayer(curRegionTensor.numBBoxes,
|
||||
curRegionTensor.numClasses,
|
||||
@@ -541,26 +525,22 @@ void Yolo::parseConfigBlocks()
|
||||
|
||||
if (block.find("mask") != block.end()) {
|
||||
std::string maskString = block.at("mask");
|
||||
std::vector<int> pMASKS;
|
||||
while (!maskString.empty())
|
||||
{
|
||||
int npos = maskString.find_first_of(',');
|
||||
if (npos != -1)
|
||||
{
|
||||
int mask = std::stoul(trim(maskString.substr(0, npos)));
|
||||
pMASKS.push_back(mask);
|
||||
outputTensor.masks.push_back(mask);
|
||||
maskString.erase(0, npos + 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
int mask = std::stoul(trim(maskString));
|
||||
pMASKS.push_back(mask);
|
||||
outputTensor.masks.push_back(mask);
|
||||
break;
|
||||
}
|
||||
}
|
||||
m_OutputMasks.push_back(pMASKS);
|
||||
}
|
||||
|
||||
outputTensor.numBBoxes = outputTensor.masks.size() > 0
|
||||
|
||||
@@ -58,7 +58,7 @@ struct TensorInfo
|
||||
uint numClasses{0};
|
||||
uint numBBoxes{0};
|
||||
uint64_t volume{0};
|
||||
std::vector<uint> masks;
|
||||
std::vector<int> masks;
|
||||
std::vector<float> anchors;
|
||||
int bindingIndex{-1};
|
||||
float* hostBuffer{nullptr};
|
||||
@@ -86,7 +86,6 @@ protected:
|
||||
const std::string m_DeviceType;
|
||||
const std::string m_InputBlobName;
|
||||
std::vector<TensorInfo> m_OutputTensors;
|
||||
std::vector<std::vector<int>> m_OutputMasks;
|
||||
std::vector<std::map<std::string, std::string>> m_ConfigBlocks;
|
||||
uint m_InputH;
|
||||
uint m_InputW;
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
/*
|
||||
* Copyright (c) 2018-2019 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* NVIDIA Corporation and its licensors retain all intellectual property
|
||||
* and proprietary rights in and to this software, related documentation
|
||||
* and any modifications thereto. Any use, reproduction, disclosure or
|
||||
* distribution of this software and related documentation without an express
|
||||
* license agreement from NVIDIA Corporation is strictly prohibited.
|
||||
*
|
||||
* Edited by Marcos Luciano
|
||||
* Created by Marcos Luciano
|
||||
* https://www.github.com/marcoslucianops
|
||||
*
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
@@ -21,7 +12,7 @@
|
||||
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
|
||||
|
||||
__global__ void gpuYoloLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
||||
const uint numBBoxes, const float scale_x_y)
|
||||
const uint numBBoxes, const float scaleXY, const float* anchors, const int* mask)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
@@ -35,38 +26,53 @@ __global__ void gpuYoloLayer(const float* input, float* output, const uint gridS
|
||||
const int numGridCells = gridSizeX * gridSizeY;
|
||||
const int bbindex = y_id * gridSizeX + x_id;
|
||||
|
||||
const float alpha = scale_x_y;
|
||||
const float beta = -0.5 * (scale_x_y - 1);
|
||||
const float alpha = scaleXY;
|
||||
const float beta = -0.5 * (scaleXY - 1);
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta;
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta + x_id;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * alpha + beta;
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * alpha + beta + y_id;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
|
||||
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]);
|
||||
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * anchors[mask[z_id] * 2];
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
|
||||
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]);
|
||||
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * anchors[mask[z_id] * 2 + 1];
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
||||
const float objectness
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
|
||||
|
||||
float maxProb = 0.0f;
|
||||
int maxIndex = -1;
|
||||
|
||||
for (uint i = 0; i < numOutputClasses; ++i)
|
||||
{
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]
|
||||
float prob
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]);
|
||||
|
||||
if (prob > maxProb)
|
||||
{
|
||||
maxProb = prob;
|
||||
maxIndex = i;
|
||||
}
|
||||
}
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
||||
= objectness * maxProb;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 5)]
|
||||
= maxIndex;
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||
const float modelScale);
|
||||
const float scaleXY, const void* anchors, const void* mask);
|
||||
|
||||
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||
const float modelScale)
|
||||
const float scaleXY, const void* anchors, const void* mask)
|
||||
{
|
||||
dim3 threads_per_block(16, 16, 4);
|
||||
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
|
||||
@@ -77,7 +83,7 @@ cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize
|
||||
gpuYoloLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
||||
numBBoxes, modelScale);
|
||||
numBBoxes, scaleXY, reinterpret_cast<const float*>(anchors), reinterpret_cast<const int*>(mask));
|
||||
}
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
@@ -9,10 +9,8 @@
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
|
||||
|
||||
__global__ void gpuYoloLayer_nc(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
||||
const uint numBBoxes, const float scale_x_y)
|
||||
const uint numBBoxes, const float scaleXY, const float* anchors, const int* mask)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
@@ -26,38 +24,53 @@ __global__ void gpuYoloLayer_nc(const float* input, float* output, const uint gr
|
||||
const int numGridCells = gridSizeX * gridSizeY;
|
||||
const int bbindex = y_id * gridSizeX + x_id;
|
||||
|
||||
const float alpha = scale_x_y;
|
||||
const float beta = -0.5 * (scale_x_y - 1);
|
||||
const float alpha = scaleXY;
|
||||
const float beta = -0.5 * (scaleXY - 1);
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
||||
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] * alpha + beta;
|
||||
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] * alpha + beta + x_id;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
|
||||
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] * alpha + beta;
|
||||
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] * alpha + beta + y_id;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
|
||||
= pow(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] * 2, 2);
|
||||
= __powf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] * 2, 2) * anchors[mask[z_id] * 2];
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
|
||||
= pow(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] * 2, 2);
|
||||
= __powf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] * 2, 2) * anchors[mask[z_id] * 2 + 1];
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
||||
const float objectness
|
||||
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)];
|
||||
|
||||
float maxProb = 0.0f;
|
||||
int maxIndex = -1;
|
||||
|
||||
for (uint i = 0; i < numOutputClasses; ++i)
|
||||
{
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]
|
||||
float prob
|
||||
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
|
||||
|
||||
if (prob > maxProb)
|
||||
{
|
||||
maxProb = prob;
|
||||
maxIndex = i;
|
||||
}
|
||||
}
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
||||
= objectness * maxProb;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 5)]
|
||||
= maxIndex;
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer_nc(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||
const float modelScale);
|
||||
const float scaleXY, const void* anchors, const void* mask);
|
||||
|
||||
cudaError_t cudaYoloLayer_nc(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||
const float modelScale)
|
||||
const float scaleXY, const void* anchors, const void* mask)
|
||||
{
|
||||
dim3 threads_per_block(16, 16, 4);
|
||||
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
|
||||
@@ -68,7 +81,7 @@ cudaError_t cudaYoloLayer_nc(const void* input, void* output, const uint& batchS
|
||||
gpuYoloLayer_nc<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
||||
numBBoxes, modelScale);
|
||||
numBBoxes, scaleXY, reinterpret_cast<const float*>(anchors), reinterpret_cast<const int*>(mask));
|
||||
}
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
|
||||
|
||||
__global__ void gpuYoloLayer_r(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
||||
const uint numBBoxes, const float scale_x_y)
|
||||
const uint numBBoxes, const float scaleXY, const float* anchors, const int* mask)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
@@ -26,35 +26,53 @@ __global__ void gpuYoloLayer_r(const float* input, float* output, const uint gri
|
||||
const int numGridCells = gridSizeX * gridSizeY;
|
||||
const int bbindex = y_id * gridSizeX + x_id;
|
||||
|
||||
const float alpha = scaleXY;
|
||||
const float beta = -0.5 * (scaleXY - 1);
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * 2.0 - 0.5;
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta + x_id;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * 2.0 - 0.5;
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * alpha + beta + y_id;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
|
||||
= pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * 2, 2);
|
||||
= __powf(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * 2, 2) * anchors[mask[z_id] * 2];
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
|
||||
= pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * 2, 2);
|
||||
= __powf(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * 2, 2) * anchors[mask[z_id] * 2 + 1];
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
||||
const float objectness
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
|
||||
|
||||
float maxProb = 0.0f;
|
||||
int maxIndex = -1;
|
||||
|
||||
for (uint i = 0; i < numOutputClasses; ++i)
|
||||
{
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]
|
||||
float prob
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]);
|
||||
|
||||
if (prob > maxProb)
|
||||
{
|
||||
maxProb = prob;
|
||||
maxIndex = i;
|
||||
}
|
||||
}
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
||||
= objectness * maxProb;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 5)]
|
||||
= maxIndex;
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer_r(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||
const float modelScale);
|
||||
const float scaleXY, const void* anchors, const void* mask);
|
||||
|
||||
cudaError_t cudaYoloLayer_r(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||
const float modelScale)
|
||||
const float scaleXY, const void* anchors, const void* mask)
|
||||
{
|
||||
dim3 threads_per_block(16, 16, 4);
|
||||
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
|
||||
@@ -65,7 +83,7 @@ cudaError_t cudaYoloLayer_r(const void* input, void* output, const uint& batchSi
|
||||
gpuYoloLayer_r<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
||||
numBBoxes, modelScale);
|
||||
numBBoxes, scaleXY, reinterpret_cast<const float*>(anchors), reinterpret_cast<const int*>(mask));
|
||||
}
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
@@ -11,8 +11,28 @@
|
||||
|
||||
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
|
||||
|
||||
__global__ void gpuRegionLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
||||
const uint numBBoxes)
|
||||
__device__ void softmaxGPU(const float* input, const int bbindex, const int numGridCells,
|
||||
uint z_id, const uint numOutputClasses, float temp, float* output)
|
||||
{
|
||||
int i;
|
||||
float sum = 0;
|
||||
float largest = -INFINITY;
|
||||
for (i = 0; i < numOutputClasses; ++i) {
|
||||
int val = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
|
||||
largest = (val>largest) ? val : largest;
|
||||
}
|
||||
for (i = 0; i < numOutputClasses; ++i) {
|
||||
float e = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] / temp - largest / temp);
|
||||
sum += e;
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] = e;
|
||||
}
|
||||
for (i = 0; i < numOutputClasses; ++i) {
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gpuRegionLayer(const float* input, float* output, float* softmax, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
||||
const uint numBBoxes, const float* anchors)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
@@ -27,43 +47,51 @@ __global__ void gpuRegionLayer(const float* input, float* output, const uint gri
|
||||
const int bbindex = y_id * gridSizeX + x_id;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]);
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) + x_id;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]);
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) + y_id;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
|
||||
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]);
|
||||
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * anchors[z_id * 2];
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
|
||||
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]);
|
||||
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * anchors[z_id * 2 + 1];
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
||||
softmaxGPU(input, bbindex, numGridCells, z_id, numOutputClasses, 1.0, softmax);
|
||||
|
||||
const float objectness
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
|
||||
|
||||
float temp = 1.0;
|
||||
int i;
|
||||
float sum = 0;
|
||||
float largest = -INFINITY;
|
||||
for(i = 0; i < numOutputClasses; ++i){
|
||||
int val = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
|
||||
largest = (val>largest) ? val : largest;
|
||||
float maxProb = 0.0f;
|
||||
int maxIndex = -1;
|
||||
|
||||
for (uint i = 0; i < numOutputClasses; ++i)
|
||||
{
|
||||
float prob
|
||||
= softmax[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
|
||||
|
||||
if (prob > maxProb)
|
||||
{
|
||||
maxProb = prob;
|
||||
maxIndex = i;
|
||||
}
|
||||
for(i = 0; i < numOutputClasses; ++i){
|
||||
float e = exp(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] / temp - largest / temp);
|
||||
sum += e;
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] = e;
|
||||
}
|
||||
for(i = 0; i < numOutputClasses; ++i){
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] /= sum;
|
||||
}
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
||||
= objectness * maxProb;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 5)]
|
||||
= maxIndex;
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer_v2(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream);
|
||||
cudaError_t cudaYoloLayer_v2(const void* input, void* output, void* softmax, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||
const void* anchors);
|
||||
|
||||
cudaError_t cudaYoloLayer_v2(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream)
|
||||
cudaError_t cudaYoloLayer_v2(const void* input, void* output, void* softmax, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream,
|
||||
const void* anchors)
|
||||
{
|
||||
dim3 threads_per_block(16, 16, 4);
|
||||
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
|
||||
@@ -73,8 +101,9 @@ cudaError_t cudaYoloLayer_v2(const void* input, void* output, const uint& batchS
|
||||
{
|
||||
gpuRegionLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
||||
numBBoxes);
|
||||
reinterpret_cast<float*>(output) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(softmax) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
||||
numBBoxes, reinterpret_cast<const float*>(anchors));
|
||||
}
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
@@ -29,10 +29,10 @@
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
int kMODEL_TYPE;
|
||||
int kNUM_BBOXES;
|
||||
int kNUM_CLASSES;
|
||||
float kBETA_NMS;
|
||||
std::vector<float> kANCHORS;
|
||||
std::vector<std::vector<int>> kMASK;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
@@ -50,25 +50,28 @@ namespace {
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer (
|
||||
cudaError_t cudaYoloLayer_r (
|
||||
const void* input, void* output, const uint& batchSize,
|
||||
const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
||||
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const float modelScale);
|
||||
|
||||
cudaError_t cudaYoloLayer_v2 (
|
||||
const void* input, void* output, const uint& batchSize,
|
||||
const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
||||
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream);
|
||||
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const float scaleXY,
|
||||
const void* anchors, const void* mask);
|
||||
|
||||
cudaError_t cudaYoloLayer_nc (
|
||||
const void* input, void* output, const uint& batchSize,
|
||||
const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
||||
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const float modelScale);
|
||||
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const float scaleXY,
|
||||
const void* anchors, const void* mask);
|
||||
|
||||
cudaError_t cudaYoloLayer_r (
|
||||
cudaError_t cudaYoloLayer (
|
||||
const void* input, void* output, const uint& batchSize,
|
||||
const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
||||
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const float modelScale);
|
||||
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const float scaleXY,
|
||||
const void* anchors, const void* mask);
|
||||
|
||||
cudaError_t cudaYoloLayer_v2 (
|
||||
const void* input, void* output, void* softmax, const uint& batchSize,
|
||||
const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
||||
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const void* anchors);
|
||||
|
||||
YoloLayer::YoloLayer (const void* data, size_t length)
|
||||
{
|
||||
@@ -79,10 +82,11 @@ YoloLayer::YoloLayer (const void* data, size_t length)
|
||||
read(d, m_GridSizeY);
|
||||
read(d, m_OutputSize);
|
||||
|
||||
read(d, m_type);
|
||||
read(d, m_new_coords);
|
||||
read(d, m_scale_x_y);
|
||||
read(d, m_beta_nms);
|
||||
read(d, m_Type);
|
||||
read(d, m_NewCoords);
|
||||
read(d, m_ScaleXY);
|
||||
read(d, m_BetaNMS);
|
||||
|
||||
uint anchorsSize;
|
||||
read(d, anchorsSize);
|
||||
for (uint i = 0; i < anchorsSize; i++) {
|
||||
@@ -90,35 +94,43 @@ YoloLayer::YoloLayer (const void* data, size_t length)
|
||||
read(d, result);
|
||||
m_Anchors.push_back(result);
|
||||
}
|
||||
|
||||
uint maskSize;
|
||||
read(d, maskSize);
|
||||
for (uint i = 0; i < maskSize; i++) {
|
||||
uint nMask;
|
||||
read(d, nMask);
|
||||
std::vector<int> pMask;
|
||||
for (uint f = 0; f < nMask; f++) {
|
||||
int result;
|
||||
read(d, result);
|
||||
pMask.push_back(result);
|
||||
}
|
||||
m_Mask.push_back(pMask);
|
||||
m_Mask.push_back(result);
|
||||
}
|
||||
|
||||
kMODEL_TYPE = m_Type;
|
||||
kNUM_BBOXES = m_NumBoxes;
|
||||
kNUM_CLASSES = m_NumClasses;
|
||||
kBETA_NMS = m_beta_nms;
|
||||
kANCHORS = m_Anchors;
|
||||
kMASK = m_Mask;
|
||||
kBETA_NMS = m_BetaNMS;
|
||||
|
||||
if (m_Anchors.size() > 0) {
|
||||
float* m_anchors = m_Anchors.data();
|
||||
CHECK(cudaMallocHost(&mAnchors, m_Anchors.size() * sizeof(float)));
|
||||
CHECK(cudaMemcpy(mAnchors, m_anchors, m_Anchors.size() * sizeof(float), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
if (m_Mask.size() > 0) {
|
||||
int* m_mask = m_Mask.data();
|
||||
CHECK(cudaMallocHost(&mMask, m_Mask.size() * sizeof(int)));
|
||||
CHECK(cudaMemcpy(mMask, m_mask, m_Mask.size() * sizeof(int), cudaMemcpyHostToDevice));
|
||||
}
|
||||
};
|
||||
|
||||
YoloLayer::YoloLayer (
|
||||
const uint& numBoxes, const uint& numClasses, const uint& gridSizeX, const uint& gridSizeY, const uint model_type, const uint new_coords, const float scale_x_y, const float beta_nms, const std::vector<float> anchors, std::vector<std::vector<int>> mask) :
|
||||
const uint& numBoxes, const uint& numClasses, const uint& gridSizeX, const uint& gridSizeY, const uint modelType, const uint newCoords, const float scaleXY, const float betaNMS, const std::vector<float> anchors, std::vector<int> mask) :
|
||||
m_NumBoxes(numBoxes),
|
||||
m_NumClasses(numClasses),
|
||||
m_GridSizeX(gridSizeX),
|
||||
m_GridSizeY(gridSizeY),
|
||||
m_type(model_type),
|
||||
m_new_coords(new_coords),
|
||||
m_scale_x_y(scale_x_y),
|
||||
m_beta_nms(beta_nms),
|
||||
m_Type(modelType),
|
||||
m_NewCoords(newCoords),
|
||||
m_ScaleXY(scaleXY),
|
||||
m_BetaNMS(betaNMS),
|
||||
m_Anchors(anchors),
|
||||
m_Mask(mask)
|
||||
{
|
||||
@@ -127,8 +139,30 @@ YoloLayer::YoloLayer (
|
||||
assert(m_GridSizeX > 0);
|
||||
assert(m_GridSizeY > 0);
|
||||
m_OutputSize = m_GridSizeX * m_GridSizeY * (m_NumBoxes * (4 + 1 + m_NumClasses));
|
||||
|
||||
if (m_Anchors.size() > 0) {
|
||||
float* m_anchors = m_Anchors.data();
|
||||
CHECK(cudaMallocHost(&mAnchors, m_Anchors.size() * sizeof(float)));
|
||||
CHECK(cudaMemcpy(mAnchors, m_anchors, m_Anchors.size() * sizeof(float), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
if (m_Mask.size() > 0) {
|
||||
int* m_mask = m_Mask.data();
|
||||
CHECK(cudaMallocHost(&mMask, m_Mask.size() * sizeof(int)));
|
||||
CHECK(cudaMemcpy(mMask, m_mask, m_Mask.size() * sizeof(int), cudaMemcpyHostToDevice));
|
||||
}
|
||||
};
|
||||
|
||||
YoloLayer::~YoloLayer()
|
||||
{
|
||||
if (m_Anchors.size() > 0) {
|
||||
CHECK(cudaFreeHost(mAnchors));
|
||||
}
|
||||
if (m_Mask.size() > 0) {
|
||||
CHECK(cudaFreeHost(mMask));
|
||||
}
|
||||
}
|
||||
|
||||
nvinfer1::Dims
|
||||
YoloLayer::getOutputDimensions(
|
||||
int index, const nvinfer1::Dims* inputs, int nbInputDims) noexcept
|
||||
@@ -159,27 +193,33 @@ int YoloLayer::enqueue(
|
||||
int batchSize, void const* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) noexcept
|
||||
{
|
||||
if (m_type == 2) { // YOLOR incorrect param
|
||||
if (m_Type == 2) { // YOLOR incorrect param: scale_x_y = 2.0
|
||||
CHECK(cudaYoloLayer_r(
|
||||
inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
|
||||
m_OutputSize, stream, m_scale_x_y));
|
||||
m_OutputSize, stream, 2.0, mAnchors, mMask));
|
||||
}
|
||||
else if (m_type == 1) {
|
||||
if (m_new_coords) {
|
||||
else if (m_Type == 1) {
|
||||
if (m_NewCoords) {
|
||||
CHECK(cudaYoloLayer_nc(
|
||||
inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
|
||||
m_OutputSize, stream, m_scale_x_y));
|
||||
m_OutputSize, stream, m_ScaleXY, mAnchors, mMask));
|
||||
}
|
||||
else {
|
||||
CHECK(cudaYoloLayer(
|
||||
inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
|
||||
m_OutputSize, stream, m_scale_x_y));
|
||||
m_OutputSize, stream, m_ScaleXY, mAnchors, mMask));
|
||||
}
|
||||
}
|
||||
else {
|
||||
void* softmax;
|
||||
cudaMallocHost(&softmax, sizeof(outputs[0]));
|
||||
cudaMemcpy(softmax, outputs[0], sizeof(outputs[0]), cudaMemcpyHostToDevice);
|
||||
|
||||
CHECK(cudaYoloLayer_v2(
|
||||
inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
|
||||
m_OutputSize, stream));
|
||||
inputs[0], outputs[0], softmax, batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
|
||||
m_OutputSize, stream, mAnchors));
|
||||
|
||||
CHECK(cudaFreeHost(softmax));
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -193,13 +233,10 @@ size_t YoloLayer::getSerializationSize() const noexcept
|
||||
int maskSum = 1;
|
||||
for (uint i = 0; i < m_Mask.size(); i++) {
|
||||
maskSum += 1;
|
||||
for (uint f = 0; f < m_Mask[i].size(); f++) {
|
||||
maskSum += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return sizeof(m_NumBoxes) + sizeof(m_NumClasses) + sizeof(m_GridSizeX) + sizeof(m_GridSizeY) + sizeof(m_OutputSize) + sizeof(m_type)
|
||||
+ sizeof(m_new_coords) + sizeof(m_scale_x_y) + sizeof(m_beta_nms) + anchorsSum * sizeof(float) + maskSum * sizeof(int);
|
||||
return sizeof(m_NumBoxes) + sizeof(m_NumClasses) + sizeof(m_GridSizeX) + sizeof(m_GridSizeY) + sizeof(m_OutputSize) + sizeof(m_Type)
|
||||
+ sizeof(m_NewCoords) + sizeof(m_ScaleXY) + sizeof(m_BetaNMS) + anchorsSum * sizeof(float) + maskSum * sizeof(int);
|
||||
}
|
||||
|
||||
void YoloLayer::serialize(void* buffer) const noexcept
|
||||
@@ -211,33 +248,32 @@ void YoloLayer::serialize(void* buffer) const noexcept
|
||||
write(d, m_GridSizeY);
|
||||
write(d, m_OutputSize);
|
||||
|
||||
write(d, m_type);
|
||||
write(d, m_new_coords);
|
||||
write(d, m_scale_x_y);
|
||||
write(d, m_beta_nms);
|
||||
write(d, m_Type);
|
||||
write(d, m_NewCoords);
|
||||
write(d, m_ScaleXY);
|
||||
write(d, m_BetaNMS);
|
||||
|
||||
uint anchorsSize = m_Anchors.size();
|
||||
write(d, anchorsSize);
|
||||
for (uint i = 0; i < anchorsSize; i++) {
|
||||
write(d, m_Anchors[i]);
|
||||
}
|
||||
|
||||
uint maskSize = m_Mask.size();
|
||||
write(d, maskSize);
|
||||
for (uint i = 0; i < maskSize; i++) {
|
||||
uint pMaskSize = m_Mask[i].size();
|
||||
write(d, pMaskSize);
|
||||
for (uint f = 0; f < pMaskSize; f++) {
|
||||
write(d, m_Mask[i][f]);
|
||||
}
|
||||
write(d, m_Mask[i]);
|
||||
}
|
||||
|
||||
kMODEL_TYPE = m_Type;
|
||||
kNUM_BBOXES = m_NumBoxes;
|
||||
kNUM_CLASSES = m_NumClasses;
|
||||
kBETA_NMS = m_beta_nms;
|
||||
kANCHORS = m_Anchors;
|
||||
kMASK = m_Mask;
|
||||
kBETA_NMS = m_BetaNMS;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* YoloLayer::clone() const noexcept
|
||||
{
|
||||
return new YoloLayer (m_NumBoxes, m_NumClasses, m_GridSizeX, m_GridSizeY, m_type, m_new_coords, m_scale_x_y, m_beta_nms, m_Anchors, m_Mask);
|
||||
return new YoloLayer (m_NumBoxes, m_NumClasses, m_GridSizeX, m_GridSizeY, m_Type, m_NewCoords, m_ScaleXY, m_BetaNMS, m_Anchors, m_Mask);
|
||||
}
|
||||
|
||||
REGISTER_TENSORRT_PLUGIN(YoloLayerPluginCreator);
|
||||
@@ -57,8 +57,9 @@ class YoloLayer : public nvinfer1::IPluginV2
|
||||
public:
|
||||
YoloLayer (const void* data, size_t length);
|
||||
YoloLayer (const uint& numBoxes, const uint& numClasses, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint model_type, const uint new_coords, const float scale_x_y, const float beta_nms,
|
||||
const std::vector<float> anchors, const std::vector<std::vector<int>> mask);
|
||||
const uint modelType, const uint newCoords, const float scaleXY, const float betaNMS,
|
||||
const std::vector<float> anchors, const std::vector<int> mask);
|
||||
~YoloLayer ();
|
||||
const char* getPluginType () const noexcept override { return YOLOLAYER_PLUGIN_NAME; }
|
||||
const char* getPluginVersion () const noexcept override { return YOLOLAYER_PLUGIN_VERSION; }
|
||||
int getNbOutputs () const noexcept override { return 1; }
|
||||
@@ -101,12 +102,15 @@ private:
|
||||
uint64_t m_OutputSize {0};
|
||||
std::string m_Namespace {""};
|
||||
|
||||
uint m_type {0};
|
||||
uint m_new_coords {0};
|
||||
float m_scale_x_y {0};
|
||||
float m_beta_nms {0};
|
||||
uint m_Type {0};
|
||||
uint m_NewCoords {0};
|
||||
float m_ScaleXY {0};
|
||||
float m_BetaNMS {0};
|
||||
std::vector<float> m_Anchors;
|
||||
std::vector<std::vector<int>> m_Mask;
|
||||
std::vector<int> m_Mask;
|
||||
|
||||
void* mAnchors;
|
||||
void* mMask;
|
||||
};
|
||||
|
||||
class YoloLayerPluginCreator : public nvinfer1::IPluginCreator
|
||||
@@ -148,9 +152,9 @@ private:
|
||||
std::string m_Namespace {""};
|
||||
};
|
||||
|
||||
extern int kMODEL_TYPE;
|
||||
extern int kNUM_BBOXES;
|
||||
extern int kNUM_CLASSES;
|
||||
extern float kBETA_NMS;
|
||||
extern std::vector<float> kANCHORS;
|
||||
extern std::vector<std::vector<int>> kMASK;
|
||||
|
||||
#endif // __YOLO_PLUGINS__
|
||||
|
||||
19
readme.md
19
readme.md
@@ -23,7 +23,8 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
|
||||
* Support for reorg, implicit and channel layers (YOLOR)
|
||||
* YOLOv5 6.0 native support
|
||||
* YOLOR native support
|
||||
* **Models benchmarks**
|
||||
* Models benchmarks
|
||||
* **GPU YOLO Decoder (moved from CPU to GPU to get better performance)**
|
||||
|
||||
##
|
||||
|
||||
@@ -43,6 +44,8 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
|
||||
|
||||
### Requirements
|
||||
|
||||
#### x86 platform
|
||||
|
||||
* [Ubuntu 18.04](https://releases.ubuntu.com/18.04.6/)
|
||||
* [CUDA 11.4.3](https://developer.nvidia.com/cuda-toolkit)
|
||||
* [TensorRT 8.0 GA (8.0.1)](https://developer.nvidia.com/tensorrt)
|
||||
@@ -51,10 +54,22 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
|
||||
* [NVIDIA DeepStream SDK 6.0](https://developer.nvidia.com/deepstream-sdk)
|
||||
* [DeepStream-Yolo](https://github.com/marcoslucianops/DeepStream-Yolo)
|
||||
|
||||
**For YOLOv5 and YOLOR**:
|
||||
#### Jetson platform
|
||||
|
||||
* [JetPack 4.6](https://developer.nvidia.com/embedded/jetpack)
|
||||
* [NVIDIA DeepStream SDK 6.0](https://developer.nvidia.com/deepstream-sdk)
|
||||
* [DeepStream-Yolo](https://github.com/marcoslucianops/DeepStream-Yolo)
|
||||
|
||||
### For YOLOv5 and YOLOR
|
||||
|
||||
#### x86 platform
|
||||
|
||||
* [PyTorch >= 1.7.0](https://pytorch.org/get-started/locally/)
|
||||
|
||||
#### Jetson platform
|
||||
|
||||
* [PyTorch >= 1.7.0](https://forums.developer.nvidia.com/t/pytorch-for-jetson-version-1-10-now-available/72048)
|
||||
|
||||
##
|
||||
|
||||
### Tested models
|
||||
|
||||
Reference in New Issue
Block a user