Add YOLOv8 support
This commit is contained in:
@@ -26,7 +26,11 @@
|
||||
#ifndef _YOLO_H_
|
||||
#define _YOLO_H_
|
||||
|
||||
#include "NvInferPlugin.h"
|
||||
#include "nvdsinfer_custom_impl.h"
|
||||
|
||||
#include "layers/convolutional_layer.h"
|
||||
#include "layers/c2f_layer.h"
|
||||
#include "layers/batchnorm_layer.h"
|
||||
#include "layers/implicit_layer.h"
|
||||
#include "layers/channels_layer.h"
|
||||
@@ -40,36 +44,35 @@
|
||||
#include "layers/softmax_layer.h"
|
||||
#include "layers/cls_layer.h"
|
||||
#include "layers/reg_layer.h"
|
||||
|
||||
#include "nvdsinfer_custom_impl.h"
|
||||
#include "layers/detect_v8_layer.h"
|
||||
|
||||
struct NetworkInfo
|
||||
{
|
||||
std::string inputBlobName;
|
||||
std::string networkType;
|
||||
std::string configFilePath;
|
||||
std::string wtsFilePath;
|
||||
std::string int8CalibPath;
|
||||
std::string deviceType;
|
||||
uint numDetectedClasses;
|
||||
int clusterMode;
|
||||
float scoreThreshold;
|
||||
std::string networkMode;
|
||||
std::string inputBlobName;
|
||||
std::string networkType;
|
||||
std::string configFilePath;
|
||||
std::string wtsFilePath;
|
||||
std::string int8CalibPath;
|
||||
std::string deviceType;
|
||||
uint numDetectedClasses;
|
||||
int clusterMode;
|
||||
float scoreThreshold;
|
||||
std::string networkMode;
|
||||
};
|
||||
|
||||
struct TensorInfo
|
||||
{
|
||||
std::string blobName;
|
||||
uint gridSizeX {0};
|
||||
uint gridSizeY {0};
|
||||
uint numBBoxes {0};
|
||||
float scaleXY;
|
||||
std::vector<float> anchors;
|
||||
std::vector<int> mask;
|
||||
std::string blobName;
|
||||
uint gridSizeX {0};
|
||||
uint gridSizeY {0};
|
||||
uint numBBoxes {0};
|
||||
float scaleXY;
|
||||
std::vector<float> anchors;
|
||||
std::vector<int> mask;
|
||||
};
|
||||
|
||||
class Yolo : public IModelParser {
|
||||
public:
|
||||
public:
|
||||
Yolo(const NetworkInfo& networkInfo);
|
||||
|
||||
~Yolo() override;
|
||||
@@ -77,14 +80,14 @@ public:
|
||||
bool hasFullDimsSupported() const override { return false; }
|
||||
|
||||
const char* getModelName() const override {
|
||||
return m_ConfigFilePath.empty() ? m_NetworkType.c_str() : m_ConfigFilePath.c_str();
|
||||
return m_ConfigFilePath.empty() ? m_NetworkType.c_str() : m_ConfigFilePath.c_str();
|
||||
}
|
||||
|
||||
NvDsInferStatus parseModel(nvinfer1::INetworkDefinition& network) override;
|
||||
|
||||
nvinfer1::ICudaEngine *createEngine (nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config);
|
||||
nvinfer1::ICudaEngine* createEngine(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config);
|
||||
|
||||
protected:
|
||||
protected:
|
||||
const std::string m_InputBlobName;
|
||||
const std::string m_NetworkType;
|
||||
const std::string m_ConfigFilePath;
|
||||
@@ -109,7 +112,7 @@ protected:
|
||||
std::vector<std::map<std::string, std::string>> m_ConfigBlocks;
|
||||
std::vector<nvinfer1::Weights> m_TrtWeights;
|
||||
|
||||
private:
|
||||
private:
|
||||
NvDsInferStatus buildYoloNetwork(std::vector<float>& weights, nvinfer1::INetworkDefinition& network);
|
||||
|
||||
std::vector<std::map<std::string, std::string>> parseConfigFile(const std::string cfgFilePath);
|
||||
|
||||
Reference in New Issue
Block a user