New features
- Added support for INT8 calibration - Added support for non square models - Updated mAP comparison between models
This commit is contained in:
@@ -27,13 +27,25 @@ CUDA_VER?=
|
||||
ifeq ($(CUDA_VER),)
|
||||
$(error "CUDA_VER is not set")
|
||||
endif
|
||||
|
||||
OPENCV?=
|
||||
ifeq ($(OPENCV),)
|
||||
OPENCV=0
|
||||
endif
|
||||
|
||||
CC:= g++
|
||||
NVCC:=/usr/local/cuda-$(CUDA_VER)/bin/nvcc
|
||||
|
||||
CFLAGS:= -Wall -std=c++11 -shared -fPIC -Wno-error=deprecated-declarations
|
||||
CFLAGS+= -I../../includes -I/usr/local/cuda-$(CUDA_VER)/include
|
||||
CFLAGS+= -I/opt/nvidia/deepstream/deepstream-5.1/sources/includes -I/usr/local/cuda-$(CUDA_VER)/include
|
||||
|
||||
LIBS:= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib64 -lcudart -lcublas -lstdc++fs
|
||||
ifeq ($(OPENCV), 1)
|
||||
COMMON= -DOPENCV
|
||||
CFLAGS+= $(shell pkg-config --cflags opencv4 2> /dev/null || pkg-config --cflags opencv)
|
||||
LIBS+= $(shell pkg-config --libs opencv4 2> /dev/null || pkg-config --libs opencv)
|
||||
endif
|
||||
|
||||
LIBS+= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib64 -lcudart -lcublas -lstdc++fs
|
||||
LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group
|
||||
|
||||
INCS:= $(wildcard *.h)
|
||||
@@ -50,6 +62,11 @@ SRCFILES:= nvdsinfer_yolo_engine.cpp \
|
||||
utils.cpp \
|
||||
yolo.cpp \
|
||||
yoloForward.cu
|
||||
|
||||
ifeq ($(OPENCV), 1)
|
||||
SRCFILES+= calibrator.cpp
|
||||
endif
|
||||
|
||||
TARGET_LIB:= libnvdsinfer_custom_impl_Yolo.so
|
||||
|
||||
TARGET_OBJS:= $(SRCFILES:.cpp=.o)
|
||||
@@ -58,7 +75,7 @@ TARGET_OBJS:= $(TARGET_OBJS:.cu=.o)
|
||||
all: $(TARGET_LIB)
|
||||
|
||||
%.o: %.cpp $(INCS) Makefile
|
||||
$(CC) -c -o $@ $(CFLAGS) $<
|
||||
$(CC) -c $(COMMON) -o $@ $(CFLAGS) $<
|
||||
|
||||
%.o: %.cu $(INCS) Makefile
|
||||
$(NVCC) -c -o $@ --compiler-options '-fPIC' $<
|
||||
|
||||
130
native/nvdsinfer_custom_impl_Yolo/calibrator.cpp
Normal file
130
native/nvdsinfer_custom_impl_Yolo/calibrator.cpp
Normal file
@@ -0,0 +1,130 @@
|
||||
/*
|
||||
* Created by Marcos Luciano
|
||||
* https://www.github.com/marcoslucianops
|
||||
*/
|
||||
|
||||
#include "calibrator.h"
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
|
||||
namespace nvinfer1
|
||||
{
|
||||
int8EntroyCalibrator::int8EntroyCalibrator(const int &batchsize, const int &channels, const int &height, const int &width, const int &letterbox, const std::string &imgPath,
|
||||
const std::string &calibTablePath):batchSize(batchsize), inputC(channels), inputH(height), inputW(width), letterBox(letterbox), calibTablePath(calibTablePath), imageIndex(0)
|
||||
{
|
||||
inputCount = batchsize * channels * height * width;
|
||||
std::fstream f(imgPath);
|
||||
if (f.is_open())
|
||||
{
|
||||
std::string temp;
|
||||
while (std::getline(f, temp)) imgPaths.push_back(temp);
|
||||
}
|
||||
batchData = new float[inputCount];
|
||||
CUDA_CHECK(cudaMalloc(&deviceInput, inputCount * sizeof(float)));
|
||||
}
|
||||
|
||||
int8EntroyCalibrator::~int8EntroyCalibrator()
|
||||
{
|
||||
CUDA_CHECK(cudaFree(deviceInput));
|
||||
if (batchData)
|
||||
delete[] batchData;
|
||||
}
|
||||
|
||||
bool int8EntroyCalibrator::getBatch(void **bindings, const char **names, int nbBindings)
|
||||
{
|
||||
if (imageIndex + batchSize > uint(imgPaths.size()))
|
||||
return false;
|
||||
|
||||
float* ptr = batchData;
|
||||
for (size_t j = imageIndex; j < imageIndex + batchSize; ++j)
|
||||
{
|
||||
cv::Mat img = cv::imread(imgPaths[j], cv::IMREAD_COLOR);
|
||||
std::vector<float>inputData = prepareImage(img, inputC, inputH, inputW, letterBox);
|
||||
|
||||
int len = (int)(inputData.size());
|
||||
memcpy(ptr, inputData.data(), len * sizeof(float));
|
||||
|
||||
ptr += inputData.size();
|
||||
std::cout << "Load image: " << imgPaths[j] << std::endl;
|
||||
std::cout << "Progress: " << (j + 1)*100. / imgPaths.size() << "%" << std::endl;
|
||||
}
|
||||
imageIndex += batchSize;
|
||||
CUDA_CHECK(cudaMemcpy(deviceInput, batchData, inputCount * sizeof(float), cudaMemcpyHostToDevice));
|
||||
bindings[0] = deviceInput;
|
||||
return true;
|
||||
}
|
||||
|
||||
const void* int8EntroyCalibrator::readCalibrationCache(std::size_t &length)
|
||||
{
|
||||
calibrationCache.clear();
|
||||
std::ifstream input(calibTablePath, std::ios::binary);
|
||||
input >> std::noskipws;
|
||||
if (readCache && input.good())
|
||||
{
|
||||
std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(),
|
||||
std::back_inserter(calibrationCache));
|
||||
}
|
||||
length = calibrationCache.size();
|
||||
return length ? calibrationCache.data() : nullptr;
|
||||
}
|
||||
|
||||
void int8EntroyCalibrator::writeCalibrationCache(const void *cache, std::size_t length)
|
||||
{
|
||||
std::ofstream output(calibTablePath, std::ios::binary);
|
||||
output.write(reinterpret_cast<const char*>(cache), length);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> prepareImage(cv::Mat& img, int input_c, int input_h, int input_w, int letter_box)
|
||||
{
|
||||
cv::Mat out;
|
||||
if (letter_box == 2)
|
||||
{
|
||||
int image_w = img.cols;
|
||||
int image_h = img.rows;
|
||||
int resize_w = 0;
|
||||
int resize_h = 0;
|
||||
int offset_top = 0;
|
||||
int offset_bottom = 0;
|
||||
int offset_left = 0;
|
||||
int offset_right = 0;
|
||||
if ((float)input_h / image_h > (float)input_w / image_w)
|
||||
{
|
||||
resize_w = input_w;
|
||||
resize_h = (input_w * image_h) / image_w;
|
||||
offset_bottom = input_h - resize_h;
|
||||
}
|
||||
else
|
||||
{
|
||||
resize_h = input_h;
|
||||
resize_w = (input_h * image_w) / image_h;
|
||||
offset_right = input_w - resize_w;
|
||||
}
|
||||
cv::resize(img, out, cv::Size(resize_w, resize_h), 0, 0, cv::INTER_CUBIC);
|
||||
cv::copyMakeBorder(out, out, offset_top, offset_bottom, offset_left, offset_right, cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
cv::resize(img, out, cv::Size(input_w, input_h), 0, 0, cv::INTER_CUBIC);
|
||||
}
|
||||
cv::cvtColor(out, out, cv::COLOR_BGR2RGB);
|
||||
if (input_c == 3)
|
||||
{
|
||||
out.convertTo(out, CV_32FC3, 1.0 / 255.0);
|
||||
}
|
||||
else
|
||||
{
|
||||
out.convertTo(out, CV_32FC1, 1.0 / 255.0);
|
||||
}
|
||||
std::vector<cv::Mat> input_channels(input_c);
|
||||
cv::split(out, input_channels);
|
||||
std::vector<float> result(input_h * input_w * input_c);
|
||||
auto data = result.data();
|
||||
int channelLength = input_h * input_w;
|
||||
for (int i = 0; i < input_c; ++i)
|
||||
{
|
||||
memcpy(data, input_channels[i].data, channelLength * sizeof(float));
|
||||
data += channelLength;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
62
native/nvdsinfer_custom_impl_Yolo/calibrator.h
Normal file
62
native/nvdsinfer_custom_impl_Yolo/calibrator.h
Normal file
@@ -0,0 +1,62 @@
|
||||
/*
|
||||
* Created by Marcos Luciano
|
||||
* https://www.github.com/marcoslucianops
|
||||
*/
|
||||
|
||||
#ifndef CALIBRATOR_H
|
||||
#define CALIBRATOR_H
|
||||
|
||||
#include "opencv2/opencv.hpp"
|
||||
#include "cuda_runtime.h"
|
||||
#include "NvInfer.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#ifndef CUDA_CHECK
|
||||
#define CUDA_CHECK(callstr) \
|
||||
{ \
|
||||
cudaError_t error_code = callstr; \
|
||||
if (error_code != cudaSuccess) { \
|
||||
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
|
||||
assert(0); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace nvinfer1 {
|
||||
class int8EntroyCalibrator : public nvinfer1::IInt8EntropyCalibrator2 {
|
||||
public:
|
||||
int8EntroyCalibrator(const int &batchsize,
|
||||
const int &channels,
|
||||
const int &height,
|
||||
const int &width,
|
||||
const int &letterbox,
|
||||
const std::string &imgPath,
|
||||
const std::string &calibTablePath);
|
||||
|
||||
virtual ~int8EntroyCalibrator();
|
||||
int getBatchSize() const override { return batchSize; }
|
||||
bool getBatch(void *bindings[], const char *names[], int nbBindings) override;
|
||||
const void *readCalibrationCache(std::size_t &length) override;
|
||||
void writeCalibrationCache(const void *ptr, std::size_t length) override;
|
||||
|
||||
private:
|
||||
int batchSize;
|
||||
int inputC;
|
||||
int inputH;
|
||||
int inputW;
|
||||
int letterBox;
|
||||
std::string calibTablePath;
|
||||
size_t imageIndex;
|
||||
size_t inputCount;
|
||||
std::vector<std::string> imgPaths;
|
||||
float *batchData{ nullptr };
|
||||
void *deviceInput{ nullptr };
|
||||
bool readCache;
|
||||
std::vector<char> calibrationCache;
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<float> prepareImage(cv::Mat& img, int input_c, int input_h, int input_w, int letter_box);
|
||||
|
||||
#endif //CALIBRATOR_H
|
||||
@@ -8,79 +8,17 @@
|
||||
nvinfer1::ILayer* upsampleLayer(
|
||||
int layerIdx,
|
||||
std::map<std::string, std::string>& block,
|
||||
std::vector<float>& weights,
|
||||
std::vector<nvinfer1::Weights>& trtWeights,
|
||||
int& inputChannels,
|
||||
nvinfer1::ITensor* input,
|
||||
nvinfer1::INetworkDefinition* network)
|
||||
{
|
||||
assert(block.at("type") == "upsample");
|
||||
nvinfer1::Dims inpDims = input->getDimensions();
|
||||
assert(inpDims.nbDims == 3);
|
||||
assert(inpDims.d[1] == inpDims.d[2]);
|
||||
int h = inpDims.d[1];
|
||||
int w = inpDims.d[2];
|
||||
int stride = std::stoi(block.at("stride"));
|
||||
|
||||
nvinfer1::Dims preDims{3,
|
||||
{1, stride * h, w},
|
||||
{nvinfer1::DimensionType::kCHANNEL, nvinfer1::DimensionType::kSPATIAL,
|
||||
nvinfer1::DimensionType::kSPATIAL}};
|
||||
int size = stride * h * w;
|
||||
nvinfer1::Weights preMul{nvinfer1::DataType::kFLOAT, nullptr, size};
|
||||
float* preWt = new float[size];
|
||||
|
||||
for (int i = 0, idx = 0; i < h; ++i)
|
||||
{
|
||||
for (int s = 0; s < stride; ++s)
|
||||
{
|
||||
for (int j = 0; j < w; ++j, ++idx)
|
||||
{
|
||||
preWt[idx] = (i == j) ? 1.0 : 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
preMul.values = preWt;
|
||||
trtWeights.push_back(preMul);
|
||||
nvinfer1::IConstantLayer* preM = network->addConstant(preDims, preMul);
|
||||
assert(preM != nullptr);
|
||||
std::string preLayerName = "preMul_" + std::to_string(layerIdx);
|
||||
preM->setName(preLayerName.c_str());
|
||||
|
||||
nvinfer1::Dims postDims{3,
|
||||
{1, h, stride * w},
|
||||
{nvinfer1::DimensionType::kCHANNEL, nvinfer1::DimensionType::kSPATIAL,
|
||||
nvinfer1::DimensionType::kSPATIAL}};
|
||||
size = stride * h * w;
|
||||
nvinfer1::Weights postMul{nvinfer1::DataType::kFLOAT, nullptr, size};
|
||||
float* postWt = new float[size];
|
||||
|
||||
for (int i = 0, idx = 0; i < h; ++i)
|
||||
{
|
||||
for (int j = 0; j < stride * w; ++j, ++idx)
|
||||
{
|
||||
postWt[idx] = (j / stride == i) ? 1.0 : 0.0;
|
||||
}
|
||||
}
|
||||
postMul.values = postWt;
|
||||
trtWeights.push_back(postMul);
|
||||
nvinfer1::IConstantLayer* post_m = network->addConstant(postDims, postMul);
|
||||
assert(post_m != nullptr);
|
||||
std::string postLayerName = "postMul_" + std::to_string(layerIdx);
|
||||
post_m->setName(postLayerName.c_str());
|
||||
|
||||
nvinfer1::IMatrixMultiplyLayer* mm1
|
||||
= network->addMatrixMultiply(*preM->getOutput(0), nvinfer1::MatrixOperation::kNONE, *input,
|
||||
nvinfer1::MatrixOperation::kNONE);
|
||||
assert(mm1 != nullptr);
|
||||
std::string mm1LayerName = "mm1_" + std::to_string(layerIdx);
|
||||
mm1->setName(mm1LayerName.c_str());
|
||||
nvinfer1::IMatrixMultiplyLayer* mm2
|
||||
= network->addMatrixMultiply(*mm1->getOutput(0), nvinfer1::MatrixOperation::kNONE,
|
||||
*post_m->getOutput(0), nvinfer1::MatrixOperation::kNONE);
|
||||
assert(mm2 != nullptr);
|
||||
std::string mm2LayerName = "mm2_" + std::to_string(layerIdx);
|
||||
mm2->setName(mm2LayerName.c_str());
|
||||
|
||||
return mm2;
|
||||
nvinfer1::IResizeLayer* resize_layer = network->addResize(*input);
|
||||
resize_layer->setResizeMode(nvinfer1::ResizeMode::kNEAREST);
|
||||
float scale[3] = {1, stride, stride};
|
||||
resize_layer->setScales(scale, 3);
|
||||
std::string layer_name = "upsample_" + std::to_string(layerIdx);
|
||||
resize_layer->setName(layer_name.c_str());
|
||||
return resize_layer;
|
||||
}
|
||||
@@ -15,9 +15,6 @@
|
||||
nvinfer1::ILayer* upsampleLayer(
|
||||
int layerIdx,
|
||||
std::map<std::string, std::string>& block,
|
||||
std::vector<float>& weights,
|
||||
std::vector<nvinfer1::Weights>& trtWeights,
|
||||
int& inputChannels,
|
||||
nvinfer1::ITensor* input,
|
||||
nvinfer1::INetworkDefinition* network);
|
||||
|
||||
|
||||
@@ -45,9 +45,20 @@ static bool getYoloNetworkInfo (NetworkInfo &networkInfo, const NvDsInferContext
|
||||
networkInfo.networkType = yoloType;
|
||||
networkInfo.configFilePath = initParams->customNetworkConfigFilePath;
|
||||
networkInfo.wtsFilePath = initParams->modelFilePath;
|
||||
networkInfo.int8CalibPath = initParams->int8CalibrationFilePath;
|
||||
networkInfo.deviceType = (initParams->useDLA ? "kDLA" : "kGPU");
|
||||
networkInfo.inputBlobName = "data";
|
||||
|
||||
if(initParams->networkMode == 0) {
|
||||
networkInfo.networkMode = "FP32";
|
||||
}
|
||||
else if(initParams->networkMode == 1) {
|
||||
networkInfo.networkMode = "INT8";
|
||||
}
|
||||
else if(initParams->networkMode == 2) {
|
||||
networkInfo.networkMode = "FP16";
|
||||
}
|
||||
|
||||
if (networkInfo.configFilePath.empty() ||
|
||||
networkInfo.wtsFilePath.empty()) {
|
||||
std::cerr << "YOLO config file or weights file is not specified"
|
||||
|
||||
@@ -302,7 +302,6 @@ static bool NvDsInferParseYolo(
|
||||
const uint gridSizeH = layer.inferDims.d[1];
|
||||
const uint gridSizeW = layer.inferDims.d[2];
|
||||
const uint stride = DIVUP(networkInfo.width, gridSizeW);
|
||||
assert(stride == DIVUP(networkInfo.height, gridSizeH));
|
||||
|
||||
std::vector<NvDsInferParseObjectInfo> outObjs =
|
||||
decodeYoloTensor((const float*)(layer.buffer), masks[idx], anchors, gridSizeW, gridSizeH, stride, masks[idx].size(),
|
||||
@@ -344,7 +343,6 @@ static bool NvDsInferParseYoloV2(
|
||||
const uint gridSizeH = layer.inferDims.d[1];
|
||||
const uint gridSizeW = layer.inferDims.d[2];
|
||||
const uint stride = DIVUP(networkInfo.width, gridSizeW);
|
||||
assert(stride == DIVUP(networkInfo.height, gridSizeH));
|
||||
for (auto& anchor : anchors) {
|
||||
anchor *= stride;
|
||||
}
|
||||
|
||||
@@ -25,6 +25,11 @@
|
||||
|
||||
#include "yolo.h"
|
||||
#include "yoloPlugins.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "calibrator.h"
|
||||
#endif
|
||||
|
||||
void orderParams(std::vector<std::vector<int>> *maskVector) {
|
||||
std::vector<std::vector<int>> maskinput = *maskVector;
|
||||
@@ -45,6 +50,8 @@ Yolo::Yolo(const NetworkInfo& networkInfo)
|
||||
: m_NetworkType(networkInfo.networkType), // YOLO type
|
||||
m_ConfigFilePath(networkInfo.configFilePath), // YOLO cfg
|
||||
m_WtsFilePath(networkInfo.wtsFilePath), // YOLO weights
|
||||
m_Int8CalibPath(networkInfo.int8CalibPath), // INT8 calibration path
|
||||
m_NetworkMode(networkInfo.networkMode), // FP32, INT8, FP16
|
||||
m_DeviceType(networkInfo.deviceType), // kDLA, kGPU
|
||||
m_InputBlobName(networkInfo.inputBlobName), // data
|
||||
m_InputH(0),
|
||||
@@ -62,6 +69,38 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
|
||||
{
|
||||
assert (builder);
|
||||
|
||||
m_ConfigBlocks = parseConfigFile(m_ConfigFilePath);
|
||||
parseConfigBlocks();
|
||||
orderParams(&m_OutputMasks);
|
||||
|
||||
if (m_NetworkMode == "INT8" && !fileExists(m_Int8CalibPath)) {
|
||||
assert(builder->platformHasFastInt8());
|
||||
#ifdef OPENCV
|
||||
std::string calib_image_list;
|
||||
int calib_batch_size;
|
||||
if (getenv("INT8_CALIB_IMG_PATH")) {
|
||||
calib_image_list = getenv("INT8_CALIB_IMG_PATH");
|
||||
}
|
||||
else {
|
||||
std::cerr << "INT8_CALIB_IMG_PATH not set" << std::endl;
|
||||
std::abort();
|
||||
}
|
||||
if (getenv("INT8_CALIB_BATCH_SIZE")) {
|
||||
calib_batch_size = std::stoi(getenv("INT8_CALIB_BATCH_SIZE"));
|
||||
}
|
||||
else {
|
||||
std::cerr << "INT8_CALIB_BATCH_SIZE not set" << std::endl;
|
||||
std::abort();
|
||||
}
|
||||
nvinfer1::int8EntroyCalibrator *calibrator = new nvinfer1::int8EntroyCalibrator(calib_batch_size, m_InputC, m_InputH, m_InputW, m_LetterBox, calib_image_list, m_Int8CalibPath);
|
||||
builder->setInt8Mode(true);
|
||||
builder->setInt8Calibrator(calibrator);
|
||||
#else
|
||||
std::cerr << "OpenCV is required to run INT8 calibrator" << std::endl;
|
||||
std::abort();
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType);
|
||||
std::vector<nvinfer1::Weights> trtWeights;
|
||||
|
||||
@@ -71,8 +110,12 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Build the engine
|
||||
std::cout << "Building the TensorRT Engine" << std::endl;
|
||||
|
||||
if (m_LetterBox == 1) {
|
||||
std::cout << "\nNOTE: letter_box is set in cfg file, make sure to set maintain-aspect-ratio=1 in config_infer file to get better accuracy\n" << std::endl;
|
||||
}
|
||||
|
||||
nvinfer1::ICudaEngine * engine = builder->buildCudaEngine(*network);
|
||||
if (engine) {
|
||||
std::cout << "Building complete\n" << std::endl;
|
||||
@@ -80,7 +123,6 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
|
||||
std::cerr << "Building engine failed\n" << std::endl;
|
||||
}
|
||||
|
||||
// destroy
|
||||
network->destroy();
|
||||
return engine;
|
||||
}
|
||||
@@ -88,12 +130,7 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
|
||||
NvDsInferStatus Yolo::parseModel(nvinfer1::INetworkDefinition& network) {
|
||||
destroyNetworkUtils();
|
||||
|
||||
m_ConfigBlocks = parseConfigFile(m_ConfigFilePath);
|
||||
parseConfigBlocks();
|
||||
orderParams(&m_OutputMasks);
|
||||
|
||||
std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType);
|
||||
// build yolo network
|
||||
std::cout << "Building YOLO network" << std::endl;
|
||||
NvDsInferStatus status = buildYoloNetwork(weights, network);
|
||||
|
||||
@@ -121,9 +158,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
||||
std::vector<nvinfer1::ITensor*> tensorOutputs;
|
||||
uint outputTensorCount = 0;
|
||||
|
||||
// build the network using the network API
|
||||
for (uint i = 0; i < m_ConfigBlocks.size(); ++i) {
|
||||
// check if num. of channels is correct
|
||||
assert(getNumChannels(previous) == channels);
|
||||
std::string layerIndex = "(" + std::to_string(tensorOutputs.size()) + ")";
|
||||
|
||||
@@ -192,7 +227,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
||||
|
||||
else if (m_ConfigBlocks.at(i).at("type") == "upsample") {
|
||||
std::string inputVol = dimsToString(previous->getDimensions());
|
||||
nvinfer1::ILayer* out = upsampleLayer(i - 1, m_ConfigBlocks[i], weights, m_TrtWeights, channels, previous, &network);
|
||||
nvinfer1::ILayer* out = upsampleLayer(i - 1, m_ConfigBlocks[i], previous, &network);
|
||||
previous = out->getOutput(0);
|
||||
assert(previous != nullptr);
|
||||
std::string outputVol = dimsToString(previous->getDimensions());
|
||||
@@ -212,12 +247,12 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
||||
|
||||
else if (m_ConfigBlocks.at(i).at("type") == "yolo") {
|
||||
nvinfer1::Dims prevTensorDims = previous->getDimensions();
|
||||
assert(prevTensorDims.d[1] == prevTensorDims.d[2]);
|
||||
TensorInfo& curYoloTensor = m_OutputTensors.at(outputTensorCount);
|
||||
curYoloTensor.gridSize = prevTensorDims.d[1];
|
||||
curYoloTensor.stride = m_InputW / curYoloTensor.gridSize;
|
||||
m_OutputTensors.at(outputTensorCount).volume = curYoloTensor.gridSize
|
||||
* curYoloTensor.gridSize
|
||||
curYoloTensor.gridSizeY = prevTensorDims.d[1];
|
||||
curYoloTensor.gridSizeX = prevTensorDims.d[2];
|
||||
curYoloTensor.stride = m_InputH / curYoloTensor.gridSizeY;
|
||||
m_OutputTensors.at(outputTensorCount).volume = curYoloTensor.gridSizeY
|
||||
* curYoloTensor.gridSizeX
|
||||
* (curYoloTensor.numBBoxes * (5 + curYoloTensor.numClasses));
|
||||
std::string layerName = "yolo_" + std::to_string(i);
|
||||
curYoloTensor.blobName = layerName;
|
||||
@@ -236,7 +271,8 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
||||
nvinfer1::IPluginV2* yoloPlugin
|
||||
= new YoloLayer(m_OutputTensors.at(outputTensorCount).numBBoxes,
|
||||
m_OutputTensors.at(outputTensorCount).numClasses,
|
||||
m_OutputTensors.at(outputTensorCount).gridSize,
|
||||
m_OutputTensors.at(outputTensorCount).gridSizeX,
|
||||
m_OutputTensors.at(outputTensorCount).gridSizeY,
|
||||
1, new_coords, scale_x_y, beta_nms,
|
||||
curYoloTensor.anchors,
|
||||
m_OutputMasks);
|
||||
@@ -260,12 +296,12 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
||||
//YOLOv2 support
|
||||
else if (m_ConfigBlocks.at(i).at("type") == "region") {
|
||||
nvinfer1::Dims prevTensorDims = previous->getDimensions();
|
||||
assert(prevTensorDims.d[1] == prevTensorDims.d[2]);
|
||||
TensorInfo& curRegionTensor = m_OutputTensors.at(outputTensorCount);
|
||||
curRegionTensor.gridSize = prevTensorDims.d[1];
|
||||
curRegionTensor.stride = m_InputW / curRegionTensor.gridSize;
|
||||
m_OutputTensors.at(outputTensorCount).volume = curRegionTensor.gridSize
|
||||
* curRegionTensor.gridSize
|
||||
curRegionTensor.gridSizeY = prevTensorDims.d[1];
|
||||
curRegionTensor.gridSizeX = prevTensorDims.d[2];
|
||||
curRegionTensor.stride = m_InputH / curRegionTensor.gridSizeY;
|
||||
m_OutputTensors.at(outputTensorCount).volume = curRegionTensor.gridSizeY
|
||||
* curRegionTensor.gridSizeX
|
||||
* (curRegionTensor.numBBoxes * (5 + curRegionTensor.numClasses));
|
||||
std::string layerName = "region_" + std::to_string(i);
|
||||
curRegionTensor.blobName = layerName;
|
||||
@@ -273,7 +309,8 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
||||
nvinfer1::IPluginV2* regionPlugin
|
||||
= new YoloLayer(curRegionTensor.numBBoxes,
|
||||
curRegionTensor.numClasses,
|
||||
curRegionTensor.gridSize,
|
||||
curRegionTensor.gridSizeX,
|
||||
curRegionTensor.gridSizeY,
|
||||
0, 0, 1.0, 0,
|
||||
curRegionTensor.anchors,
|
||||
mask);
|
||||
@@ -387,8 +424,14 @@ void Yolo::parseConfigBlocks()
|
||||
m_InputH = std::stoul(block.at("height"));
|
||||
m_InputW = std::stoul(block.at("width"));
|
||||
m_InputC = std::stoul(block.at("channels"));
|
||||
assert(m_InputW == m_InputH);
|
||||
m_InputSize = m_InputC * m_InputH * m_InputW;
|
||||
|
||||
if (block.find("letter_box") != block.end()) {
|
||||
m_LetterBox = std::stoul(block.at("letter_box"));
|
||||
}
|
||||
else {
|
||||
m_LetterBox = 0;
|
||||
}
|
||||
}
|
||||
else if ((block.at("type") == "region") || (block.at("type") == "yolo"))
|
||||
{
|
||||
@@ -456,10 +499,9 @@ void Yolo::parseConfigBlocks()
|
||||
}
|
||||
|
||||
void Yolo::destroyNetworkUtils() {
|
||||
// deallocate the weights
|
||||
for (uint i = 0; i < m_TrtWeights.size(); ++i) {
|
||||
if (m_TrtWeights[i].count > 0)
|
||||
free(const_cast<void*>(m_TrtWeights[i].values));
|
||||
}
|
||||
m_TrtWeights.clear();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,8 @@ struct NetworkInfo
|
||||
std::string networkType;
|
||||
std::string configFilePath;
|
||||
std::string wtsFilePath;
|
||||
std::string int8CalibPath;
|
||||
std::string networkMode;
|
||||
std::string deviceType;
|
||||
std::string inputBlobName;
|
||||
};
|
||||
@@ -48,7 +50,8 @@ struct TensorInfo
|
||||
{
|
||||
std::string blobName;
|
||||
uint stride{0};
|
||||
uint gridSize{0};
|
||||
uint gridSizeY{0};
|
||||
uint gridSizeX{0};
|
||||
uint numClasses{0};
|
||||
uint numBBoxes{0};
|
||||
uint64_t volume{0};
|
||||
@@ -75,6 +78,8 @@ protected:
|
||||
const std::string m_NetworkType;
|
||||
const std::string m_ConfigFilePath;
|
||||
const std::string m_WtsFilePath;
|
||||
const std::string m_Int8CalibPath;
|
||||
const std::string m_NetworkMode;
|
||||
const std::string m_DeviceType;
|
||||
const std::string m_InputBlobName;
|
||||
std::vector<TensorInfo> m_OutputTensors;
|
||||
@@ -84,6 +89,7 @@ protected:
|
||||
uint m_InputW;
|
||||
uint m_InputC;
|
||||
uint64_t m_InputSize;
|
||||
uint m_LetterBox;
|
||||
|
||||
std::vector<nvinfer1::Weights> m_TrtWeights;
|
||||
|
||||
|
||||
@@ -20,20 +20,20 @@
|
||||
|
||||
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
|
||||
|
||||
__global__ void gpuYoloLayer(const float* input, float* output, const uint gridSize, const uint numOutputClasses,
|
||||
__global__ void gpuYoloLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
||||
const uint numBBoxes, const uint new_coords, const float scale_x_y)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
uint z_id = blockIdx.z * blockDim.z + threadIdx.z;
|
||||
|
||||
if ((x_id >= gridSize) || (y_id >= gridSize) || (z_id >= numBBoxes))
|
||||
if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const int numGridCells = gridSize * gridSize;
|
||||
const int bbindex = y_id * gridSize + x_id;
|
||||
const int numGridCells = gridSizeX * gridSizeY;
|
||||
const int bbindex = y_id * gridSizeX + x_id;
|
||||
|
||||
float alpha = scale_x_y;
|
||||
float beta = -0.5 * (scale_x_y - 1);
|
||||
@@ -84,20 +84,20 @@ __global__ void gpuYoloLayer(const float* input, float* output, const uint gridS
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gpuRegionLayer(const float* input, float* output, const uint gridSize, const uint numOutputClasses,
|
||||
__global__ void gpuRegionLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
||||
const uint numBBoxes)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
uint z_id = blockIdx.z * blockDim.z + threadIdx.z;
|
||||
|
||||
if ((x_id >= gridSize) || (y_id >= gridSize) || (z_id >= numBBoxes))
|
||||
if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const int numGridCells = gridSize * gridSize;
|
||||
const int bbindex = y_id * gridSize + x_id;
|
||||
const int numGridCells = gridSizeX * gridSizeY;
|
||||
const int bbindex = y_id * gridSizeX + x_id;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]);
|
||||
@@ -132,24 +132,24 @@ __global__ void gpuRegionLayer(const float* input, float* output, const uint gri
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSize,
|
||||
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes,
|
||||
uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType);
|
||||
|
||||
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSize,
|
||||
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes,
|
||||
uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType)
|
||||
{
|
||||
dim3 threads_per_block(16, 16, 4);
|
||||
dim3 number_of_blocks((gridSize / threads_per_block.x) + 1,
|
||||
(gridSize / threads_per_block.y) + 1,
|
||||
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
|
||||
(gridSizeY / threads_per_block.y) + 1,
|
||||
(numBBoxes / threads_per_block.z) + 1);
|
||||
if (modelType == 1) {
|
||||
for (unsigned int batch = 0; batch < batchSize; ++batch)
|
||||
{
|
||||
gpuYoloLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSize, numOutputClasses,
|
||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
||||
numBBoxes, modelCoords, modelScale);
|
||||
}
|
||||
}
|
||||
@@ -158,7 +158,7 @@ cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize
|
||||
{
|
||||
gpuRegionLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSize, numOutputClasses,
|
||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
||||
numBBoxes);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ void read(const char*& buffer, T& val)
|
||||
|
||||
cudaError_t cudaYoloLayer (
|
||||
const void* input, void* output, const uint& batchSize,
|
||||
const uint& gridSize, const uint& numOutputClasses,
|
||||
const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
|
||||
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType);
|
||||
|
||||
YoloLayer::YoloLayer (const void* data, size_t length)
|
||||
@@ -60,7 +60,8 @@ YoloLayer::YoloLayer (const void* data, size_t length)
|
||||
const char *d = static_cast<const char*>(data);
|
||||
read(d, m_NumBoxes);
|
||||
read(d, m_NumClasses);
|
||||
read(d, m_GridSize);
|
||||
read(d, m_GridSizeX);
|
||||
read(d, m_GridSizeY);
|
||||
read(d, m_OutputSize);
|
||||
|
||||
read(d, m_type);
|
||||
@@ -94,10 +95,11 @@ YoloLayer::YoloLayer (const void* data, size_t length)
|
||||
};
|
||||
|
||||
YoloLayer::YoloLayer (
|
||||
const uint& numBoxes, const uint& numClasses, const uint& gridSize, const uint model_type, const uint new_coords, const float scale_x_y, const float beta_nms, const std::vector<float> anchors, std::vector<std::vector<int>> mask) :
|
||||
const uint& numBoxes, const uint& numClasses, const uint& gridSizeX, const uint& gridSizeY, const uint model_type, const uint new_coords, const float scale_x_y, const float beta_nms, const std::vector<float> anchors, std::vector<std::vector<int>> mask) :
|
||||
m_NumBoxes(numBoxes),
|
||||
m_NumClasses(numClasses),
|
||||
m_GridSize(gridSize),
|
||||
m_GridSizeX(gridSizeX),
|
||||
m_GridSizeY(gridSizeY),
|
||||
m_type(model_type),
|
||||
m_new_coords(new_coords),
|
||||
m_scale_x_y(scale_x_y),
|
||||
@@ -107,8 +109,9 @@ YoloLayer::YoloLayer (
|
||||
{
|
||||
assert(m_NumBoxes > 0);
|
||||
assert(m_NumClasses > 0);
|
||||
assert(m_GridSize > 0);
|
||||
m_OutputSize = m_GridSize * m_GridSize * (m_NumBoxes * (4 + 1 + m_NumClasses));
|
||||
assert(m_GridSizeX > 0);
|
||||
assert(m_GridSizeY > 0);
|
||||
m_OutputSize = m_GridSizeX * m_GridSizeY * (m_NumBoxes * (4 + 1 + m_NumClasses));
|
||||
};
|
||||
|
||||
nvinfer1::Dims
|
||||
@@ -142,7 +145,7 @@ int YoloLayer::enqueue(
|
||||
cudaStream_t stream)
|
||||
{
|
||||
CHECK(cudaYoloLayer(
|
||||
inputs[0], outputs[0], batchSize, m_GridSize, m_NumClasses, m_NumBoxes,
|
||||
inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
|
||||
m_OutputSize, stream, m_new_coords, m_scale_x_y, m_type));
|
||||
return 0;
|
||||
}
|
||||
@@ -161,7 +164,7 @@ size_t YoloLayer::getSerializationSize() const
|
||||
}
|
||||
}
|
||||
|
||||
return sizeof(m_NumBoxes) + sizeof(m_NumClasses) + sizeof(m_GridSize) + sizeof(m_OutputSize) + sizeof(m_type)
|
||||
return sizeof(m_NumBoxes) + sizeof(m_NumClasses) + sizeof(m_GridSizeX) + sizeof(m_GridSizeY) + sizeof(m_OutputSize) + sizeof(m_type)
|
||||
+ sizeof(m_new_coords) + sizeof(m_scale_x_y) + sizeof(m_beta_nms) + anchorsSum * sizeof(float) + maskSum * sizeof(int);
|
||||
}
|
||||
|
||||
@@ -170,7 +173,8 @@ void YoloLayer::serialize(void* buffer) const
|
||||
char *d = static_cast<char*>(buffer);
|
||||
write(d, m_NumBoxes);
|
||||
write(d, m_NumClasses);
|
||||
write(d, m_GridSize);
|
||||
write(d, m_GridSizeX);
|
||||
write(d, m_GridSizeY);
|
||||
write(d, m_OutputSize);
|
||||
|
||||
write(d, m_type);
|
||||
@@ -199,7 +203,7 @@ void YoloLayer::serialize(void* buffer) const
|
||||
|
||||
nvinfer1::IPluginV2* YoloLayer::clone() const
|
||||
{
|
||||
return new YoloLayer (m_NumBoxes, m_NumClasses, m_GridSize, m_type, m_new_coords, m_scale_x_y, m_beta_nms, m_Anchors, m_Mask);
|
||||
return new YoloLayer (m_NumBoxes, m_NumClasses, m_GridSizeX, m_GridSizeY, m_type, m_new_coords, m_scale_x_y, m_beta_nms, m_Anchors, m_Mask);
|
||||
}
|
||||
|
||||
REGISTER_TENSORRT_PLUGIN(YoloLayerPluginCreator);
|
||||
@@ -56,7 +56,7 @@ class YoloLayer : public nvinfer1::IPluginV2
|
||||
{
|
||||
public:
|
||||
YoloLayer (const void* data, size_t length);
|
||||
YoloLayer (const uint& numBoxes, const uint& numClasses, const uint& gridSize,
|
||||
YoloLayer (const uint& numBoxes, const uint& numClasses, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint model_type, const uint new_coords, const float scale_x_y, const float beta_nms,
|
||||
const std::vector<float> anchors, const std::vector<std::vector<int>> mask);
|
||||
const char* getPluginType () const override { return YOLOLAYER_PLUGIN_NAME; }
|
||||
@@ -96,7 +96,8 @@ public:
|
||||
private:
|
||||
uint m_NumBoxes {0};
|
||||
uint m_NumClasses {0};
|
||||
uint m_GridSize {0};
|
||||
uint m_GridSizeX {0};
|
||||
uint m_GridSizeY {0};
|
||||
uint64_t m_OutputSize {0};
|
||||
std::string m_Namespace {""};
|
||||
|
||||
@@ -152,4 +153,4 @@ extern float kBETA_NMS;
|
||||
extern std::vector<float> kANCHORS;
|
||||
extern std::vector<std::vector<int>> kMASK;
|
||||
|
||||
#endif // __YOLO_PLUGINS__
|
||||
#endif // __YOLO_PLUGINS__
|
||||
|
||||
Reference in New Issue
Block a user