New features

- Added support for INT8 calibration
- Added support for non square models
- Updated mAP comparison between models
This commit is contained in:
Marcos Luciano
2021-06-18 00:30:10 -03:00
parent 312e9a448d
commit cbd9675dc2
74 changed files with 3287 additions and 700 deletions

View File

@@ -27,13 +27,25 @@ CUDA_VER?=
ifeq ($(CUDA_VER),)
$(error "CUDA_VER is not set")
endif
OPENCV?=
ifeq ($(OPENCV),)
OPENCV=0
endif
CC:= g++
NVCC:=/usr/local/cuda-$(CUDA_VER)/bin/nvcc
CFLAGS:= -Wall -std=c++11 -shared -fPIC -Wno-error=deprecated-declarations
CFLAGS+= -I../../includes -I/usr/local/cuda-$(CUDA_VER)/include
CFLAGS+= -I/opt/nvidia/deepstream/deepstream-5.1/sources/includes -I/usr/local/cuda-$(CUDA_VER)/include
LIBS:= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib64 -lcudart -lcublas -lstdc++fs
ifeq ($(OPENCV), 1)
COMMON= -DOPENCV
CFLAGS+= $(shell pkg-config --cflags opencv4 2> /dev/null || pkg-config --cflags opencv)
LIBS+= $(shell pkg-config --libs opencv4 2> /dev/null || pkg-config --libs opencv)
endif
LIBS+= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib64 -lcudart -lcublas -lstdc++fs
LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group
INCS:= $(wildcard *.h)
@@ -50,6 +62,11 @@ SRCFILES:= nvdsinfer_yolo_engine.cpp \
utils.cpp \
yolo.cpp \
yoloForward.cu
ifeq ($(OPENCV), 1)
SRCFILES+= calibrator.cpp
endif
TARGET_LIB:= libnvdsinfer_custom_impl_Yolo.so
TARGET_OBJS:= $(SRCFILES:.cpp=.o)
@@ -58,7 +75,7 @@ TARGET_OBJS:= $(TARGET_OBJS:.cu=.o)
all: $(TARGET_LIB)
%.o: %.cpp $(INCS) Makefile
$(CC) -c -o $@ $(CFLAGS) $<
$(CC) -c $(COMMON) -o $@ $(CFLAGS) $<
%.o: %.cu $(INCS) Makefile
$(NVCC) -c -o $@ --compiler-options '-fPIC' $<

View File

@@ -0,0 +1,130 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "calibrator.h"
#include <fstream>
#include <iterator>
namespace nvinfer1
{
int8EntroyCalibrator::int8EntroyCalibrator(const int &batchsize, const int &channels, const int &height, const int &width, const int &letterbox, const std::string &imgPath,
const std::string &calibTablePath):batchSize(batchsize), inputC(channels), inputH(height), inputW(width), letterBox(letterbox), calibTablePath(calibTablePath), imageIndex(0)
{
inputCount = batchsize * channels * height * width;
std::fstream f(imgPath);
if (f.is_open())
{
std::string temp;
while (std::getline(f, temp)) imgPaths.push_back(temp);
}
batchData = new float[inputCount];
CUDA_CHECK(cudaMalloc(&deviceInput, inputCount * sizeof(float)));
}
int8EntroyCalibrator::~int8EntroyCalibrator()
{
CUDA_CHECK(cudaFree(deviceInput));
if (batchData)
delete[] batchData;
}
bool int8EntroyCalibrator::getBatch(void **bindings, const char **names, int nbBindings)
{
if (imageIndex + batchSize > uint(imgPaths.size()))
return false;
float* ptr = batchData;
for (size_t j = imageIndex; j < imageIndex + batchSize; ++j)
{
cv::Mat img = cv::imread(imgPaths[j], cv::IMREAD_COLOR);
std::vector<float>inputData = prepareImage(img, inputC, inputH, inputW, letterBox);
int len = (int)(inputData.size());
memcpy(ptr, inputData.data(), len * sizeof(float));
ptr += inputData.size();
std::cout << "Load image: " << imgPaths[j] << std::endl;
std::cout << "Progress: " << (j + 1)*100. / imgPaths.size() << "%" << std::endl;
}
imageIndex += batchSize;
CUDA_CHECK(cudaMemcpy(deviceInput, batchData, inputCount * sizeof(float), cudaMemcpyHostToDevice));
bindings[0] = deviceInput;
return true;
}
const void* int8EntroyCalibrator::readCalibrationCache(std::size_t &length)
{
calibrationCache.clear();
std::ifstream input(calibTablePath, std::ios::binary);
input >> std::noskipws;
if (readCache && input.good())
{
std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(),
std::back_inserter(calibrationCache));
}
length = calibrationCache.size();
return length ? calibrationCache.data() : nullptr;
}
void int8EntroyCalibrator::writeCalibrationCache(const void *cache, std::size_t length)
{
std::ofstream output(calibTablePath, std::ios::binary);
output.write(reinterpret_cast<const char*>(cache), length);
}
}
std::vector<float> prepareImage(cv::Mat& img, int input_c, int input_h, int input_w, int letter_box)
{
cv::Mat out;
if (letter_box == 2)
{
int image_w = img.cols;
int image_h = img.rows;
int resize_w = 0;
int resize_h = 0;
int offset_top = 0;
int offset_bottom = 0;
int offset_left = 0;
int offset_right = 0;
if ((float)input_h / image_h > (float)input_w / image_w)
{
resize_w = input_w;
resize_h = (input_w * image_h) / image_w;
offset_bottom = input_h - resize_h;
}
else
{
resize_h = input_h;
resize_w = (input_h * image_w) / image_h;
offset_right = input_w - resize_w;
}
cv::resize(img, out, cv::Size(resize_w, resize_h), 0, 0, cv::INTER_CUBIC);
cv::copyMakeBorder(out, out, offset_top, offset_bottom, offset_left, offset_right, cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
}
else
{
cv::resize(img, out, cv::Size(input_w, input_h), 0, 0, cv::INTER_CUBIC);
}
cv::cvtColor(out, out, cv::COLOR_BGR2RGB);
if (input_c == 3)
{
out.convertTo(out, CV_32FC3, 1.0 / 255.0);
}
else
{
out.convertTo(out, CV_32FC1, 1.0 / 255.0);
}
std::vector<cv::Mat> input_channels(input_c);
cv::split(out, input_channels);
std::vector<float> result(input_h * input_w * input_c);
auto data = result.data();
int channelLength = input_h * input_w;
for (int i = 0; i < input_c; ++i)
{
memcpy(data, input_channels[i].data, channelLength * sizeof(float));
data += channelLength;
}
return result;
}

View File

@@ -0,0 +1,62 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#ifndef CALIBRATOR_H
#define CALIBRATOR_H
#include "opencv2/opencv.hpp"
#include "cuda_runtime.h"
#include "NvInfer.h"
#include <vector>
#include <string>
#ifndef CUDA_CHECK
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#endif
namespace nvinfer1 {
class int8EntroyCalibrator : public nvinfer1::IInt8EntropyCalibrator2 {
public:
int8EntroyCalibrator(const int &batchsize,
const int &channels,
const int &height,
const int &width,
const int &letterbox,
const std::string &imgPath,
const std::string &calibTablePath);
virtual ~int8EntroyCalibrator();
int getBatchSize() const override { return batchSize; }
bool getBatch(void *bindings[], const char *names[], int nbBindings) override;
const void *readCalibrationCache(std::size_t &length) override;
void writeCalibrationCache(const void *ptr, std::size_t length) override;
private:
int batchSize;
int inputC;
int inputH;
int inputW;
int letterBox;
std::string calibTablePath;
size_t imageIndex;
size_t inputCount;
std::vector<std::string> imgPaths;
float *batchData{ nullptr };
void *deviceInput{ nullptr };
bool readCache;
std::vector<char> calibrationCache;
};
}
std::vector<float> prepareImage(cv::Mat& img, int input_c, int input_h, int input_w, int letter_box);
#endif //CALIBRATOR_H

View File

@@ -8,79 +8,17 @@
nvinfer1::ILayer* upsampleLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& inputChannels,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
assert(block.at("type") == "upsample");
nvinfer1::Dims inpDims = input->getDimensions();
assert(inpDims.nbDims == 3);
assert(inpDims.d[1] == inpDims.d[2]);
int h = inpDims.d[1];
int w = inpDims.d[2];
int stride = std::stoi(block.at("stride"));
nvinfer1::Dims preDims{3,
{1, stride * h, w},
{nvinfer1::DimensionType::kCHANNEL, nvinfer1::DimensionType::kSPATIAL,
nvinfer1::DimensionType::kSPATIAL}};
int size = stride * h * w;
nvinfer1::Weights preMul{nvinfer1::DataType::kFLOAT, nullptr, size};
float* preWt = new float[size];
for (int i = 0, idx = 0; i < h; ++i)
{
for (int s = 0; s < stride; ++s)
{
for (int j = 0; j < w; ++j, ++idx)
{
preWt[idx] = (i == j) ? 1.0 : 0.0;
}
}
}
preMul.values = preWt;
trtWeights.push_back(preMul);
nvinfer1::IConstantLayer* preM = network->addConstant(preDims, preMul);
assert(preM != nullptr);
std::string preLayerName = "preMul_" + std::to_string(layerIdx);
preM->setName(preLayerName.c_str());
nvinfer1::Dims postDims{3,
{1, h, stride * w},
{nvinfer1::DimensionType::kCHANNEL, nvinfer1::DimensionType::kSPATIAL,
nvinfer1::DimensionType::kSPATIAL}};
size = stride * h * w;
nvinfer1::Weights postMul{nvinfer1::DataType::kFLOAT, nullptr, size};
float* postWt = new float[size];
for (int i = 0, idx = 0; i < h; ++i)
{
for (int j = 0; j < stride * w; ++j, ++idx)
{
postWt[idx] = (j / stride == i) ? 1.0 : 0.0;
}
}
postMul.values = postWt;
trtWeights.push_back(postMul);
nvinfer1::IConstantLayer* post_m = network->addConstant(postDims, postMul);
assert(post_m != nullptr);
std::string postLayerName = "postMul_" + std::to_string(layerIdx);
post_m->setName(postLayerName.c_str());
nvinfer1::IMatrixMultiplyLayer* mm1
= network->addMatrixMultiply(*preM->getOutput(0), nvinfer1::MatrixOperation::kNONE, *input,
nvinfer1::MatrixOperation::kNONE);
assert(mm1 != nullptr);
std::string mm1LayerName = "mm1_" + std::to_string(layerIdx);
mm1->setName(mm1LayerName.c_str());
nvinfer1::IMatrixMultiplyLayer* mm2
= network->addMatrixMultiply(*mm1->getOutput(0), nvinfer1::MatrixOperation::kNONE,
*post_m->getOutput(0), nvinfer1::MatrixOperation::kNONE);
assert(mm2 != nullptr);
std::string mm2LayerName = "mm2_" + std::to_string(layerIdx);
mm2->setName(mm2LayerName.c_str());
return mm2;
nvinfer1::IResizeLayer* resize_layer = network->addResize(*input);
resize_layer->setResizeMode(nvinfer1::ResizeMode::kNEAREST);
float scale[3] = {1, stride, stride};
resize_layer->setScales(scale, 3);
std::string layer_name = "upsample_" + std::to_string(layerIdx);
resize_layer->setName(layer_name.c_str());
return resize_layer;
}

View File

@@ -15,9 +15,6 @@
nvinfer1::ILayer* upsampleLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& inputChannels,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);

View File

@@ -45,9 +45,20 @@ static bool getYoloNetworkInfo (NetworkInfo &networkInfo, const NvDsInferContext
networkInfo.networkType = yoloType;
networkInfo.configFilePath = initParams->customNetworkConfigFilePath;
networkInfo.wtsFilePath = initParams->modelFilePath;
networkInfo.int8CalibPath = initParams->int8CalibrationFilePath;
networkInfo.deviceType = (initParams->useDLA ? "kDLA" : "kGPU");
networkInfo.inputBlobName = "data";
if(initParams->networkMode == 0) {
networkInfo.networkMode = "FP32";
}
else if(initParams->networkMode == 1) {
networkInfo.networkMode = "INT8";
}
else if(initParams->networkMode == 2) {
networkInfo.networkMode = "FP16";
}
if (networkInfo.configFilePath.empty() ||
networkInfo.wtsFilePath.empty()) {
std::cerr << "YOLO config file or weights file is not specified"

View File

@@ -302,7 +302,6 @@ static bool NvDsInferParseYolo(
const uint gridSizeH = layer.inferDims.d[1];
const uint gridSizeW = layer.inferDims.d[2];
const uint stride = DIVUP(networkInfo.width, gridSizeW);
assert(stride == DIVUP(networkInfo.height, gridSizeH));
std::vector<NvDsInferParseObjectInfo> outObjs =
decodeYoloTensor((const float*)(layer.buffer), masks[idx], anchors, gridSizeW, gridSizeH, stride, masks[idx].size(),
@@ -344,7 +343,6 @@ static bool NvDsInferParseYoloV2(
const uint gridSizeH = layer.inferDims.d[1];
const uint gridSizeW = layer.inferDims.d[2];
const uint stride = DIVUP(networkInfo.width, gridSizeW);
assert(stride == DIVUP(networkInfo.height, gridSizeH));
for (auto& anchor : anchors) {
anchor *= stride;
}

View File

@@ -25,6 +25,11 @@
#include "yolo.h"
#include "yoloPlugins.h"
#include <stdlib.h>
#ifdef OPENCV
#include "calibrator.h"
#endif
void orderParams(std::vector<std::vector<int>> *maskVector) {
std::vector<std::vector<int>> maskinput = *maskVector;
@@ -45,6 +50,8 @@ Yolo::Yolo(const NetworkInfo& networkInfo)
: m_NetworkType(networkInfo.networkType), // YOLO type
m_ConfigFilePath(networkInfo.configFilePath), // YOLO cfg
m_WtsFilePath(networkInfo.wtsFilePath), // YOLO weights
m_Int8CalibPath(networkInfo.int8CalibPath), // INT8 calibration path
m_NetworkMode(networkInfo.networkMode), // FP32, INT8, FP16
m_DeviceType(networkInfo.deviceType), // kDLA, kGPU
m_InputBlobName(networkInfo.inputBlobName), // data
m_InputH(0),
@@ -62,6 +69,38 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
{
assert (builder);
m_ConfigBlocks = parseConfigFile(m_ConfigFilePath);
parseConfigBlocks();
orderParams(&m_OutputMasks);
if (m_NetworkMode == "INT8" && !fileExists(m_Int8CalibPath)) {
assert(builder->platformHasFastInt8());
#ifdef OPENCV
std::string calib_image_list;
int calib_batch_size;
if (getenv("INT8_CALIB_IMG_PATH")) {
calib_image_list = getenv("INT8_CALIB_IMG_PATH");
}
else {
std::cerr << "INT8_CALIB_IMG_PATH not set" << std::endl;
std::abort();
}
if (getenv("INT8_CALIB_BATCH_SIZE")) {
calib_batch_size = std::stoi(getenv("INT8_CALIB_BATCH_SIZE"));
}
else {
std::cerr << "INT8_CALIB_BATCH_SIZE not set" << std::endl;
std::abort();
}
nvinfer1::int8EntroyCalibrator *calibrator = new nvinfer1::int8EntroyCalibrator(calib_batch_size, m_InputC, m_InputH, m_InputW, m_LetterBox, calib_image_list, m_Int8CalibPath);
builder->setInt8Mode(true);
builder->setInt8Calibrator(calibrator);
#else
std::cerr << "OpenCV is required to run INT8 calibrator" << std::endl;
std::abort();
#endif
}
std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType);
std::vector<nvinfer1::Weights> trtWeights;
@@ -71,8 +110,12 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
return nullptr;
}
// Build the engine
std::cout << "Building the TensorRT Engine" << std::endl;
if (m_LetterBox == 1) {
std::cout << "\nNOTE: letter_box is set in cfg file, make sure to set maintain-aspect-ratio=1 in config_infer file to get better accuracy\n" << std::endl;
}
nvinfer1::ICudaEngine * engine = builder->buildCudaEngine(*network);
if (engine) {
std::cout << "Building complete\n" << std::endl;
@@ -80,7 +123,6 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
std::cerr << "Building engine failed\n" << std::endl;
}
// destroy
network->destroy();
return engine;
}
@@ -88,12 +130,7 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
NvDsInferStatus Yolo::parseModel(nvinfer1::INetworkDefinition& network) {
destroyNetworkUtils();
m_ConfigBlocks = parseConfigFile(m_ConfigFilePath);
parseConfigBlocks();
orderParams(&m_OutputMasks);
std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType);
// build yolo network
std::cout << "Building YOLO network" << std::endl;
NvDsInferStatus status = buildYoloNetwork(weights, network);
@@ -121,9 +158,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(
std::vector<nvinfer1::ITensor*> tensorOutputs;
uint outputTensorCount = 0;
// build the network using the network API
for (uint i = 0; i < m_ConfigBlocks.size(); ++i) {
// check if num. of channels is correct
assert(getNumChannels(previous) == channels);
std::string layerIndex = "(" + std::to_string(tensorOutputs.size()) + ")";
@@ -192,7 +227,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(
else if (m_ConfigBlocks.at(i).at("type") == "upsample") {
std::string inputVol = dimsToString(previous->getDimensions());
nvinfer1::ILayer* out = upsampleLayer(i - 1, m_ConfigBlocks[i], weights, m_TrtWeights, channels, previous, &network);
nvinfer1::ILayer* out = upsampleLayer(i - 1, m_ConfigBlocks[i], previous, &network);
previous = out->getOutput(0);
assert(previous != nullptr);
std::string outputVol = dimsToString(previous->getDimensions());
@@ -212,12 +247,12 @@ NvDsInferStatus Yolo::buildYoloNetwork(
else if (m_ConfigBlocks.at(i).at("type") == "yolo") {
nvinfer1::Dims prevTensorDims = previous->getDimensions();
assert(prevTensorDims.d[1] == prevTensorDims.d[2]);
TensorInfo& curYoloTensor = m_OutputTensors.at(outputTensorCount);
curYoloTensor.gridSize = prevTensorDims.d[1];
curYoloTensor.stride = m_InputW / curYoloTensor.gridSize;
m_OutputTensors.at(outputTensorCount).volume = curYoloTensor.gridSize
* curYoloTensor.gridSize
curYoloTensor.gridSizeY = prevTensorDims.d[1];
curYoloTensor.gridSizeX = prevTensorDims.d[2];
curYoloTensor.stride = m_InputH / curYoloTensor.gridSizeY;
m_OutputTensors.at(outputTensorCount).volume = curYoloTensor.gridSizeY
* curYoloTensor.gridSizeX
* (curYoloTensor.numBBoxes * (5 + curYoloTensor.numClasses));
std::string layerName = "yolo_" + std::to_string(i);
curYoloTensor.blobName = layerName;
@@ -236,7 +271,8 @@ NvDsInferStatus Yolo::buildYoloNetwork(
nvinfer1::IPluginV2* yoloPlugin
= new YoloLayer(m_OutputTensors.at(outputTensorCount).numBBoxes,
m_OutputTensors.at(outputTensorCount).numClasses,
m_OutputTensors.at(outputTensorCount).gridSize,
m_OutputTensors.at(outputTensorCount).gridSizeX,
m_OutputTensors.at(outputTensorCount).gridSizeY,
1, new_coords, scale_x_y, beta_nms,
curYoloTensor.anchors,
m_OutputMasks);
@@ -260,12 +296,12 @@ NvDsInferStatus Yolo::buildYoloNetwork(
//YOLOv2 support
else if (m_ConfigBlocks.at(i).at("type") == "region") {
nvinfer1::Dims prevTensorDims = previous->getDimensions();
assert(prevTensorDims.d[1] == prevTensorDims.d[2]);
TensorInfo& curRegionTensor = m_OutputTensors.at(outputTensorCount);
curRegionTensor.gridSize = prevTensorDims.d[1];
curRegionTensor.stride = m_InputW / curRegionTensor.gridSize;
m_OutputTensors.at(outputTensorCount).volume = curRegionTensor.gridSize
* curRegionTensor.gridSize
curRegionTensor.gridSizeY = prevTensorDims.d[1];
curRegionTensor.gridSizeX = prevTensorDims.d[2];
curRegionTensor.stride = m_InputH / curRegionTensor.gridSizeY;
m_OutputTensors.at(outputTensorCount).volume = curRegionTensor.gridSizeY
* curRegionTensor.gridSizeX
* (curRegionTensor.numBBoxes * (5 + curRegionTensor.numClasses));
std::string layerName = "region_" + std::to_string(i);
curRegionTensor.blobName = layerName;
@@ -273,7 +309,8 @@ NvDsInferStatus Yolo::buildYoloNetwork(
nvinfer1::IPluginV2* regionPlugin
= new YoloLayer(curRegionTensor.numBBoxes,
curRegionTensor.numClasses,
curRegionTensor.gridSize,
curRegionTensor.gridSizeX,
curRegionTensor.gridSizeY,
0, 0, 1.0, 0,
curRegionTensor.anchors,
mask);
@@ -387,8 +424,14 @@ void Yolo::parseConfigBlocks()
m_InputH = std::stoul(block.at("height"));
m_InputW = std::stoul(block.at("width"));
m_InputC = std::stoul(block.at("channels"));
assert(m_InputW == m_InputH);
m_InputSize = m_InputC * m_InputH * m_InputW;
if (block.find("letter_box") != block.end()) {
m_LetterBox = std::stoul(block.at("letter_box"));
}
else {
m_LetterBox = 0;
}
}
else if ((block.at("type") == "region") || (block.at("type") == "yolo"))
{
@@ -456,10 +499,9 @@ void Yolo::parseConfigBlocks()
}
void Yolo::destroyNetworkUtils() {
// deallocate the weights
for (uint i = 0; i < m_TrtWeights.size(); ++i) {
if (m_TrtWeights[i].count > 0)
free(const_cast<void*>(m_TrtWeights[i].values));
}
m_TrtWeights.clear();
}
}

View File

@@ -40,6 +40,8 @@ struct NetworkInfo
std::string networkType;
std::string configFilePath;
std::string wtsFilePath;
std::string int8CalibPath;
std::string networkMode;
std::string deviceType;
std::string inputBlobName;
};
@@ -48,7 +50,8 @@ struct TensorInfo
{
std::string blobName;
uint stride{0};
uint gridSize{0};
uint gridSizeY{0};
uint gridSizeX{0};
uint numClasses{0};
uint numBBoxes{0};
uint64_t volume{0};
@@ -75,6 +78,8 @@ protected:
const std::string m_NetworkType;
const std::string m_ConfigFilePath;
const std::string m_WtsFilePath;
const std::string m_Int8CalibPath;
const std::string m_NetworkMode;
const std::string m_DeviceType;
const std::string m_InputBlobName;
std::vector<TensorInfo> m_OutputTensors;
@@ -84,6 +89,7 @@ protected:
uint m_InputW;
uint m_InputC;
uint64_t m_InputSize;
uint m_LetterBox;
std::vector<nvinfer1::Weights> m_TrtWeights;

View File

@@ -20,20 +20,20 @@
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
__global__ void gpuYoloLayer(const float* input, float* output, const uint gridSize, const uint numOutputClasses,
__global__ void gpuYoloLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
const uint numBBoxes, const uint new_coords, const float scale_x_y)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
uint z_id = blockIdx.z * blockDim.z + threadIdx.z;
if ((x_id >= gridSize) || (y_id >= gridSize) || (z_id >= numBBoxes))
if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes))
{
return;
}
const int numGridCells = gridSize * gridSize;
const int bbindex = y_id * gridSize + x_id;
const int numGridCells = gridSizeX * gridSizeY;
const int bbindex = y_id * gridSizeX + x_id;
float alpha = scale_x_y;
float beta = -0.5 * (scale_x_y - 1);
@@ -84,20 +84,20 @@ __global__ void gpuYoloLayer(const float* input, float* output, const uint gridS
}
}
__global__ void gpuRegionLayer(const float* input, float* output, const uint gridSize, const uint numOutputClasses,
__global__ void gpuRegionLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
const uint numBBoxes)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
uint z_id = blockIdx.z * blockDim.z + threadIdx.z;
if ((x_id >= gridSize) || (y_id >= gridSize) || (z_id >= numBBoxes))
if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes))
{
return;
}
const int numGridCells = gridSize * gridSize;
const int bbindex = y_id * gridSize + x_id;
const int numGridCells = gridSizeX * gridSizeY;
const int bbindex = y_id * gridSizeX + x_id;
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]);
@@ -132,24 +132,24 @@ __global__ void gpuRegionLayer(const float* input, float* output, const uint gri
}
}
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSize,
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
const uint& numOutputClasses, const uint& numBBoxes,
uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType);
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSize,
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
const uint& numOutputClasses, const uint& numBBoxes,
uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType)
{
dim3 threads_per_block(16, 16, 4);
dim3 number_of_blocks((gridSize / threads_per_block.x) + 1,
(gridSize / threads_per_block.y) + 1,
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
(gridSizeY / threads_per_block.y) + 1,
(numBBoxes / threads_per_block.z) + 1);
if (modelType == 1) {
for (unsigned int batch = 0; batch < batchSize; ++batch)
{
gpuYoloLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
reinterpret_cast<const float*>(input) + (batch * outputSize),
reinterpret_cast<float*>(output) + (batch * outputSize), gridSize, numOutputClasses,
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
numBBoxes, modelCoords, modelScale);
}
}
@@ -158,7 +158,7 @@ cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize
{
gpuRegionLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
reinterpret_cast<const float*>(input) + (batch * outputSize),
reinterpret_cast<float*>(output) + (batch * outputSize), gridSize, numOutputClasses,
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
numBBoxes);
}
}

View File

@@ -52,7 +52,7 @@ void read(const char*& buffer, T& val)
cudaError_t cudaYoloLayer (
const void* input, void* output, const uint& batchSize,
const uint& gridSize, const uint& numOutputClasses,
const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType);
YoloLayer::YoloLayer (const void* data, size_t length)
@@ -60,7 +60,8 @@ YoloLayer::YoloLayer (const void* data, size_t length)
const char *d = static_cast<const char*>(data);
read(d, m_NumBoxes);
read(d, m_NumClasses);
read(d, m_GridSize);
read(d, m_GridSizeX);
read(d, m_GridSizeY);
read(d, m_OutputSize);
read(d, m_type);
@@ -94,10 +95,11 @@ YoloLayer::YoloLayer (const void* data, size_t length)
};
YoloLayer::YoloLayer (
const uint& numBoxes, const uint& numClasses, const uint& gridSize, const uint model_type, const uint new_coords, const float scale_x_y, const float beta_nms, const std::vector<float> anchors, std::vector<std::vector<int>> mask) :
const uint& numBoxes, const uint& numClasses, const uint& gridSizeX, const uint& gridSizeY, const uint model_type, const uint new_coords, const float scale_x_y, const float beta_nms, const std::vector<float> anchors, std::vector<std::vector<int>> mask) :
m_NumBoxes(numBoxes),
m_NumClasses(numClasses),
m_GridSize(gridSize),
m_GridSizeX(gridSizeX),
m_GridSizeY(gridSizeY),
m_type(model_type),
m_new_coords(new_coords),
m_scale_x_y(scale_x_y),
@@ -107,8 +109,9 @@ YoloLayer::YoloLayer (
{
assert(m_NumBoxes > 0);
assert(m_NumClasses > 0);
assert(m_GridSize > 0);
m_OutputSize = m_GridSize * m_GridSize * (m_NumBoxes * (4 + 1 + m_NumClasses));
assert(m_GridSizeX > 0);
assert(m_GridSizeY > 0);
m_OutputSize = m_GridSizeX * m_GridSizeY * (m_NumBoxes * (4 + 1 + m_NumClasses));
};
nvinfer1::Dims
@@ -142,7 +145,7 @@ int YoloLayer::enqueue(
cudaStream_t stream)
{
CHECK(cudaYoloLayer(
inputs[0], outputs[0], batchSize, m_GridSize, m_NumClasses, m_NumBoxes,
inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
m_OutputSize, stream, m_new_coords, m_scale_x_y, m_type));
return 0;
}
@@ -161,7 +164,7 @@ size_t YoloLayer::getSerializationSize() const
}
}
return sizeof(m_NumBoxes) + sizeof(m_NumClasses) + sizeof(m_GridSize) + sizeof(m_OutputSize) + sizeof(m_type)
return sizeof(m_NumBoxes) + sizeof(m_NumClasses) + sizeof(m_GridSizeX) + sizeof(m_GridSizeY) + sizeof(m_OutputSize) + sizeof(m_type)
+ sizeof(m_new_coords) + sizeof(m_scale_x_y) + sizeof(m_beta_nms) + anchorsSum * sizeof(float) + maskSum * sizeof(int);
}
@@ -170,7 +173,8 @@ void YoloLayer::serialize(void* buffer) const
char *d = static_cast<char*>(buffer);
write(d, m_NumBoxes);
write(d, m_NumClasses);
write(d, m_GridSize);
write(d, m_GridSizeX);
write(d, m_GridSizeY);
write(d, m_OutputSize);
write(d, m_type);
@@ -199,7 +203,7 @@ void YoloLayer::serialize(void* buffer) const
nvinfer1::IPluginV2* YoloLayer::clone() const
{
return new YoloLayer (m_NumBoxes, m_NumClasses, m_GridSize, m_type, m_new_coords, m_scale_x_y, m_beta_nms, m_Anchors, m_Mask);
return new YoloLayer (m_NumBoxes, m_NumClasses, m_GridSizeX, m_GridSizeY, m_type, m_new_coords, m_scale_x_y, m_beta_nms, m_Anchors, m_Mask);
}
REGISTER_TENSORRT_PLUGIN(YoloLayerPluginCreator);

View File

@@ -56,7 +56,7 @@ class YoloLayer : public nvinfer1::IPluginV2
{
public:
YoloLayer (const void* data, size_t length);
YoloLayer (const uint& numBoxes, const uint& numClasses, const uint& gridSize,
YoloLayer (const uint& numBoxes, const uint& numClasses, const uint& gridSizeX, const uint& gridSizeY,
const uint model_type, const uint new_coords, const float scale_x_y, const float beta_nms,
const std::vector<float> anchors, const std::vector<std::vector<int>> mask);
const char* getPluginType () const override { return YOLOLAYER_PLUGIN_NAME; }
@@ -96,7 +96,8 @@ public:
private:
uint m_NumBoxes {0};
uint m_NumClasses {0};
uint m_GridSize {0};
uint m_GridSizeX {0};
uint m_GridSizeY {0};
uint64_t m_OutputSize {0};
std::string m_Namespace {""};
@@ -152,4 +153,4 @@ extern float kBETA_NMS;
extern std::vector<float> kANCHORS;
extern std::vector<std::vector<int>> kMASK;
#endif // __YOLO_PLUGINS__
#endif // __YOLO_PLUGINS__