New optimized NMS
This commit is contained in:
@@ -41,6 +41,7 @@ Yolo::Yolo(const NetworkInfo& networkInfo)
|
||||
m_NumDetectedClasses(networkInfo.numDetectedClasses),
|
||||
m_ClusterMode(networkInfo.clusterMode),
|
||||
m_NetworkMode(networkInfo.networkMode),
|
||||
m_ScoreThreshold(networkInfo.scoreThreshold),
|
||||
m_InputH(0),
|
||||
m_InputW(0),
|
||||
m_InputC(0),
|
||||
@@ -48,10 +49,7 @@ Yolo::Yolo(const NetworkInfo& networkInfo)
|
||||
m_NumClasses(0),
|
||||
m_LetterBox(0),
|
||||
m_NewCoords(0),
|
||||
m_YoloCount(0),
|
||||
m_IouThreshold(0),
|
||||
m_ScoreThreshold(0),
|
||||
m_TopK(0)
|
||||
m_YoloCount(0)
|
||||
{}
|
||||
|
||||
Yolo::~Yolo()
|
||||
@@ -59,22 +57,13 @@ Yolo::~Yolo()
|
||||
destroyNetworkUtils();
|
||||
}
|
||||
|
||||
nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config)
|
||||
nvinfer1::ICudaEngine *Yolo::createEngine(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config)
|
||||
{
|
||||
assert (builder);
|
||||
|
||||
m_ConfigBlocks = parseConfigFile(m_ConfigFilePath);
|
||||
parseConfigBlocks();
|
||||
|
||||
std::string configNMS = getAbsPath(m_WtsFilePath) + "/config_nms.txt";
|
||||
if (!fileExists(configNMS))
|
||||
{
|
||||
std::cerr << "YOLO config_nms.txt file is not specified\n" << std::endl;
|
||||
assert(0);
|
||||
}
|
||||
m_ConfigNMSBlocks = parseConfigFile(configNMS);
|
||||
parseConfigNMSBlocks();
|
||||
|
||||
nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0);
|
||||
if (parseModel(*network) != NVDSINFER_SUCCESS)
|
||||
{
|
||||
@@ -94,9 +83,9 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder, nvinfer1
|
||||
std::cout << "NOTE: letter_box is set in cfg file, make sure to set maintain-aspect-ratio=1 in config_infer file"
|
||||
<< " to get better accuracy\n" << std::endl;
|
||||
}
|
||||
if (m_ClusterMode != 4)
|
||||
if (m_ClusterMode != 2)
|
||||
{
|
||||
std::cout << "NOTE: Wrong cluster-mode is set, make sure to set cluster-mode=4 in config_infer file\n"
|
||||
std::cout << "NOTE: Wrong cluster-mode is set, make sure to set cluster-mode=2 in config_infer file\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
@@ -452,54 +441,31 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
|
||||
outputSize += curYoloTensor.gridSizeX * curYoloTensor.gridSizeY * curYoloTensor.numBBoxes;
|
||||
}
|
||||
|
||||
if (m_TopK > outputSize) {
|
||||
std::cout << "\ntopk > Number of outputs\nPlease change the topk to " << outputSize
|
||||
<< " or less in config_nms.txt file\n" << std::endl;
|
||||
assert(0);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* yoloPlugin = new YoloLayer(
|
||||
m_InputW, m_InputH, m_NumClasses, m_NewCoords, m_YoloTensors, outputSize, modelType, m_TopK, m_ScoreThreshold);
|
||||
m_InputW, m_InputH, m_NumClasses, m_NewCoords, m_YoloTensors, outputSize, modelType, m_ScoreThreshold);
|
||||
assert(yoloPlugin != nullptr);
|
||||
nvinfer1::IPluginV2Layer* yolo = network.addPluginV2(yoloTensorInputs, m_YoloCount, *yoloPlugin);
|
||||
assert(yolo != nullptr);
|
||||
std::string yoloLayerName = "yolo";
|
||||
yolo->setName(yoloLayerName.c_str());
|
||||
|
||||
nvinfer1::ITensor* yoloTensorOutputs[] = {yolo->getOutput(0), yolo->getOutput(1)};
|
||||
|
||||
nvinfer1::plugin::NMSParameters nmsParams;
|
||||
nmsParams.shareLocation = true;
|
||||
nmsParams.backgroundLabelId = -1;
|
||||
nmsParams.numClasses = m_NumClasses;
|
||||
nmsParams.topK = m_TopK;
|
||||
nmsParams.keepTopK = m_TopK;
|
||||
nmsParams.scoreThreshold = m_ScoreThreshold;
|
||||
nmsParams.iouThreshold = m_IouThreshold;
|
||||
nmsParams.isNormalized = false;
|
||||
|
||||
std::string nmslayerName = "batchedNMS";
|
||||
nvinfer1::IPluginV2* batchedNMS = createBatchedNMSPlugin(nmsParams);
|
||||
nvinfer1::IPluginV2Layer* nms = network.addPluginV2(yoloTensorOutputs, 2, *batchedNMS);
|
||||
nms->setName(nmslayerName.c_str());
|
||||
nvinfer1::ITensor* num_detections = nms->getOutput(0);
|
||||
nmslayerName = "num_detections";
|
||||
num_detections->setName(nmslayerName.c_str());
|
||||
nvinfer1::ITensor* nmsed_boxes = nms->getOutput(1);
|
||||
nmslayerName = "nmsed_boxes";
|
||||
nmsed_boxes->setName(nmslayerName.c_str());
|
||||
nvinfer1::ITensor* nmsed_scores = nms->getOutput(2);
|
||||
nmslayerName = "nmsed_scores";
|
||||
nmsed_scores->setName(nmslayerName.c_str());
|
||||
nvinfer1::ITensor* nmsed_classes = nms->getOutput(3);
|
||||
nmslayerName = "nmsed_classes";
|
||||
nmsed_classes->setName(nmslayerName.c_str());
|
||||
std::string outputlayerName;
|
||||
nvinfer1::ITensor* num_detections = yolo->getOutput(0);
|
||||
outputlayerName = "num_detections";
|
||||
num_detections->setName(outputlayerName.c_str());
|
||||
nvinfer1::ITensor* detection_boxes = yolo->getOutput(1);
|
||||
outputlayerName = "detection_boxes";
|
||||
detection_boxes->setName(outputlayerName.c_str());
|
||||
nvinfer1::ITensor* detection_scores = yolo->getOutput(2);
|
||||
outputlayerName = "detection_scores";
|
||||
detection_scores->setName(outputlayerName.c_str());
|
||||
nvinfer1::ITensor* detection_classes = yolo->getOutput(3);
|
||||
outputlayerName = "detection_classes";
|
||||
detection_classes->setName(outputlayerName.c_str());
|
||||
network.markOutput(*num_detections);
|
||||
network.markOutput(*nmsed_boxes);
|
||||
network.markOutput(*nmsed_scores);
|
||||
network.markOutput(*nmsed_classes);
|
||||
|
||||
printLayerInfo("", "batched_nms", "-", "-", "-");
|
||||
network.markOutput(*detection_boxes);
|
||||
network.markOutput(*detection_scores);
|
||||
network.markOutput(*detection_classes);
|
||||
}
|
||||
else {
|
||||
std::cout << "\nError in yolo cfg file" << std::endl;
|
||||
@@ -659,20 +625,6 @@ void Yolo::parseConfigBlocks()
|
||||
}
|
||||
}
|
||||
|
||||
void Yolo::parseConfigNMSBlocks()
|
||||
{
|
||||
auto block = m_ConfigNMSBlocks[0];
|
||||
|
||||
assert((block.at("type") == "property") && "Missing 'property' param in nms cfg");
|
||||
assert((block.find("iou-threshold") != block.end()) && "Missing 'iou-threshold' param in nms cfg");
|
||||
assert((block.find("score-threshold") != block.end()) && "Missing 'score-threshold' param in nms cfg");
|
||||
assert((block.find("topk") != block.end()) && "Missing 'topk' param in nms cfg");
|
||||
|
||||
m_IouThreshold = std::stof(block.at("iou-threshold"));
|
||||
m_ScoreThreshold = std::stof(block.at("score-threshold"));
|
||||
m_TopK = std::stoul(block.at("topk"));
|
||||
}
|
||||
|
||||
void Yolo::destroyNetworkUtils()
|
||||
{
|
||||
for (uint i = 0; i < m_TrtWeights.size(); ++i)
|
||||
|
||||
Reference in New Issue
Block a user