New optimized NMS
This commit is contained in:
16
README.md
16
README.md
@@ -24,9 +24,9 @@ NVIDIA DeepStream SDK 6.1 / 6.0.1 / 6.0 configuration for YOLO models
|
||||
* YOLOv5 support
|
||||
* YOLOR support
|
||||
* **GPU YOLO Decoder** [#138](https://github.com/marcoslucianops/DeepStream-Yolo/issues/138)
|
||||
* **GPU Batched NMS** [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142)
|
||||
* **PP-YOLOE support**
|
||||
* **YOLOv7 support**
|
||||
* **Optimized NMS** [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142)
|
||||
|
||||
##
|
||||
|
||||
@@ -446,20 +446,20 @@ config-file=config_infer_primary_yoloV2.txt
|
||||
|
||||
### NMS Configuration
|
||||
|
||||
To change the `iou-threshold`, `score-threshold` and `topk` values, modify the `config_nms.txt` file and regenerate the model engine file.
|
||||
To change the `nms-iou-threshold`, `pre-cluster-threshold` and `topk` values, modify the config_infer file and regenerate the model engine file
|
||||
|
||||
```
|
||||
[property]
|
||||
iou-threshold=0.45
|
||||
score-threshold=0.25
|
||||
[class-attrs-all]
|
||||
nms-iou-threshold=0.45
|
||||
pre-cluster-threshold=0.25
|
||||
topk=300
|
||||
```
|
||||
|
||||
**NOTE**: It is important to regenerate the engine to get the max detection speed based on `pre-cluster-threshold` you set.
|
||||
|
||||
**NOTE**: Lower `topk` values will result in more performance.
|
||||
|
||||
**NOTE**: Make sure to set `cluster-mode=4` in the config_infer file.
|
||||
|
||||
**NOTE**: You are still able to change the `pre-cluster-threshold` values in the config_infer files.
|
||||
**NOTE**: Make sure to set `cluster-mode=2` in the config_infer file.
|
||||
|
||||
##
|
||||
|
||||
|
||||
@@ -14,11 +14,13 @@ interval=0
|
||||
gie-unique-id=1
|
||||
process-mode=1
|
||||
network-type=0
|
||||
cluster-mode=4
|
||||
cluster-mode=2
|
||||
maintain-aspect-ratio=0
|
||||
parse-bbox-func-name=NvDsInferParseYolo
|
||||
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
|
||||
engine-create-func-name=NvDsInferYoloCudaEngineGet
|
||||
|
||||
[class-attrs-all]
|
||||
pre-cluster-threshold=0
|
||||
nms-iou-threshold=0.45
|
||||
pre-cluster-threshold=0.25
|
||||
topk=300
|
||||
|
||||
@@ -15,11 +15,13 @@ interval=0
|
||||
gie-unique-id=1
|
||||
process-mode=1
|
||||
network-type=0
|
||||
cluster-mode=4
|
||||
cluster-mode=2
|
||||
maintain-aspect-ratio=0
|
||||
parse-bbox-func-name=NvDsInferParseYolo
|
||||
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
|
||||
engine-create-func-name=NvDsInferYoloCudaEngineGet
|
||||
|
||||
[class-attrs-all]
|
||||
pre-cluster-threshold=0
|
||||
nms-iou-threshold=0.45
|
||||
pre-cluster-threshold=0.25
|
||||
topk=300
|
||||
|
||||
@@ -14,11 +14,13 @@ interval=0
|
||||
gie-unique-id=1
|
||||
process-mode=1
|
||||
network-type=0
|
||||
cluster-mode=4
|
||||
cluster-mode=2
|
||||
maintain-aspect-ratio=0
|
||||
parse-bbox-func-name=NvDsInferParseYolo
|
||||
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
|
||||
engine-create-func-name=NvDsInferYoloCudaEngineGet
|
||||
|
||||
[class-attrs-all]
|
||||
pre-cluster-threshold=0
|
||||
nms-iou-threshold=0.45
|
||||
pre-cluster-threshold=0.25
|
||||
topk=300
|
||||
|
||||
@@ -14,11 +14,13 @@ interval=0
|
||||
gie-unique-id=1
|
||||
process-mode=1
|
||||
network-type=0
|
||||
cluster-mode=4
|
||||
cluster-mode=2
|
||||
maintain-aspect-ratio=1
|
||||
parse-bbox-func-name=NvDsInferParseYolo
|
||||
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
|
||||
engine-create-func-name=NvDsInferYoloCudaEngineGet
|
||||
|
||||
[class-attrs-all]
|
||||
pre-cluster-threshold=0
|
||||
nms-iou-threshold=0.45
|
||||
pre-cluster-threshold=0.25
|
||||
topk=300
|
||||
|
||||
@@ -14,11 +14,13 @@ interval=0
|
||||
gie-unique-id=1
|
||||
process-mode=1
|
||||
network-type=0
|
||||
cluster-mode=4
|
||||
cluster-mode=2
|
||||
maintain-aspect-ratio=0
|
||||
parse-bbox-func-name=NvDsInferParseYolo
|
||||
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
|
||||
engine-create-func-name=NvDsInferYoloCudaEngineGet
|
||||
|
||||
[class-attrs-all]
|
||||
pre-cluster-threshold=0
|
||||
nms-iou-threshold=0.45
|
||||
pre-cluster-threshold=0.25
|
||||
topk=300
|
||||
|
||||
@@ -14,11 +14,13 @@ interval=0
|
||||
gie-unique-id=1
|
||||
process-mode=1
|
||||
network-type=0
|
||||
cluster-mode=4
|
||||
cluster-mode=2
|
||||
maintain-aspect-ratio=1
|
||||
parse-bbox-func-name=NvDsInferParseYolo
|
||||
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
|
||||
engine-create-func-name=NvDsInferYoloCudaEngineGet
|
||||
|
||||
[class-attrs-all]
|
||||
pre-cluster-threshold=0
|
||||
nms-iou-threshold=0.45
|
||||
pre-cluster-threshold=0.25
|
||||
topk=300
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
[property]
|
||||
iou-threshold=0.45
|
||||
score-threshold=0.25
|
||||
topk=300
|
||||
@@ -69,12 +69,8 @@ all: $(TARGET_LIB)
|
||||
%.o: %.cpp $(INCS) Makefile
|
||||
$(CC) -c $(COMMON) -o $@ $(CFLAGS) $<
|
||||
|
||||
ifeq ($(CUDA_VER), 10.2)
|
||||
CUB=-I/usr/local/cuda-$(CUDA_VER)/include/thrust/system/cuda/detail
|
||||
endif
|
||||
|
||||
%.o: %.cu $(INCS) Makefile
|
||||
$(NVCC) -c -o $@ $(CUB) --compiler-options '-fPIC' $<
|
||||
$(NVCC) -c -o $@ --compiler-options '-fPIC' $<
|
||||
|
||||
$(TARGET_LIB) : $(TARGET_OBJS)
|
||||
$(CC) -o $@ $(TARGET_OBJS) $(LFLAGS)
|
||||
|
||||
@@ -45,6 +45,9 @@ nvinfer1::ITensor* routeLayer(
|
||||
}
|
||||
layers += std::to_string(idxLayers[idxLayers.size() - 1]);
|
||||
|
||||
if (concatInputs.size() == 1)
|
||||
return concatInputs[0];
|
||||
|
||||
int axis = 0;
|
||||
if (block.find("axis") != block.end())
|
||||
axis = std::stoi(block.at("axis"));
|
||||
|
||||
@@ -38,7 +38,8 @@ static bool getYoloNetworkInfo (NetworkInfo &networkInfo, const NvDsInferContext
|
||||
std::string yoloType;
|
||||
|
||||
std::transform(yoloCfg.begin(), yoloCfg.end(), yoloCfg.begin(), [] (uint8_t c) {
|
||||
return std::tolower (c);});
|
||||
return std::tolower(c);
|
||||
});
|
||||
|
||||
yoloType = yoloCfg.substr(0, yoloCfg.find(".cfg"));
|
||||
|
||||
@@ -50,25 +51,23 @@ static bool getYoloNetworkInfo (NetworkInfo &networkInfo, const NvDsInferContext
|
||||
networkInfo.deviceType = (initParams->useDLA ? "kDLA" : "kGPU");
|
||||
networkInfo.numDetectedClasses = initParams->numDetectedClasses;
|
||||
networkInfo.clusterMode = initParams->clusterMode;
|
||||
networkInfo.scoreThreshold = initParams->perClassDetectionParams->preClusterThreshold;
|
||||
|
||||
if(initParams->networkMode == 0) {
|
||||
if (initParams->networkMode == 0)
|
||||
networkInfo.networkMode = "FP32";
|
||||
}
|
||||
else if(initParams->networkMode == 1) {
|
||||
else if (initParams->networkMode == 1)
|
||||
networkInfo.networkMode = "INT8";
|
||||
}
|
||||
else if(initParams->networkMode == 2) {
|
||||
else if (initParams->networkMode == 2)
|
||||
networkInfo.networkMode = "FP16";
|
||||
}
|
||||
|
||||
if (networkInfo.configFilePath.empty() ||
|
||||
networkInfo.wtsFilePath.empty()) {
|
||||
if (networkInfo.configFilePath.empty() || networkInfo.wtsFilePath.empty())
|
||||
{
|
||||
std::cerr << "YOLO config file or weights file is not specified\n" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!fileExists(networkInfo.configFilePath) ||
|
||||
!fileExists(networkInfo.wtsFilePath)) {
|
||||
if (!fileExists(networkInfo.configFilePath) || !fileExists(networkInfo.wtsFilePath))
|
||||
{
|
||||
std::cerr << "YOLO config file or weights file is not exist\n" << std::endl;
|
||||
return false;
|
||||
}
|
||||
@@ -80,9 +79,8 @@ static bool getYoloNetworkInfo (NetworkInfo &networkInfo, const NvDsInferContext
|
||||
IModelParser* NvDsInferCreateModelParser(
|
||||
const NvDsInferContextInitParams* initParams) {
|
||||
NetworkInfo networkInfo;
|
||||
if (!getYoloNetworkInfo(networkInfo, initParams)) {
|
||||
if (!getYoloNetworkInfo(networkInfo, initParams))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new Yolo(networkInfo);
|
||||
}
|
||||
@@ -102,16 +100,14 @@ bool NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder * const builder,
|
||||
nvinfer1::ICudaEngine *& cudaEngine)
|
||||
{
|
||||
NetworkInfo networkInfo;
|
||||
if (!getYoloNetworkInfo(networkInfo, initParams)) {
|
||||
if (!getYoloNetworkInfo(networkInfo, initParams))
|
||||
return false;
|
||||
}
|
||||
|
||||
Yolo yolo(networkInfo);
|
||||
cudaEngine = yolo.createEngine (builder, builderConfig);
|
||||
if (cudaEngine == nullptr)
|
||||
{
|
||||
std::cerr << "Failed to build CUDA engine on "
|
||||
<< networkInfo.configFilePath << std::endl;
|
||||
std::cerr << "Failed to build CUDA engine on " << networkInfo.configFilePath << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ static void addBBoxProposal(
|
||||
}
|
||||
|
||||
static std::vector<NvDsInferParseObjectInfo> decodeYoloTensor(
|
||||
const int* counts, const float* boxes, const float* scores, const float* classes, const uint& netW, const uint& netH)
|
||||
const int* counts, const float* boxes, const float* scores, const int* classes, const uint& netW, const uint& netH)
|
||||
{
|
||||
std::vector<NvDsInferParseObjectInfo> binfo;
|
||||
|
||||
@@ -118,7 +118,7 @@ static bool NvDsInferParseCustomYolo(
|
||||
std::vector<NvDsInferParseObjectInfo> outObjs =
|
||||
decodeYoloTensor(
|
||||
(const int*)(counts.buffer), (const float*)(boxes.buffer), (const float*)(scores.buffer),
|
||||
(const float*)(classes.buffer), networkInfo.width, networkInfo.height);
|
||||
(const int*)(classes.buffer), networkInfo.width, networkInfo.height);
|
||||
|
||||
objects.insert(objects.end(), outObjs.begin(), outObjs.end());
|
||||
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
/*
|
||||
* Created by Marcos Luciano
|
||||
* https://www.github.com/marcoslucianops
|
||||
*/
|
||||
|
||||
#include <cub/device/device_radix_sort.cuh>
|
||||
|
||||
__global__ void sortOutput(
|
||||
int* d_indexes, float* d_scores, float* d_boxes, int* d_classes, float* bboxData, float* scoreData,
|
||||
const uint numOutputClasses, const int topk)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (x_id >= topk)
|
||||
return;
|
||||
|
||||
int index = d_indexes[x_id];
|
||||
int maxIndex = d_classes[index];
|
||||
bboxData[x_id * 4 + 0] = d_boxes[index * 4 + 0];
|
||||
bboxData[x_id * 4 + 1] = d_boxes[index * 4 + 1];
|
||||
bboxData[x_id * 4 + 2] = d_boxes[index * 4 + 2];
|
||||
bboxData[x_id * 4 + 3] = d_boxes[index * 4 + 3];
|
||||
scoreData[x_id * numOutputClasses + maxIndex] = d_scores[x_id] - 1.f;
|
||||
}
|
||||
|
||||
cudaError_t sortDetections(
|
||||
void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* bboxData, void* scoreData, void* countData,
|
||||
const uint& batchSize, uint64_t& outputSize, uint& topK, const uint& numOutputClasses, cudaStream_t stream);
|
||||
|
||||
cudaError_t sortDetections(
|
||||
void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* bboxData, void* scoreData, void* countData,
|
||||
const uint& batchSize, uint64_t& outputSize, uint& topK, const uint& numOutputClasses, cudaStream_t stream)
|
||||
{
|
||||
for (unsigned int batch = 0; batch < batchSize; ++batch)
|
||||
{
|
||||
int* _d_indexes = reinterpret_cast<int*>(d_indexes) + (batch * outputSize);
|
||||
float* _d_scores = reinterpret_cast<float*>(d_scores) + (batch * outputSize);
|
||||
|
||||
int* _countData = reinterpret_cast<int*>(countData) + (batch);
|
||||
int count;
|
||||
cudaMemcpy(&count, _countData, sizeof(int), cudaMemcpyDeviceToHost);
|
||||
|
||||
if (count == 0)
|
||||
{
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
size_t begin_bit = 0;
|
||||
size_t end_bit = sizeof(float) * 8;
|
||||
|
||||
float *d_keys_out = NULL;
|
||||
int *d_values_out = NULL;
|
||||
|
||||
cudaMalloc((void **)&d_keys_out, count * sizeof(float));
|
||||
cudaMalloc((void **)&d_values_out, count * sizeof(int));
|
||||
|
||||
void* d_temp_storage = NULL;
|
||||
size_t temp_storage_bytes = 0;
|
||||
|
||||
cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, _d_scores, d_keys_out, _d_indexes,
|
||||
d_values_out, count, begin_bit, end_bit);
|
||||
|
||||
cudaMalloc(&d_temp_storage, temp_storage_bytes);
|
||||
|
||||
cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, _d_scores, d_keys_out, _d_indexes,
|
||||
d_values_out, count, begin_bit, end_bit);
|
||||
|
||||
cudaMemcpy(_d_scores, d_keys_out, count * sizeof(float), cudaMemcpyDeviceToDevice);
|
||||
cudaMemcpy(_d_indexes, d_values_out, count * sizeof(int), cudaMemcpyDeviceToDevice);
|
||||
|
||||
int _topK = count < topK ? count : topK;
|
||||
|
||||
int threads_per_block = 16;
|
||||
int number_of_blocks = 0;
|
||||
|
||||
if (_topK % 2 == 0 && _topK >= threads_per_block)
|
||||
number_of_blocks = _topK / threads_per_block;
|
||||
else
|
||||
number_of_blocks = (_topK / threads_per_block) + 1;
|
||||
|
||||
sortOutput<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
_d_indexes, _d_scores, reinterpret_cast<float*>(d_boxes) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<int*>(d_classes) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(bboxData) + (batch * topK * 4),
|
||||
reinterpret_cast<float*>(scoreData) + (batch * topK * numOutputClasses), numOutputClasses, _topK);
|
||||
|
||||
cudaFree(d_keys_out);
|
||||
cudaFree(d_values_out);
|
||||
cudaFree(d_temp_storage);
|
||||
}
|
||||
return cudaGetLastError();
|
||||
}
|
||||
@@ -155,11 +155,3 @@ void printLayerInfo(
|
||||
std::cout << std::setw(20) << std::left << layerInput << std::setw(20) << std::left << layerOutput;
|
||||
std::cout << weightPtr << std::endl;
|
||||
}
|
||||
|
||||
std::string getAbsPath(std::string path)
|
||||
{
|
||||
std::size_t found = path.rfind("/");
|
||||
if (found != std::string::npos)
|
||||
path.erase(path.begin() + found, path.end());
|
||||
return path;
|
||||
}
|
||||
|
||||
@@ -43,6 +43,5 @@ std::string dimsToString(const nvinfer1::Dims d);
|
||||
int getNumChannels(nvinfer1::ITensor* t);
|
||||
void printLayerInfo(
|
||||
std::string layerIndex, std::string layerName, std::string layerInput, std::string layerOutput, std::string weightPtr);
|
||||
std::string getAbsPath(std::string path);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -41,6 +41,7 @@ Yolo::Yolo(const NetworkInfo& networkInfo)
|
||||
m_NumDetectedClasses(networkInfo.numDetectedClasses),
|
||||
m_ClusterMode(networkInfo.clusterMode),
|
||||
m_NetworkMode(networkInfo.networkMode),
|
||||
m_ScoreThreshold(networkInfo.scoreThreshold),
|
||||
m_InputH(0),
|
||||
m_InputW(0),
|
||||
m_InputC(0),
|
||||
@@ -48,10 +49,7 @@ Yolo::Yolo(const NetworkInfo& networkInfo)
|
||||
m_NumClasses(0),
|
||||
m_LetterBox(0),
|
||||
m_NewCoords(0),
|
||||
m_YoloCount(0),
|
||||
m_IouThreshold(0),
|
||||
m_ScoreThreshold(0),
|
||||
m_TopK(0)
|
||||
m_YoloCount(0)
|
||||
{}
|
||||
|
||||
Yolo::~Yolo()
|
||||
@@ -66,15 +64,6 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder, nvinfer1
|
||||
m_ConfigBlocks = parseConfigFile(m_ConfigFilePath);
|
||||
parseConfigBlocks();
|
||||
|
||||
std::string configNMS = getAbsPath(m_WtsFilePath) + "/config_nms.txt";
|
||||
if (!fileExists(configNMS))
|
||||
{
|
||||
std::cerr << "YOLO config_nms.txt file is not specified\n" << std::endl;
|
||||
assert(0);
|
||||
}
|
||||
m_ConfigNMSBlocks = parseConfigFile(configNMS);
|
||||
parseConfigNMSBlocks();
|
||||
|
||||
nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0);
|
||||
if (parseModel(*network) != NVDSINFER_SUCCESS)
|
||||
{
|
||||
@@ -94,9 +83,9 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder, nvinfer1
|
||||
std::cout << "NOTE: letter_box is set in cfg file, make sure to set maintain-aspect-ratio=1 in config_infer file"
|
||||
<< " to get better accuracy\n" << std::endl;
|
||||
}
|
||||
if (m_ClusterMode != 4)
|
||||
if (m_ClusterMode != 2)
|
||||
{
|
||||
std::cout << "NOTE: Wrong cluster-mode is set, make sure to set cluster-mode=4 in config_infer file\n"
|
||||
std::cout << "NOTE: Wrong cluster-mode is set, make sure to set cluster-mode=2 in config_infer file\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
@@ -452,54 +441,31 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
|
||||
outputSize += curYoloTensor.gridSizeX * curYoloTensor.gridSizeY * curYoloTensor.numBBoxes;
|
||||
}
|
||||
|
||||
if (m_TopK > outputSize) {
|
||||
std::cout << "\ntopk > Number of outputs\nPlease change the topk to " << outputSize
|
||||
<< " or less in config_nms.txt file\n" << std::endl;
|
||||
assert(0);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* yoloPlugin = new YoloLayer(
|
||||
m_InputW, m_InputH, m_NumClasses, m_NewCoords, m_YoloTensors, outputSize, modelType, m_TopK, m_ScoreThreshold);
|
||||
m_InputW, m_InputH, m_NumClasses, m_NewCoords, m_YoloTensors, outputSize, modelType, m_ScoreThreshold);
|
||||
assert(yoloPlugin != nullptr);
|
||||
nvinfer1::IPluginV2Layer* yolo = network.addPluginV2(yoloTensorInputs, m_YoloCount, *yoloPlugin);
|
||||
assert(yolo != nullptr);
|
||||
std::string yoloLayerName = "yolo";
|
||||
yolo->setName(yoloLayerName.c_str());
|
||||
|
||||
nvinfer1::ITensor* yoloTensorOutputs[] = {yolo->getOutput(0), yolo->getOutput(1)};
|
||||
|
||||
nvinfer1::plugin::NMSParameters nmsParams;
|
||||
nmsParams.shareLocation = true;
|
||||
nmsParams.backgroundLabelId = -1;
|
||||
nmsParams.numClasses = m_NumClasses;
|
||||
nmsParams.topK = m_TopK;
|
||||
nmsParams.keepTopK = m_TopK;
|
||||
nmsParams.scoreThreshold = m_ScoreThreshold;
|
||||
nmsParams.iouThreshold = m_IouThreshold;
|
||||
nmsParams.isNormalized = false;
|
||||
|
||||
std::string nmslayerName = "batchedNMS";
|
||||
nvinfer1::IPluginV2* batchedNMS = createBatchedNMSPlugin(nmsParams);
|
||||
nvinfer1::IPluginV2Layer* nms = network.addPluginV2(yoloTensorOutputs, 2, *batchedNMS);
|
||||
nms->setName(nmslayerName.c_str());
|
||||
nvinfer1::ITensor* num_detections = nms->getOutput(0);
|
||||
nmslayerName = "num_detections";
|
||||
num_detections->setName(nmslayerName.c_str());
|
||||
nvinfer1::ITensor* nmsed_boxes = nms->getOutput(1);
|
||||
nmslayerName = "nmsed_boxes";
|
||||
nmsed_boxes->setName(nmslayerName.c_str());
|
||||
nvinfer1::ITensor* nmsed_scores = nms->getOutput(2);
|
||||
nmslayerName = "nmsed_scores";
|
||||
nmsed_scores->setName(nmslayerName.c_str());
|
||||
nvinfer1::ITensor* nmsed_classes = nms->getOutput(3);
|
||||
nmslayerName = "nmsed_classes";
|
||||
nmsed_classes->setName(nmslayerName.c_str());
|
||||
std::string outputlayerName;
|
||||
nvinfer1::ITensor* num_detections = yolo->getOutput(0);
|
||||
outputlayerName = "num_detections";
|
||||
num_detections->setName(outputlayerName.c_str());
|
||||
nvinfer1::ITensor* detection_boxes = yolo->getOutput(1);
|
||||
outputlayerName = "detection_boxes";
|
||||
detection_boxes->setName(outputlayerName.c_str());
|
||||
nvinfer1::ITensor* detection_scores = yolo->getOutput(2);
|
||||
outputlayerName = "detection_scores";
|
||||
detection_scores->setName(outputlayerName.c_str());
|
||||
nvinfer1::ITensor* detection_classes = yolo->getOutput(3);
|
||||
outputlayerName = "detection_classes";
|
||||
detection_classes->setName(outputlayerName.c_str());
|
||||
network.markOutput(*num_detections);
|
||||
network.markOutput(*nmsed_boxes);
|
||||
network.markOutput(*nmsed_scores);
|
||||
network.markOutput(*nmsed_classes);
|
||||
|
||||
printLayerInfo("", "batched_nms", "-", "-", "-");
|
||||
network.markOutput(*detection_boxes);
|
||||
network.markOutput(*detection_scores);
|
||||
network.markOutput(*detection_classes);
|
||||
}
|
||||
else {
|
||||
std::cout << "\nError in yolo cfg file" << std::endl;
|
||||
@@ -659,20 +625,6 @@ void Yolo::parseConfigBlocks()
|
||||
}
|
||||
}
|
||||
|
||||
void Yolo::parseConfigNMSBlocks()
|
||||
{
|
||||
auto block = m_ConfigNMSBlocks[0];
|
||||
|
||||
assert((block.at("type") == "property") && "Missing 'property' param in nms cfg");
|
||||
assert((block.find("iou-threshold") != block.end()) && "Missing 'iou-threshold' param in nms cfg");
|
||||
assert((block.find("score-threshold") != block.end()) && "Missing 'score-threshold' param in nms cfg");
|
||||
assert((block.find("topk") != block.end()) && "Missing 'topk' param in nms cfg");
|
||||
|
||||
m_IouThreshold = std::stof(block.at("iou-threshold"));
|
||||
m_ScoreThreshold = std::stof(block.at("score-threshold"));
|
||||
m_TopK = std::stoul(block.at("topk"));
|
||||
}
|
||||
|
||||
void Yolo::destroyNetworkUtils()
|
||||
{
|
||||
for (uint i = 0; i < m_TrtWeights.size(); ++i)
|
||||
|
||||
@@ -53,6 +53,7 @@ struct NetworkInfo
|
||||
std::string deviceType;
|
||||
uint numDetectedClasses;
|
||||
int clusterMode;
|
||||
float scoreThreshold;
|
||||
std::string networkMode;
|
||||
};
|
||||
|
||||
@@ -93,6 +94,7 @@ protected:
|
||||
const uint m_NumDetectedClasses;
|
||||
const int m_ClusterMode;
|
||||
const std::string m_NetworkMode;
|
||||
const float m_ScoreThreshold;
|
||||
|
||||
uint m_InputH;
|
||||
uint m_InputW;
|
||||
@@ -102,13 +104,9 @@ protected:
|
||||
uint m_LetterBox;
|
||||
uint m_NewCoords;
|
||||
uint m_YoloCount;
|
||||
float m_IouThreshold;
|
||||
float m_ScoreThreshold;
|
||||
uint m_TopK;
|
||||
|
||||
std::vector<TensorInfo> m_YoloTensors;
|
||||
std::vector<std::map<std::string, std::string>> m_ConfigBlocks;
|
||||
std::vector<std::map<std::string, std::string>> m_ConfigNMSBlocks;
|
||||
std::vector<nvinfer1::Weights> m_TrtWeights;
|
||||
|
||||
private:
|
||||
@@ -118,8 +116,6 @@ private:
|
||||
|
||||
void parseConfigBlocks();
|
||||
|
||||
void parseConfigNMSBlocks();
|
||||
|
||||
void destroyNetworkUtils();
|
||||
};
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
|
||||
|
||||
__global__ void gpuYoloLayer(
|
||||
const float* input, int* d_indexes, float* d_scores, float* d_boxes, int* d_classes, int* countData,
|
||||
const float* input, int* num_detections, float* detection_boxes, float* detection_scores, int* detection_classes,
|
||||
const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, const uint gridSizeY,
|
||||
const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors, const int* mask)
|
||||
{
|
||||
@@ -28,7 +28,7 @@ __global__ void gpuYoloLayer(
|
||||
if (objectness < scoreThreshold)
|
||||
return;
|
||||
|
||||
int count = (int)atomicAdd(countData, 1);
|
||||
int count = (int)atomicAdd(num_detections, 1);
|
||||
|
||||
const float alpha = scaleXY;
|
||||
const float beta = -0.5 * (scaleXY - 1);
|
||||
@@ -64,23 +64,22 @@ __global__ void gpuYoloLayer(
|
||||
}
|
||||
}
|
||||
|
||||
d_indexes[count] = count;
|
||||
d_scores[count] = objectness * maxProb + 1.f;
|
||||
d_boxes[count * 4 + 0] = x - 0.5 * w;
|
||||
d_boxes[count * 4 + 1] = y - 0.5 * h;
|
||||
d_boxes[count * 4 + 2] = x + 0.5 * w;
|
||||
d_boxes[count * 4 + 3] = y + 0.5 * h;
|
||||
d_classes[count] = maxIndex;
|
||||
detection_boxes[count * 4 + 0] = x - 0.5 * w;
|
||||
detection_boxes[count * 4 + 1] = y - 0.5 * h;
|
||||
detection_boxes[count * 4 + 2] = x + 0.5 * w;
|
||||
detection_boxes[count * 4 + 3] = y + 0.5 * h;
|
||||
detection_scores[count] = objectness * maxProb;
|
||||
detection_classes[count] = maxIndex;
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer(
|
||||
const void* input, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
|
||||
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
|
||||
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
|
||||
|
||||
cudaError_t cudaYoloLayer(
|
||||
const void* input, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
|
||||
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
|
||||
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream)
|
||||
@@ -94,10 +93,10 @@ cudaError_t cudaYoloLayer(
|
||||
{
|
||||
gpuYoloLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * inputSize),
|
||||
reinterpret_cast<int*>(d_indexes) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(d_scores) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(d_boxes) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<int*>(d_classes) + (batch * outputSize), reinterpret_cast<int*>(countData) + (batch),
|
||||
reinterpret_cast<int*>(num_detections) + (batch),
|
||||
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
|
||||
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
|
||||
scoreThreshold, netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes, scaleXY,
|
||||
reinterpret_cast<const float*>(anchors), reinterpret_cast<const int*>(mask));
|
||||
}
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
#include <stdio.h>
|
||||
|
||||
__global__ void gpuYoloLayer_e(
|
||||
const float* cls, const float* reg, int* d_indexes, float* d_scores, float* d_boxes, int* d_classes, int* countData,
|
||||
const float scoreThreshold, const uint netWidth, const uint netHeight, const uint numOutputClasses,
|
||||
const uint64_t outputSize)
|
||||
const float* cls, const float* reg, int* num_detections, float* detection_boxes, float* detection_scores,
|
||||
int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight,
|
||||
const uint numOutputClasses, const uint64_t outputSize)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
@@ -34,39 +34,38 @@ __global__ void gpuYoloLayer_e(
|
||||
if (maxProb < scoreThreshold)
|
||||
return;
|
||||
|
||||
int count = (int)atomicAdd(countData, 1);
|
||||
int count = (int)atomicAdd(num_detections, 1);
|
||||
|
||||
d_indexes[count] = count;
|
||||
d_scores[count] = maxProb + 1.f;
|
||||
d_boxes[count * 4 + 0] = reg[x_id * 4 + 0];
|
||||
d_boxes[count * 4 + 1] = reg[x_id * 4 + 1];
|
||||
d_boxes[count * 4 + 2] = reg[x_id * 4 + 2];
|
||||
d_boxes[count * 4 + 3] = reg[x_id * 4 + 3];
|
||||
d_classes[count] = maxIndex;
|
||||
detection_boxes[count * 4 + 0] = reg[x_id * 4 + 0];
|
||||
detection_boxes[count * 4 + 1] = reg[x_id * 4 + 1];
|
||||
detection_boxes[count * 4 + 2] = reg[x_id * 4 + 2];
|
||||
detection_boxes[count * 4 + 3] = reg[x_id * 4 + 3];
|
||||
detection_scores[count] = maxProb;
|
||||
detection_classes[count] = maxIndex;
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer_e(
|
||||
const void* cls, const void* reg, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, const uint& netHeight,
|
||||
const uint& numOutputClasses, cudaStream_t stream);
|
||||
const void* cls, const void* reg, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream);
|
||||
|
||||
cudaError_t cudaYoloLayer_e(
|
||||
const void* cls, const void* reg, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, const uint& netHeight,
|
||||
const uint& numOutputClasses, cudaStream_t stream)
|
||||
const void* cls, const void* reg, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream)
|
||||
{
|
||||
int threads_per_block = 16;
|
||||
int number_of_blocks = 525;
|
||||
int number_of_blocks = (outputSize / threads_per_block) + 1;
|
||||
|
||||
for (unsigned int batch = 0; batch < batchSize; ++batch)
|
||||
{
|
||||
gpuYoloLayer_e<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(cls) + (batch * numOutputClasses * outputSize),
|
||||
reinterpret_cast<const float*>(reg) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<int*>(d_indexes) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(d_scores) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(d_boxes) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<int*>(d_classes) + (batch * outputSize), reinterpret_cast<int*>(countData) + (batch),
|
||||
reinterpret_cast<int*>(num_detections) + (batch),
|
||||
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
|
||||
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
|
||||
scoreThreshold, netWidth, netHeight, numOutputClasses, outputSize);
|
||||
}
|
||||
return cudaGetLastError();
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include <stdint.h>
|
||||
|
||||
__global__ void gpuYoloLayer_nc(
|
||||
const float* input, int* d_indexes, float* d_scores, float* d_boxes, int* d_classes, int* countData,
|
||||
const float* input, int* num_detections, float* detection_boxes, float* detection_scores, int* detection_classes,
|
||||
const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, const uint gridSizeY,
|
||||
const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors, const int* mask)
|
||||
{
|
||||
@@ -26,7 +26,7 @@ __global__ void gpuYoloLayer_nc(
|
||||
if (objectness < scoreThreshold)
|
||||
return;
|
||||
|
||||
int count = (int)atomicAdd(countData, 1);
|
||||
int count = (int)atomicAdd(num_detections, 1);
|
||||
|
||||
const float alpha = scaleXY;
|
||||
const float beta = -0.5 * (scaleXY - 1);
|
||||
@@ -62,23 +62,22 @@ __global__ void gpuYoloLayer_nc(
|
||||
}
|
||||
}
|
||||
|
||||
d_indexes[count] = count;
|
||||
d_scores[count] = objectness * maxProb + 1.f;
|
||||
d_boxes[count * 4 + 0] = x - 0.5 * w;
|
||||
d_boxes[count * 4 + 1] = y - 0.5 * h;
|
||||
d_boxes[count * 4 + 2] = x + 0.5 * w;
|
||||
d_boxes[count * 4 + 3] = y + 0.5 * h;
|
||||
d_classes[count] = maxIndex;
|
||||
detection_boxes[count * 4 + 0] = x - 0.5 * w;
|
||||
detection_boxes[count * 4 + 1] = y - 0.5 * h;
|
||||
detection_boxes[count * 4 + 2] = x + 0.5 * w;
|
||||
detection_boxes[count * 4 + 3] = y + 0.5 * h;
|
||||
detection_scores[count] = objectness * maxProb;
|
||||
detection_classes[count] = maxIndex;
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer_nc(
|
||||
const void* input, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
|
||||
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
|
||||
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
|
||||
|
||||
cudaError_t cudaYoloLayer_nc(
|
||||
const void* input, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
|
||||
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
|
||||
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream)
|
||||
@@ -92,10 +91,10 @@ cudaError_t cudaYoloLayer_nc(
|
||||
{
|
||||
gpuYoloLayer_nc<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * inputSize),
|
||||
reinterpret_cast<int*>(d_indexes) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(d_scores) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(d_boxes) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<int*>(d_classes) + (batch * outputSize), reinterpret_cast<int*>(countData) + (batch),
|
||||
reinterpret_cast<int*>(num_detections) + (batch),
|
||||
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
|
||||
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
|
||||
scoreThreshold, netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes, scaleXY,
|
||||
reinterpret_cast<const float*>(anchors), reinterpret_cast<const int*>(mask));
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
|
||||
|
||||
__global__ void gpuYoloLayer_r(
|
||||
const float* input, int* d_indexes, float* d_scores, float* d_boxes, int* d_classes, int* countData,
|
||||
const float* input, int* num_detections, float* detection_boxes, float* detection_scores, int* detection_classes,
|
||||
const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, const uint gridSizeY,
|
||||
const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors, const int* mask)
|
||||
{
|
||||
@@ -28,7 +28,7 @@ __global__ void gpuYoloLayer_r(
|
||||
if (objectness < scoreThreshold)
|
||||
return;
|
||||
|
||||
int count = (int)atomicAdd(countData, 1);
|
||||
int count = (int)atomicAdd(num_detections, 1);
|
||||
|
||||
const float alpha = scaleXY;
|
||||
const float beta = -0.5 * (scaleXY - 1);
|
||||
@@ -64,23 +64,22 @@ __global__ void gpuYoloLayer_r(
|
||||
}
|
||||
}
|
||||
|
||||
d_indexes[count] = count;
|
||||
d_scores[count] = objectness * maxProb + 1.f;
|
||||
d_boxes[count * 4 + 0] = x - 0.5 * w;
|
||||
d_boxes[count * 4 + 1] = y - 0.5 * h;
|
||||
d_boxes[count * 4 + 2] = x + 0.5 * w;
|
||||
d_boxes[count * 4 + 3] = y + 0.5 * h;
|
||||
d_classes[count] = maxIndex;
|
||||
detection_boxes[count * 4 + 0] = x - 0.5 * w;
|
||||
detection_boxes[count * 4 + 1] = y - 0.5 * h;
|
||||
detection_boxes[count * 4 + 2] = x + 0.5 * w;
|
||||
detection_boxes[count * 4 + 3] = y + 0.5 * h;
|
||||
detection_scores[count] = objectness * maxProb;
|
||||
detection_classes[count] = maxIndex;
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer_r(
|
||||
const void* input, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
|
||||
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
|
||||
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
|
||||
|
||||
cudaError_t cudaYoloLayer_r(
|
||||
const void* input, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
|
||||
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
|
||||
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream)
|
||||
@@ -94,10 +93,10 @@ cudaError_t cudaYoloLayer_r(
|
||||
{
|
||||
gpuYoloLayer_r<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * inputSize),
|
||||
reinterpret_cast<int*>(d_indexes) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(d_scores) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(d_boxes) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<int*>(d_classes) + (batch * outputSize), reinterpret_cast<int*>(countData) + (batch),
|
||||
reinterpret_cast<int*>(num_detections) + (batch),
|
||||
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
|
||||
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
|
||||
scoreThreshold, netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes, scaleXY,
|
||||
reinterpret_cast<const float*>(anchors), reinterpret_cast<const int*>(mask));
|
||||
}
|
||||
|
||||
@@ -29,9 +29,9 @@ __device__ void softmaxGPU(
|
||||
}
|
||||
|
||||
__global__ void gpuRegionLayer(
|
||||
const float* input, float* softmax, int* d_indexes, float* d_scores, float* d_boxes, int* d_classes, int* countData,
|
||||
const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, const uint gridSizeY,
|
||||
const uint numOutputClasses, const uint numBBoxes, const float* anchors)
|
||||
const float* input, float* softmax, int* num_detections, float* detection_boxes, float* detection_scores,
|
||||
int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX,
|
||||
const uint gridSizeY, const uint numOutputClasses, const uint numBBoxes, const float* anchors)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
@@ -49,7 +49,7 @@ __global__ void gpuRegionLayer(
|
||||
if (objectness < scoreThreshold)
|
||||
return;
|
||||
|
||||
int count = (int)atomicAdd(countData, 1);
|
||||
int count = (int)atomicAdd(num_detections, 1);
|
||||
|
||||
float x
|
||||
= (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)])
|
||||
@@ -84,26 +84,25 @@ __global__ void gpuRegionLayer(
|
||||
}
|
||||
}
|
||||
|
||||
d_indexes[count] = count;
|
||||
d_scores[count] = objectness * maxProb + 1.f;
|
||||
d_boxes[count * 4 + 0] = x - 0.5 * w;
|
||||
d_boxes[count * 4 + 1] = y - 0.5 * h;
|
||||
d_boxes[count * 4 + 2] = x + 0.5 * w;
|
||||
d_boxes[count * 4 + 3] = y + 0.5 * h;
|
||||
d_classes[count] = maxIndex;
|
||||
detection_boxes[count * 4 + 0] = x - 0.5 * w;
|
||||
detection_boxes[count * 4 + 1] = y - 0.5 * h;
|
||||
detection_boxes[count * 4 + 2] = x + 0.5 * w;
|
||||
detection_boxes[count * 4 + 3] = y + 0.5 * h;
|
||||
detection_scores[count] = objectness * maxProb;
|
||||
detection_classes[count] = maxIndex;
|
||||
}
|
||||
|
||||
cudaError_t cudaRegionLayer(
|
||||
const void* input, void* softmax, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
|
||||
const void* anchors, cudaStream_t stream);
|
||||
const void* input, void* softmax, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
|
||||
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
||||
const uint& numBBoxes, const void* anchors, cudaStream_t stream);
|
||||
|
||||
cudaError_t cudaRegionLayer(
|
||||
const void* input, void* softmax, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
|
||||
const void* anchors, cudaStream_t stream)
|
||||
const void* input, void* softmax, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
|
||||
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
||||
const uint& numBBoxes, const void* anchors, cudaStream_t stream)
|
||||
{
|
||||
dim3 threads_per_block(16, 16, 4);
|
||||
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
|
||||
@@ -115,10 +114,10 @@ cudaError_t cudaRegionLayer(
|
||||
gpuRegionLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * inputSize),
|
||||
reinterpret_cast<float*>(softmax) + (batch * inputSize),
|
||||
reinterpret_cast<int*>(d_indexes) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(d_scores) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(d_boxes) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<int*>(d_classes) + (batch * outputSize), reinterpret_cast<int*>(countData) + (batch),
|
||||
reinterpret_cast<int*>(num_detections) + (batch),
|
||||
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
|
||||
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
|
||||
scoreThreshold, netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes,
|
||||
reinterpret_cast<const float*>(anchors));
|
||||
}
|
||||
|
||||
@@ -48,37 +48,33 @@ namespace {
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer_e(
|
||||
const void* cls, const void* reg, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, const uint& netHeight,
|
||||
const uint& numOutputClasses, cudaStream_t stream);
|
||||
const void* cls, const void* reg, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream);
|
||||
|
||||
cudaError_t cudaYoloLayer_r(
|
||||
const void* input, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
|
||||
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
|
||||
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
|
||||
|
||||
cudaError_t cudaYoloLayer_nc(
|
||||
const void* input, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
|
||||
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
|
||||
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
|
||||
|
||||
cudaError_t cudaYoloLayer(
|
||||
const void* input, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
|
||||
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
|
||||
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
|
||||
|
||||
cudaError_t cudaRegionLayer(
|
||||
const void* input, void* softmax, void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* countData,
|
||||
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
|
||||
const void* anchors, cudaStream_t stream);
|
||||
|
||||
cudaError_t sortDetections(
|
||||
void* d_indexes, void* d_scores, void* d_boxes, void* d_classes, void* bboxData, void* scoreData, void* countData,
|
||||
const uint& batchSize, uint64_t& outputSize, uint& topK, const uint& numOutputClasses, cudaStream_t stream);
|
||||
const void* input, void* softmax, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
|
||||
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
||||
const uint& numBBoxes, const void* anchors, cudaStream_t stream);
|
||||
|
||||
YoloLayer::YoloLayer (const void* data, size_t length)
|
||||
{
|
||||
@@ -90,7 +86,6 @@ YoloLayer::YoloLayer (const void* data, size_t length)
|
||||
read(d, m_NewCoords);
|
||||
read(d, m_OutputSize);
|
||||
read(d, m_Type);
|
||||
read(d, m_TopK);
|
||||
read(d, m_ScoreThreshold);
|
||||
|
||||
if (m_Type != 3) {
|
||||
@@ -130,7 +125,7 @@ YoloLayer::YoloLayer (const void* data, size_t length)
|
||||
|
||||
YoloLayer::YoloLayer(
|
||||
const uint& netWidth, const uint& netHeight, const uint& numClasses, const uint& newCoords,
|
||||
const std::vector<TensorInfo>& yoloTensors, const uint64_t& outputSize, const uint& modelType, const uint& topK,
|
||||
const std::vector<TensorInfo>& yoloTensors, const uint64_t& outputSize, const uint& modelType,
|
||||
const float& scoreThreshold) :
|
||||
m_NetWidth(netWidth),
|
||||
m_NetHeight(netHeight),
|
||||
@@ -139,7 +134,6 @@ YoloLayer::YoloLayer(
|
||||
m_YoloTensors(yoloTensors),
|
||||
m_OutputSize(outputSize),
|
||||
m_Type(modelType),
|
||||
m_TopK(topK),
|
||||
m_ScoreThreshold(scoreThreshold)
|
||||
{
|
||||
assert(m_NetWidth > 0);
|
||||
@@ -152,11 +146,14 @@ nvinfer1::Dims
|
||||
YoloLayer::getOutputDimensions(
|
||||
int index, const nvinfer1::Dims* inputs, int nbInputDims) noexcept
|
||||
{
|
||||
assert(index < 3);
|
||||
assert(index <= 4);
|
||||
if (index == 0) {
|
||||
return nvinfer1::Dims{3, {static_cast<int>(m_TopK), 1, 4}};
|
||||
return nvinfer1::Dims{1, {1}};
|
||||
}
|
||||
return nvinfer1::Dims{2, {static_cast<int>(m_TopK), static_cast<int>(m_NumClasses)}};
|
||||
else if (index == 1) {
|
||||
return nvinfer1::Dims{2, {static_cast<int>(m_OutputSize), 4}};
|
||||
}
|
||||
return nvinfer1::Dims{1, {static_cast<int>(m_OutputSize)}};
|
||||
}
|
||||
|
||||
bool YoloLayer::supportsFormat (
|
||||
@@ -180,37 +177,21 @@ int32_t YoloLayer::enqueue (
|
||||
int batchSize, void const* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) noexcept
|
||||
{
|
||||
void* bboxData = outputs[0];
|
||||
void* scoreData = outputs[1];
|
||||
void* num_detections = outputs[0];
|
||||
void* detection_boxes = outputs[1];
|
||||
void* detection_scores = outputs[2];
|
||||
void* detection_classes = outputs[3];
|
||||
|
||||
CUDA_CHECK(cudaMemsetAsync((float*)bboxData, 0, sizeof(float) * m_TopK * 4 * batchSize, stream));
|
||||
CUDA_CHECK(cudaMemsetAsync((float*)scoreData, 0, sizeof(float) * m_TopK * m_NumClasses * batchSize, stream));
|
||||
|
||||
void* countData;
|
||||
CUDA_CHECK(cudaMalloc(&countData, sizeof(int) * batchSize));
|
||||
CUDA_CHECK(cudaMemsetAsync((int*)countData, 0, sizeof(int) * batchSize, stream));
|
||||
|
||||
void* d_indexes;
|
||||
CUDA_CHECK(cudaMalloc(&d_indexes, sizeof(int) * m_OutputSize * batchSize));
|
||||
CUDA_CHECK(cudaMemsetAsync((int*)d_indexes, 0, sizeof(int) * m_OutputSize * batchSize, stream));
|
||||
|
||||
void* d_scores;
|
||||
CUDA_CHECK(cudaMalloc(&d_scores, sizeof(float) * m_OutputSize * batchSize));
|
||||
CUDA_CHECK(cudaMemsetAsync((float*)d_scores, 0, sizeof(float) * m_OutputSize * batchSize, stream));
|
||||
|
||||
void* d_boxes;
|
||||
CUDA_CHECK(cudaMalloc(&d_boxes, sizeof(float) * m_OutputSize * 4 * batchSize));
|
||||
CUDA_CHECK(cudaMemsetAsync((float*)d_boxes, 0, sizeof(float) * m_OutputSize * 4 * batchSize, stream));
|
||||
|
||||
void* d_classes;
|
||||
CUDA_CHECK(cudaMalloc(&d_classes, sizeof(int) * m_OutputSize * batchSize));
|
||||
CUDA_CHECK(cudaMemsetAsync((float*)d_classes, 0, sizeof(int) * m_OutputSize * batchSize, stream));
|
||||
CUDA_CHECK(cudaMemsetAsync((int*)num_detections, 0, sizeof(int) * batchSize, stream));
|
||||
CUDA_CHECK(cudaMemsetAsync((float*)detection_boxes, 0, sizeof(float) * m_OutputSize * 4 * batchSize, stream));
|
||||
CUDA_CHECK(cudaMemsetAsync((float*)detection_scores, 0, sizeof(float) * m_OutputSize * batchSize, stream));
|
||||
CUDA_CHECK(cudaMemsetAsync((int*)detection_classes, 0, sizeof(int) * m_OutputSize * batchSize, stream));
|
||||
|
||||
if (m_Type == 3)
|
||||
{
|
||||
CUDA_CHECK(cudaYoloLayer_e(
|
||||
inputs[0], inputs[1], d_indexes, d_scores, d_boxes, d_classes, countData, batchSize, m_OutputSize,
|
||||
m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, stream));
|
||||
inputs[0], inputs[1], num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
|
||||
m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, stream));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -243,22 +224,22 @@ int32_t YoloLayer::enqueue (
|
||||
|
||||
if (m_Type == 2) { // YOLOR incorrect param: scale_x_y = 2.0
|
||||
CUDA_CHECK(cudaYoloLayer_r(
|
||||
inputs[i], d_indexes, d_scores, d_boxes, d_classes, countData, batchSize, inputSize, m_OutputSize,
|
||||
m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, m_NumClasses, numBBoxes, 2.0, v_anchors,
|
||||
v_mask, stream));
|
||||
inputs[i], num_detections, detection_boxes, detection_scores, detection_classes, batchSize, inputSize,
|
||||
m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, m_NumClasses, numBBoxes,
|
||||
2.0, v_anchors, v_mask, stream));
|
||||
}
|
||||
else if (m_Type == 1) {
|
||||
if (m_NewCoords) {
|
||||
CUDA_CHECK(cudaYoloLayer_nc(
|
||||
inputs[i], d_indexes, d_scores, d_boxes, d_classes, countData, batchSize, inputSize, m_OutputSize,
|
||||
m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, m_NumClasses, numBBoxes, scaleXY,
|
||||
v_anchors, v_mask, stream));
|
||||
inputs[i], num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
|
||||
inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY,
|
||||
m_NumClasses, numBBoxes, scaleXY, v_anchors, v_mask, stream));
|
||||
}
|
||||
else {
|
||||
CUDA_CHECK(cudaYoloLayer(
|
||||
inputs[i], d_indexes, d_scores, d_boxes, d_classes, countData, batchSize, inputSize, m_OutputSize,
|
||||
m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, m_NumClasses, numBBoxes, scaleXY,
|
||||
v_anchors, v_mask, stream));
|
||||
inputs[i], num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
|
||||
inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY,
|
||||
m_NumClasses, numBBoxes, scaleXY, v_anchors, v_mask, stream));
|
||||
}
|
||||
}
|
||||
else {
|
||||
@@ -267,9 +248,9 @@ int32_t YoloLayer::enqueue (
|
||||
CUDA_CHECK(cudaMemsetAsync((float*)softmax, 0, sizeof(float) * inputSize * batchSize));
|
||||
|
||||
CUDA_CHECK(cudaRegionLayer(
|
||||
inputs[i], softmax, d_indexes, d_scores, d_boxes, d_classes, countData, batchSize, inputSize, m_OutputSize,
|
||||
m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, m_NumClasses, numBBoxes, v_anchors,
|
||||
stream));
|
||||
inputs[i], softmax, num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
|
||||
inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, m_NumClasses,
|
||||
numBBoxes, v_anchors, stream));
|
||||
|
||||
CUDA_CHECK(cudaFree(softmax));
|
||||
}
|
||||
@@ -283,16 +264,6 @@ int32_t YoloLayer::enqueue (
|
||||
}
|
||||
}
|
||||
|
||||
CUDA_CHECK(sortDetections(
|
||||
d_indexes, d_scores, d_boxes, d_classes, bboxData, scoreData, countData, batchSize, m_OutputSize, m_TopK,
|
||||
m_NumClasses, stream));
|
||||
|
||||
CUDA_CHECK(cudaFree(countData));
|
||||
CUDA_CHECK(cudaFree(d_indexes));
|
||||
CUDA_CHECK(cudaFree(d_scores));
|
||||
CUDA_CHECK(cudaFree(d_boxes));
|
||||
CUDA_CHECK(cudaFree(d_classes));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -306,7 +277,6 @@ size_t YoloLayer::getSerializationSize() const noexcept
|
||||
totalSize += sizeof(m_NewCoords);
|
||||
totalSize += sizeof(m_OutputSize);
|
||||
totalSize += sizeof(m_Type);
|
||||
totalSize += sizeof(m_TopK);
|
||||
totalSize += sizeof(m_ScoreThreshold);
|
||||
|
||||
if (m_Type != 3) {
|
||||
@@ -338,7 +308,6 @@ void YoloLayer::serialize(void* buffer) const noexcept
|
||||
write(d, m_NewCoords);
|
||||
write(d, m_OutputSize);
|
||||
write(d, m_Type);
|
||||
write(d, m_TopK);
|
||||
write(d, m_ScoreThreshold);
|
||||
|
||||
if (m_Type != 3) {
|
||||
@@ -372,8 +341,7 @@ void YoloLayer::serialize(void* buffer) const noexcept
|
||||
nvinfer1::IPluginV2* YoloLayer::clone() const noexcept
|
||||
{
|
||||
return new YoloLayer (
|
||||
m_NetWidth, m_NetHeight, m_NumClasses, m_NewCoords, m_YoloTensors, m_OutputSize, m_Type, m_TopK,
|
||||
m_ScoreThreshold);
|
||||
m_NetWidth, m_NetHeight, m_NumClasses, m_NewCoords, m_YoloTensors, m_OutputSize, m_Type, m_ScoreThreshold);
|
||||
}
|
||||
|
||||
REGISTER_TENSORRT_PLUGIN(YoloLayerPluginCreator);
|
||||
|
||||
@@ -61,14 +61,14 @@ public:
|
||||
|
||||
YoloLayer (
|
||||
const uint& netWidth, const uint& netHeight, const uint& numClasses, const uint& newCoords,
|
||||
const std::vector<TensorInfo>& yoloTensors, const uint64_t& outputSize, const uint& modelType, const uint& topK,
|
||||
const std::vector<TensorInfo>& yoloTensors, const uint64_t& outputSize, const uint& modelType,
|
||||
const float& scoreThreshold);
|
||||
|
||||
const char* getPluginType () const noexcept override { return YOLOLAYER_PLUGIN_NAME; }
|
||||
|
||||
const char* getPluginVersion () const noexcept override { return YOLOLAYER_PLUGIN_VERSION; }
|
||||
|
||||
int getNbOutputs () const noexcept override { return 2; }
|
||||
int getNbOutputs () const noexcept override { return 4; }
|
||||
|
||||
nvinfer1::Dims getOutputDimensions (
|
||||
int index, const nvinfer1::Dims* inputs,
|
||||
@@ -116,7 +116,6 @@ private:
|
||||
std::vector<TensorInfo> m_YoloTensors;
|
||||
uint64_t m_OutputSize {0};
|
||||
uint m_Type {0};
|
||||
uint m_TopK {0};
|
||||
float m_ScoreThreshold {0};
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user