GPU Batched NMS

This commit is contained in:
Marcos Luciano
2022-06-19 03:25:50 -03:00
parent 2300e3b44b
commit f621c0f429
24 changed files with 835 additions and 654 deletions

View File

@@ -29,7 +29,6 @@
#include "layers/convolutional_layer.h"
#include "layers/implicit_layer.h"
#include "layers/channels_layer.h"
#include "layers/dropout_layer.h"
#include "layers/shortcut_layer.h"
#include "layers/route_layer.h"
#include "layers/upsample_layer.h"
@@ -54,8 +53,10 @@ struct NetworkInfo
struct TensorInfo
{
std::string blobName;
uint gridSizeX {0};
uint gridSizeY {0};
uint numBBoxes {0};
uint numClasses {0};
float scaleXY;
std::vector<float> anchors;
std::vector<int> mask;
};
@@ -63,12 +64,15 @@ struct TensorInfo
class Yolo : public IModelParser {
public:
Yolo(const NetworkInfo& networkInfo);
~Yolo() override;
bool hasFullDimsSupported() const override { return false; }
const char* getModelName() const override {
return m_ConfigFilePath.empty() ? m_NetworkType.c_str()
: m_ConfigFilePath.c_str();
return m_ConfigFilePath.empty() ? m_NetworkType.c_str() : m_ConfigFilePath.c_str();
}
NvDsInferStatus parseModel(nvinfer1::INetworkDefinition& network) override;
nvinfer1::ICudaEngine *createEngine (nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config);
@@ -90,17 +94,26 @@ protected:
uint64_t m_InputSize;
uint m_NumClasses;
uint m_LetterBox;
uint m_NewCoords;
uint m_YoloCount;
float m_IouThreshold;
float m_ScoreThreshold;
uint m_TopK;
std::vector<TensorInfo> m_OutputTensors;
std::vector<TensorInfo> m_YoloTensors;
std::vector<std::map<std::string, std::string>> m_ConfigBlocks;
std::vector<std::map<std::string, std::string>> m_ConfigNMSBlocks;
std::vector<nvinfer1::Weights> m_TrtWeights;
private:
NvDsInferStatus buildYoloNetwork(
std::vector<float>& weights, nvinfer1::INetworkDefinition& network);
std::vector<std::map<std::string, std::string>> parseConfigFile(
const std::string cfgFilePath);
NvDsInferStatus buildYoloNetwork(std::vector<float>& weights, nvinfer1::INetworkDefinition& network);
std::vector<std::map<std::string, std::string>> parseConfigFile(const std::string cfgFilePath);
void parseConfigBlocks();
void parseConfigNMSBlocks();
void destroyNetworkUtils();
};