diff --git a/YOLOv5.md b/YOLOv5.md index 5303f2f..71f312c 100644 --- a/YOLOv5.md +++ b/YOLOv5.md @@ -1,9 +1,9 @@ # YOLOv5 -NVIDIA DeepStream SDK 5.1 configuration for YOLOv5 models +NVIDIA DeepStream SDK 5.1 configuration for YOLOv5 5.0 models -Thanks [DanaHan](https://github.com/DanaHan/Yolov5-in-Deepstream-5.0), [wang-xinyu](https://github.com/wang-xinyu/tensorrtx) and [Ultralytics](https://github.com/ultralytics/yolov5) +Thanks [wang-xinyu](https://github.com/wang-xinyu/tensorrtx) and [Ultralytics](https://github.com/ultralytics/yolov5) -Supported version: YOLOv5 3.0/3.1 +Supported version: YOLOv5 5.0 ## @@ -16,53 +16,15 @@ Supported version: YOLOv5 3.0/3.1 ## ### Requirements -* Python3 -``` -sudo apt-get install python3 python3-dev python3-pip -pip3 install --upgrade pip -``` +* [TensorRTX](https://github.com/wang-xinyu/tensorrtx/blob/master/tutorials/install.md) -* OpenCV Python -``` -sudo apt-get install libopencv-dev -pip3 install opencv-python -``` - -* Matplotlib -``` -pip3 install matplotlib -``` +* [Ultralytics](https://github.com/ultralytics/yolov5/blob/master/requirements.txt) * Matplotlib (for Jetson plataform) ``` sudo apt-get install python3-matplotlib ``` -* Scipy -``` -pip3 install scipy -``` - -* tqdm -``` -pip3 install tqdm -``` - -* Pandas -``` -pip3 install pandas -``` - -* seaborn -``` -pip3 install seaborn -``` - -* PyTorch -``` -pip3 install torch torchvision -``` - * PyTorch (for Jetson plataform) ``` wget https://nvidia.box.com/shared/static/9eptse6jyly1ggt9axbja2yrmj6pbarc.whl -O torch-1.6.0-cp36-cp36m-linux_aarch64.whl @@ -84,20 +46,13 @@ sudo python3 setup.py install ### Convert PyTorch model to wts file 1. Download repositories ``` -git clone https://github.com/DanaHan/Yolov5-in-Deepstream-5.0.git yolov5converter git clone https://github.com/wang-xinyu/tensorrtx.git git clone https://github.com/ultralytics/yolov5.git ``` -Note: checkout TensorRTX repo to 3.0/3.1 YOLOv5 version +2. Download latest YoloV5 (YOLOv5s, YOLOv5m, YOLOv5l or YOLOv5x) weights to yolov5 folder (example for YOLOv5s) ``` -cd tensorrtx -git checkout '6d0f5cb' -``` - -2. Download latest YoloV5 (YOLOv5s, YOLOv5m, YOLOv5l or YOLOv5x) weights to yolov5/weights directory (example for YOLOv5s) -``` -wget https://github.com/ultralytics/yolov5/releases/download/v3.1/yolov5s.pt -P yolov5/weights/ +wget https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5s.pt -P yolov5/ ``` 3. Copy gen_wts.py file (from tensorrtx/yolov5 folder) to yolov5 (ultralytics) folder @@ -108,36 +63,15 @@ cp tensorrtx/yolov5/gen_wts.py yolov5/gen_wts.py 4. Generate wts file ``` cd yolov5 -python3 gen_wts.py +python3 gen_wts.py yolov5s.pt ``` yolov5s.wts file will be generated in yolov5 folder -
- -Note: if you want to generate wts file to another YOLOv5 model (YOLOv5m, YOLOv5l or YOLOv5x), edit get_wts.py file changing yolov5s to your model name -``` -model = torch.load('weights/yolov5s.pt', map_location=device)['model'].float() # load to FP32 -model.to(device).eval() - -f = open('yolov5s.wts', 'w') -``` - ## ### Convert wts file to TensorRT model -1. Replace yololayer files from tensorrtx/yolov5 folder to yololayer and hardswish files from yolov5converter -``` -mv yolov5converter/yololayer.cu tensorrtx/yolov5/yololayer.cu -mv yolov5converter/yololayer.h tensorrtx/yolov5/yololayer.h -``` - -2. Move generated yolov5s.wts file to tensorrtx/yolov5 folder (example for YOLOv5s) -``` -cp yolov5/yolov5s.wts tensorrtx/yolov5/yolov5s.wts -``` - -3. Build tensorrtx/yolov5 +1. Build tensorrtx/yolov5 ``` cd tensorrtx/yolov5 mkdir build @@ -146,12 +80,17 @@ cmake .. make ``` -4. Convert to TensorRT model (yolov5s.engine file will be generated in tensorrtx/yolov5/build folder) +2. Move generated yolov5s.wts file to tensorrtx/yolov5 folder (example for YOLOv5s) ``` -sudo ./yolov5 -s +cp yolov5/yolov5s.wts tensorrtx/yolov5/build/yolov5s.wts ``` -5. Create a custom yolo folder and copy generated files (example for YOLOv5s) +3. Convert to TensorRT model (yolov5s.engine file will be generated in tensorrtx/yolov5/build folder) +``` +sudo ./yolov5 -s yolov5s.wts yolov5s.engine s +``` + +4. Create a custom yolo folder and copy generated file (example for YOLOv5s) ``` mkdir /opt/nvidia/deepstream/deepstream-5.1/sources/yolo cp yolov5s.engine /opt/nvidia/deepstream/deepstream-5.1/sources/yolo/yolov5s.engine @@ -159,15 +98,13 @@ cp yolov5s.engine /opt/nvidia/deepstream/deepstream-5.1/sources/yolo/yolov5s.eng
-Note: by default, yolov5 script generate model with batch size = 1, FP16 mode and s model. +Note: by default, yolov5 script generate model with batch size = 1 and FP16 mode. ``` -#define USE_FP16 // comment out this if want to use FP32 +#define USE_FP32 // set USE_INT8 or USE_FP16 or USE_FP32 #define DEVICE 0 // GPU id #define NMS_THRESH 0.4 #define CONF_THRESH 0.5 #define BATCH_SIZE 1 - -#define NET s // s m l x ``` Edit yolov5.cpp file before compile if you want to change this parameters. @@ -179,7 +116,7 @@ Edit yolov5.cpp file before compile if you want to change this parameters. sudo chmod -R 777 /opt/nvidia/deepstream/deepstream-5.1/sources/ ``` -2. Donwload [my external/yolov5 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5) and move files to created yolo folder +2. Donwload [my external/yolov5-5.0 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5-5.0) and move files to created yolo folder 3. Compile lib @@ -198,7 +135,7 @@ CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo ## ### Testing model -Use my edited [deepstream_app_config.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5/deepstream_app_config.txt) and [config_infer_primary.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5/config_infer_primary.txt) files available in [my external/yolov5 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5) +Use my edited [deepstream_app_config.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5-5.0/deepstream_app_config.txt) and [config_infer_primary.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5-5.0/config_infer_primary.txt) files available in [my external/yolov5-5.0 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5-5.0) Run command ``` diff --git a/external/yolov5/config_infer_primary.txt b/external/yolov5-5.0/config_infer_primary.txt similarity index 100% rename from external/yolov5/config_infer_primary.txt rename to external/yolov5-5.0/config_infer_primary.txt diff --git a/external/yolov5/deepstream_app_config.txt b/external/yolov5-5.0/deepstream_app_config.txt similarity index 100% rename from external/yolov5/deepstream_app_config.txt rename to external/yolov5-5.0/deepstream_app_config.txt diff --git a/external/yolov5/labels.txt b/external/yolov5-5.0/labels.txt similarity index 100% rename from external/yolov5/labels.txt rename to external/yolov5-5.0/labels.txt diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/Makefile b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/Makefile similarity index 100% rename from external/yolov5/nvdsinfer_custom_impl_Yolo/Makefile rename to external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/Makefile diff --git a/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/cuda_utils.h b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/cuda_utils.h new file mode 100644 index 0000000..8fbd319 --- /dev/null +++ b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/cuda_utils.h @@ -0,0 +1,18 @@ +#ifndef TRTX_CUDA_UTILS_H_ +#define TRTX_CUDA_UTILS_H_ + +#include + +#ifndef CUDA_CHECK +#define CUDA_CHECK(callstr)\ + {\ + cudaError_t error_code = callstr;\ + if (error_code != cudaSuccess) {\ + std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\ + assert(0);\ + }\ + } +#endif // CUDA_CHECK + +#endif // TRTX_CUDA_UTILS_H_ + diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp similarity index 100% rename from external/yolov5/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp rename to external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.cu b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.cu similarity index 55% rename from external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.cu rename to external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.cu index a2e6ba3..525bf8d 100644 --- a/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.cu +++ b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.cu @@ -1,33 +1,55 @@ #include +#include +#include #include "yololayer.h" -#include "utils.h" +#include "cuda_utils.h" + +namespace Tn +{ + template + void write(char*& buffer, const T& val) + { + *reinterpret_cast(buffer) = val; + buffer += sizeof(T); + } + + template + void read(const char*& buffer, T& val) + { + val = *reinterpret_cast(buffer); + buffer += sizeof(T); + } +} using namespace Yolo; namespace nvinfer1 { - YoloLayerPlugin::YoloLayerPlugin() + YoloLayerPlugin::YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector& vYoloKernel) { - mClassCount = CLASS_NUM; - mYoloKernel.clear(); - mYoloKernel.push_back(yolo1); - mYoloKernel.push_back(yolo2); - mYoloKernel.push_back(yolo3); - - mKernelCount = mYoloKernel.size(); + mClassCount = classCount; + mYoloV5NetWidth = netWidth; + mYoloV5NetHeight = netHeight; + mMaxOutObject = maxOut; + mYoloKernel = vYoloKernel; + mKernelCount = vYoloKernel.size(); CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); - size_t AnchorLen = sizeof(float)* CHECK_COUNT*2; - for(int ii = 0; ii < mKernelCount; ii ++) + size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2; + for (int ii = 0; ii < mKernelCount; ii++) { - CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen)); + CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen)); const auto& yolo = mYoloKernel[ii]; CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); } } - YoloLayerPlugin::~YoloLayerPlugin() { + for (int ii = 0; ii < mKernelCount; ii++) + { + CUDA_CHECK(cudaFree(mAnchor[ii])); + } + CUDA_CHECK(cudaFreeHost(mAnchor)); } // create the plugin at runtime from a byte stream @@ -38,20 +60,21 @@ namespace nvinfer1 read(d, mClassCount); read(d, mThreadCount); read(d, mKernelCount); + read(d, mYoloV5NetWidth); + read(d, mYoloV5NetHeight); + read(d, mMaxOutObject); mYoloKernel.resize(mKernelCount); - auto kernelSize = mKernelCount*sizeof(YoloKernel); - memcpy(mYoloKernel.data(),d,kernelSize); + auto kernelSize = mKernelCount * sizeof(YoloKernel); + memcpy(mYoloKernel.data(), d, kernelSize); d += kernelSize; - CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); - size_t AnchorLen = sizeof(float)* CHECK_COUNT*2; - for(int ii = 0; ii < mKernelCount; ii ++) + size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2; + for (int ii = 0; ii < mKernelCount; ii++) { - CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen)); + CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen)); const auto& yolo = mYoloKernel[ii]; CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); } - assert(d == a + length); } @@ -62,27 +85,30 @@ namespace nvinfer1 write(d, mClassCount); write(d, mThreadCount); write(d, mKernelCount); - auto kernelSize = mKernelCount*sizeof(YoloKernel); - memcpy(d,mYoloKernel.data(),kernelSize); + write(d, mYoloV5NetWidth); + write(d, mYoloV5NetHeight); + write(d, mMaxOutObject); + auto kernelSize = mKernelCount * sizeof(YoloKernel); + memcpy(d, mYoloKernel.data(), kernelSize); d += kernelSize; assert(d == a + getSerializationSize()); } - + size_t YoloLayerPlugin::getSerializationSize() const - { - return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size(); + { + return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size() + sizeof(mYoloV5NetWidth) + sizeof(mYoloV5NetHeight) + sizeof(mMaxOutObject); } int YoloLayerPlugin::initialize() - { + { return 0; } - + Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) { //output the result to channel - int totalsize = MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float); + int totalsize = mMaxOutObject * sizeof(Detection) / sizeof(float); return Dims3(totalsize + 1, 1, 1); } @@ -146,26 +172,27 @@ namespace nvinfer1 // Clone the plugin IPluginV2IOExt* YoloLayerPlugin::clone() const { - YoloLayerPlugin *p = new YoloLayerPlugin(); + YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, mYoloKernel); p->setPluginNamespace(mPluginNamespace); return p; } - __device__ float Logist(float data){ return 1.0f / (1.0f + expf(-data)); }; + __device__ float Logist(float data) { return 1.0f / (1.0f + expf(-data)); }; + + __global__ void CalDetection(const float *input, float *output, int noElements, + const int netwidth, const int netheight, int maxoutobject, int yoloWidth, int yoloHeight, const float anchors[CHECK_COUNT * 2], int classes, int outputElem) + { - __global__ void CalDetection(const float *input, float *output,int noElements, - int yoloWidth,int yoloHeight,const float anchors[CHECK_COUNT*2],int classes,int outputElem) { - int idx = threadIdx.x + blockDim.x * blockIdx.x; if (idx >= noElements) return; int total_grid = yoloWidth * yoloHeight; int bnIdx = idx / total_grid; - idx = idx - total_grid*bnIdx; + idx = idx - total_grid * bnIdx; int info_len_i = 5 + classes; const float* curInput = input + bnIdx * (info_len_i * total_grid * CHECK_COUNT); - for (int k = 0; k < 3; ++k) { + for (int k = 0; k < CHECK_COUNT; ++k) { float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]); if (box_prob < IGNORE_THRESH) continue; int class_id = 0; @@ -177,51 +204,57 @@ namespace nvinfer1 class_id = i - 5; } } - float *res_count = output + bnIdx*outputElem; + float *res_count = output + bnIdx * outputElem; int count = (int)atomicAdd(res_count, 1); - if (count >= MAX_OUTPUT_BBOX_COUNT) return; - char* data = (char *)res_count + sizeof(float) + count * sizeof(Detection); - Detection* det = (Detection*)(data); + if (count >= maxoutobject) return; + char *data = (char*)res_count + sizeof(float) + count * sizeof(Detection); + Detection *det = (Detection*)(data); int row = idx / yoloWidth; int col = idx % yoloWidth; //Location - det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * INPUT_W / yoloWidth; - det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * INPUT_H / yoloHeight; + // pytorch: + // y = x[i].sigmoid() + // y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy + // y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + // X: (sigmoid(tx) + cx)/FeaturemapW * netwidth + det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * netwidth / yoloWidth; + det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * netheight / yoloHeight; + + // W: (Pw * e^tw) / FeaturemapW * netwidth + // v5: https://github.com/ultralytics/yolov5/issues/471 det->bbox[2] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 2 * total_grid]); - det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2*k]; + det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2 * k]; det->bbox[3] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 3 * total_grid]); - det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2*k + 1]; + det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2 * k + 1]; det->conf = box_prob * max_cls_prob; det->class_id = class_id; } } - void YoloLayerPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) { - - int outputElem = 1 + MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float); - - for(int idx = 0 ; idx < batchSize; ++idx) { - CUDA_CHECK(cudaMemset(output + idx*outputElem, 0, sizeof(float))); + void YoloLayerPlugin::forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize) + { + int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float); + for (int idx = 0; idx < batchSize; ++idx) { + CUDA_CHECK(cudaMemset(output + idx * outputElem, 0, sizeof(float))); } int numElem = 0; - for (unsigned int i = 0; i < mYoloKernel.size(); ++i) - { + for (unsigned int i = 0; i < mYoloKernel.size(); ++i) { const auto& yolo = mYoloKernel[i]; - numElem = yolo.width*yolo.height*batchSize; - if (numElem < mThreadCount) - mThreadCount = numElem; - CalDetection<<< (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream>>> - (inputs[i], output, numElem, yolo.width, yolo.height, (float *)mAnchor[i], mClassCount, outputElem); - } + numElem = yolo.width * yolo.height * batchSize; + if (numElem < mThreadCount) mThreadCount = numElem; + //printf("Net: %d %d \n", mYoloV5NetWidth, mYoloV5NetHeight); + CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> > + (inputs[i], output, numElem, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, yolo.width, yolo.height, (float*)mAnchor[i], mClassCount, outputElem); + } } - int YoloLayerPlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) + int YoloLayerPlugin::enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) { - forwardGpu((const float *const *)inputs, (float*)outputs[0], stream, batchSize); + forwardGpu((const float* const*)inputs, (float*)outputs[0], stream, batchSize); return 0; } @@ -238,22 +271,32 @@ namespace nvinfer1 const char* YoloPluginCreator::getPluginName() const { - return "YoloLayer_TRT"; + return "YoloLayer_TRT"; } const char* YoloPluginCreator::getPluginVersion() const { - return "1"; + return "1"; } const PluginFieldCollection* YoloPluginCreator::getFieldNames() { - return &mFC; + return &mFC; } IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) { - YoloLayerPlugin* obj = new YoloLayerPlugin(); + assert(fc->nbFields == 2); + assert(strcmp(fc->fields[0].name, "netinfo") == 0); + assert(strcmp(fc->fields[1].name, "kernels") == 0); + int *p_netinfo = (int*)(fc->fields[0].data); + int class_count = p_netinfo[0]; + int input_w = p_netinfo[1]; + int input_h = p_netinfo[2]; + int max_output_object_count = p_netinfo[3]; + std::vector kernels(fc->fields[1].length); + memcpy(&kernels[0], fc->fields[1].data, kernels.size() * sizeof(Yolo::YoloKernel)); + YoloLayerPlugin* obj = new YoloLayerPlugin(class_count, input_w, input_h, max_output_object_count, kernels); obj->setPluginNamespace(mNamespace.c_str()); return obj; } @@ -261,10 +304,10 @@ namespace nvinfer1 IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) { // This object will be deleted when the network is destroyed, which will - // call MishPlugin::destroy() + // call YoloLayerPlugin::destroy() YoloLayerPlugin* obj = new YoloLayerPlugin(serialData, serialLength); obj->setPluginNamespace(mNamespace.c_str()); return obj; } - } + diff --git a/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.h b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.h new file mode 100644 index 0000000..49f6474 --- /dev/null +++ b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.h @@ -0,0 +1,137 @@ +#ifndef _YOLO_LAYER_H +#define _YOLO_LAYER_H + +#include +#include +#include "NvInfer.h" + +namespace Yolo +{ + static constexpr int CHECK_COUNT = 3; + static constexpr float IGNORE_THRESH = 0.1f; + struct YoloKernel + { + int width; + int height; + float anchors[CHECK_COUNT * 2]; + }; + static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000; + static constexpr int CLASS_NUM = 80; + static constexpr int INPUT_H = 640; // yolov5's input height and width must be divisible by 32. + static constexpr int INPUT_W = 640; + + static constexpr int LOCATIONS = 4; + struct alignas(float) Detection { + //center_x center_y w h + float bbox[LOCATIONS]; + float conf; // bbox_conf * cls_conf + float class_id; + }; +} + +namespace nvinfer1 +{ + class YoloLayerPlugin : public IPluginV2IOExt + { + public: + YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector& vYoloKernel); + YoloLayerPlugin(const void* data, size_t length); + ~YoloLayerPlugin(); + + int getNbOutputs() const override + { + return 1; + } + + Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override; + + int initialize() override; + + virtual void terminate() override {}; + + virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } + + virtual int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override; + + virtual size_t getSerializationSize() const override; + + virtual void serialize(void* buffer) const override; + + bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override { + return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT; + } + + const char* getPluginType() const override; + + const char* getPluginVersion() const override; + + void destroy() override; + + IPluginV2IOExt* clone() const override; + + void setPluginNamespace(const char* pluginNamespace) override; + + const char* getPluginNamespace() const override; + + DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override; + + bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override; + + bool canBroadcastInputAcrossBatch(int inputIndex) const override; + + void attachToContext( + cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override; + + void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override; + + void detachFromContext() override; + + private: + void forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize = 1); + int mThreadCount = 256; + const char* mPluginNamespace; + int mKernelCount; + int mClassCount; + int mYoloV5NetWidth; + int mYoloV5NetHeight; + int mMaxOutObject; + std::vector mYoloKernel; + void** mAnchor; + }; + + class YoloPluginCreator : public IPluginCreator + { + public: + YoloPluginCreator(); + + ~YoloPluginCreator() override = default; + + const char* getPluginName() const override; + + const char* getPluginVersion() const override; + + const PluginFieldCollection* getFieldNames() override; + + IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override; + + IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override; + + void setPluginNamespace(const char* libNamespace) override + { + mNamespace = libNamespace; + } + + const char* getPluginNamespace() const override + { + return mNamespace.c_str(); + } + + private: + std::string mNamespace; + static PluginFieldCollection mFC; + static std::vector mPluginAttributes; + }; + REGISTER_TENSORRT_PLUGIN(YoloPluginCreator); +}; + +#endif diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/utils.h b/external/yolov5/nvdsinfer_custom_impl_Yolo/utils.h deleted file mode 100644 index 0de663c..0000000 --- a/external/yolov5/nvdsinfer_custom_impl_Yolo/utils.h +++ /dev/null @@ -1,94 +0,0 @@ -#ifndef __TRT_UTILS_H_ -#define __TRT_UTILS_H_ - -#include -#include -#include -#include - -#ifndef CUDA_CHECK - -#define CUDA_CHECK(callstr) \ - { \ - cudaError_t error_code = callstr; \ - if (error_code != cudaSuccess) { \ - std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ - assert(0); \ - } \ - } - -#endif - -namespace Tn -{ - class Profiler : public nvinfer1::IProfiler - { - public: - void printLayerTimes(int itrationsTimes) - { - float totalTime = 0; - for (size_t i = 0; i < mProfile.size(); i++) - { - printf("%-40.40s %4.3fms\n", mProfile[i].first.c_str(), mProfile[i].second / itrationsTimes); - totalTime += mProfile[i].second; - } - printf("Time over all layers: %4.3f\n", totalTime / itrationsTimes); - } - private: - typedef std::pair Record; - std::vector mProfile; - - virtual void reportLayerTime(const char* layerName, float ms) - { - auto record = std::find_if(mProfile.begin(), mProfile.end(), [&](const Record& r){ return r.first == layerName; }); - if (record == mProfile.end()) - mProfile.push_back(std::make_pair(layerName, ms)); - else - record->second += ms; - } - }; - - //Logger for TensorRT info/warning/errors - class Logger : public nvinfer1::ILogger - { - public: - - Logger(): Logger(Severity::kWARNING) {} - - Logger(Severity severity): reportableSeverity(severity) {} - - void log(Severity severity, const char* msg) override - { - // suppress messages with severity enum value greater than the reportable - if (severity > reportableSeverity) return; - - switch (severity) - { - case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break; - case Severity::kERROR: std::cerr << "ERROR: "; break; - case Severity::kWARNING: std::cerr << "WARNING: "; break; - case Severity::kINFO: std::cerr << "INFO: "; break; - default: std::cerr << "UNKNOWN: "; break; - } - std::cerr << msg << std::endl; - } - - Severity reportableSeverity{Severity::kWARNING}; - }; - - template - void write(char*& buffer, const T& val) - { - *reinterpret_cast(buffer) = val; - buffer += sizeof(T); - } - - template - void read(const char*& buffer, T& val) - { - val = *reinterpret_cast(buffer); - buffer += sizeof(T); - } -} - -#endif \ No newline at end of file diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.h b/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.h deleted file mode 100644 index 91116cd..0000000 --- a/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.h +++ /dev/null @@ -1,152 +0,0 @@ -#ifndef _YOLO_LAYER_H -#define _YOLO_LAYER_H - -#include -#include -#include "NvInfer.h" - -namespace Yolo -{ - static constexpr int CHECK_COUNT = 3; - static constexpr float IGNORE_THRESH = 0.1f; - static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000; - static constexpr int CLASS_NUM = 80; - static constexpr int INPUT_H = 608; - static constexpr int INPUT_W = 608; - - struct YoloKernel - { - int width; - int height; - float anchors[CHECK_COUNT*2]; - }; - - static constexpr YoloKernel yolo1 = { - INPUT_W / 32, - INPUT_H / 32, - {116,90, 156,198, 373,326} - }; - static constexpr YoloKernel yolo2 = { - INPUT_W / 16, - INPUT_H / 16, - {30,61, 62,45, 59,119} - }; - static constexpr YoloKernel yolo3 = { - INPUT_W / 8, - INPUT_H / 8, - {10,13, 16,30, 33,23} - }; - - static constexpr int LOCATIONS = 4; - struct alignas(float) Detection{ - //center_x center_y w h - float bbox[LOCATIONS]; - float conf; // bbox_conf * cls_conf - float class_id; - }; -} - -namespace nvinfer1 -{ - class YoloLayerPlugin: public IPluginV2IOExt - { - public: - explicit YoloLayerPlugin(); - YoloLayerPlugin(const void* data, size_t length); - - ~YoloLayerPlugin(); - - int getNbOutputs() const override - { - return 1; - } - - Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override; - - int initialize() override; - - virtual void terminate() override {}; - - virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;} - - virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override; - - virtual size_t getSerializationSize() const override; - - virtual void serialize(void* buffer) const override; - - bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override { - return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT; - } - - const char* getPluginType() const override; - - const char* getPluginVersion() const override; - - void destroy() override; - - IPluginV2IOExt* clone() const override; - - void setPluginNamespace(const char* pluginNamespace) override; - - const char* getPluginNamespace() const override; - - DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override; - - bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override; - - bool canBroadcastInputAcrossBatch(int inputIndex) const override; - - void attachToContext( - cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override; - - void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override; - - void detachFromContext() override; - - private: - void forwardGpu(const float *const * inputs,float * output, cudaStream_t stream,int batchSize = 1); - int mClassCount; - int mKernelCount; - std::vector mYoloKernel; - int mThreadCount = 256; - void** mAnchor; - const char* mPluginNamespace; - }; - - class YoloPluginCreator : public IPluginCreator - { - public: - YoloPluginCreator(); - - ~YoloPluginCreator() override = default; - - const char* getPluginName() const override; - - const char* getPluginVersion() const override; - - const PluginFieldCollection* getFieldNames() override; - - IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override; - - IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override; - - void setPluginNamespace(const char* libNamespace) override - { - mNamespace = libNamespace; - } - - const char* getPluginNamespace() const override - { - return mNamespace.c_str(); - } - - private: - std::string mNamespace; - static PluginFieldCollection mFC; - static std::vector mPluginAttributes; - }; - REGISTER_TENSORRT_PLUGIN(YoloPluginCreator); -}; - -#endif