From ae849d2f8b1aee87a9c56d4564cd9eb39e349a13 Mon Sep 17 00:00:00 2001 From: Marcos Luciano Date: Mon, 3 May 2021 21:01:05 -0300 Subject: [PATCH] Support for YOLOv5 4.0 Added support for YOLOv5 4.0 --- YOLOv5-4.0.md | 172 +++++++++ YOLOv5.md => YOLOv5-5.0.md | 2 - external/yolov5-4.0/config_infer_primary.txt | 18 + external/yolov5-4.0/deepstream_app_config.txt | 63 ++++ external/yolov5-4.0/labels.txt | 80 +++++ .../nvdsinfer_custom_impl_Yolo/Makefile | 52 +++ .../nvdsinfer_custom_impl_Yolo/cuda_utils.h | 18 + .../nvdsparsebbox_Yolo.cpp | 122 +++++++ .../nvdsinfer_custom_impl_Yolo/yololayer.cu | 333 ++++++++++++++++++ .../nvdsinfer_custom_impl_Yolo/yololayer.h | 137 +++++++ 10 files changed, 995 insertions(+), 2 deletions(-) create mode 100644 YOLOv5-4.0.md rename YOLOv5.md => YOLOv5-5.0.md (99%) create mode 100644 external/yolov5-4.0/config_infer_primary.txt create mode 100644 external/yolov5-4.0/deepstream_app_config.txt create mode 100644 external/yolov5-4.0/labels.txt create mode 100644 external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/Makefile create mode 100644 external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/cuda_utils.h create mode 100644 external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp create mode 100644 external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/yololayer.cu create mode 100644 external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/yololayer.h diff --git a/YOLOv5-4.0.md b/YOLOv5-4.0.md new file mode 100644 index 0000000..af51329 --- /dev/null +++ b/YOLOv5-4.0.md @@ -0,0 +1,172 @@ +# YOLOv5 +NVIDIA DeepStream SDK 5.1 configuration for YOLOv5 4.0 models + +Thanks [wang-xinyu](https://github.com/wang-xinyu/tensorrtx) and [Ultralytics](https://github.com/ultralytics/yolov5) + +## + +* [Requirements](#requirements) +* [Convert PyTorch model to wts file](#convert-pytorch-model-to-wts-file) +* [Convert wts file to TensorRT model](#convert-wts-file-to-tensorrt-model) +* [Compile nvdsinfer_custom_impl_Yolo](#compile-nvdsinfer_custom_impl_yolo) +* [Testing model](#testing-model) + +## + +### Requirements +* [TensorRTX](https://github.com/wang-xinyu/tensorrtx/blob/master/tutorials/install.md) + +* [Ultralytics](https://github.com/ultralytics/yolov5/blob/v4.0/requirements.txt) + +* Matplotlib (for Jetson plataform) +``` +sudo apt-get install python3-matplotlib +``` + +* PyTorch (for Jetson plataform) +``` +wget https://nvidia.box.com/shared/static/9eptse6jyly1ggt9axbja2yrmj6pbarc.whl -O torch-1.6.0-cp36-cp36m-linux_aarch64.whl +sudo apt-get install python3-pip libopenblas-base libopenmpi-dev +pip3 install torch-1.6.0-cp36-cp36m-linux_aarch64.whl +``` + +* TorchVision (for Jetson platform) +``` +git clone -b v0.7.0 https://github.com/pytorch/vision torchvision +sudo apt-get install libjpeg-dev zlib1g-dev python3-pip +cd torchvision +export BUILD_VERSION=0.7.0 +sudo python3 setup.py install +``` + +## + +### Convert PyTorch model to wts file +1. Download repositories +``` +git clone -b yolov5-v4.0 https://github.com/wang-xinyu/tensorrtx.git +git clone -b v4.0 https://github.com/ultralytics/yolov5.git +``` + +2. Download latest YoloV5 (YOLOv5s, YOLOv5m, YOLOv5l or YOLOv5x) weights to yolov5 folder (example for YOLOv5s) +``` +wget https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt -P yolov5/weights +``` + +3. Copy gen_wts.py file (from tensorrtx/yolov5 folder) to yolov5 (ultralytics) folder +``` +cp tensorrtx/yolov5/gen_wts.py yolov5/gen_wts.py +``` + +4. Generate wts file +``` +cd yolov5 +python3 gen_wts.py +``` + +yolov5s.wts file will be generated in yolov5 folder + +## + +### Convert wts file to TensorRT model +1. Build tensorrtx/yolov5 +``` +cd tensorrtx/yolov5 +mkdir build +cd build +cmake .. +make +``` + +2. Move generated yolov5s.wts file to tensorrtx/yolov5 folder (example for YOLOv5s) +``` +cp yolov5/yolov5s.wts tensorrtx/yolov5/build/yolov5s.wts +``` + +3. Convert to TensorRT model (yolov5s.engine file will be generated in tensorrtx/yolov5/build folder) +``` +sudo ./yolov5 -s yolov5s.wts yolov5s.engine s +``` + +4. Create a custom yolo folder and copy generated file (example for YOLOv5s) +``` +mkdir /opt/nvidia/deepstream/deepstream-5.1/sources/yolo +cp yolov5s.engine /opt/nvidia/deepstream/deepstream-5.1/sources/yolo/yolov5s.engine +``` + +
+ +Note: by default, yolov5 script generate model with batch size = 1 and FP16 mode. +``` +#define USE_FP32 // set USE_INT8 or USE_FP16 or USE_FP32 +#define DEVICE 0 // GPU id +#define NMS_THRESH 0.4 +#define CONF_THRESH 0.5 +#define BATCH_SIZE 1 +``` +Edit yolov5.cpp file before compile if you want to change this parameters. + +## + +### Compile nvdsinfer_custom_impl_Yolo +1. Run command +``` +sudo chmod -R 777 /opt/nvidia/deepstream/deepstream-5.1/sources/ +``` + +2. Donwload [my external/yolov5-4.0 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5-4.0) and move files to created yolo folder + +3. Compile lib + +* x86 platform +``` +cd /opt/nvidia/deepstream/deepstream-5.1/sources/yolo +CUDA_VER=11.1 make -C nvdsinfer_custom_impl_Yolo +``` + +* Jetson platform +``` +cd /opt/nvidia/deepstream/deepstream-5.1/sources/yolo +CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo +``` + +## + +### Testing model +Use my edited [deepstream_app_config.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5-4.0/deepstream_app_config.txt) and [config_infer_primary.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5-4.0/config_infer_primary.txt) files available in [my external/yolov5-4.0 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5-4.0) + +Run command +``` +deepstream-app -c deepstream_app_config.txt +``` + +
+ +Note: based on selected model, edit config_infer_primary.txt file + +For example, if you using YOLOv5x + +``` +model-engine-file=yolov5s.engine +``` + +to + +``` +model-engine-file=yolov5x.engine +``` + +## + +To change NMS_THRESH, edit nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp file and recompile + +``` +#define kNMS_THRESH 0.45 +``` + +To change CONF_THRESH, edit config_infer_primary.txt file + +``` +[class-attrs-all] +pre-cluster-threshold=0.25 +``` diff --git a/YOLOv5.md b/YOLOv5-5.0.md similarity index 99% rename from YOLOv5.md rename to YOLOv5-5.0.md index 71f312c..bf62a19 100644 --- a/YOLOv5.md +++ b/YOLOv5-5.0.md @@ -3,8 +3,6 @@ NVIDIA DeepStream SDK 5.1 configuration for YOLOv5 5.0 models Thanks [wang-xinyu](https://github.com/wang-xinyu/tensorrtx) and [Ultralytics](https://github.com/ultralytics/yolov5) -Supported version: YOLOv5 5.0 - ## * [Requirements](#requirements) diff --git a/external/yolov5-4.0/config_infer_primary.txt b/external/yolov5-4.0/config_infer_primary.txt new file mode 100644 index 0000000..ee008f0 --- /dev/null +++ b/external/yolov5-4.0/config_infer_primary.txt @@ -0,0 +1,18 @@ +[property] +gpu-id=0 +net-scale-factor=0.0039215697906911373 +model-color-format=0 +model-engine-file=yolov5s.engine +labelfile-path=labels.txt +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=4 +maintain-aspect-ratio=0 +parse-bbox-func-name=NvDsInferParseCustomYoloV5 +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so + +[class-attrs-all] +pre-cluster-threshold=0.25 diff --git a/external/yolov5-4.0/deepstream_app_config.txt b/external/yolov5-4.0/deepstream_app_config.txt new file mode 100644 index 0000000..cd2a411 --- /dev/null +++ b/external/yolov5-4.0/deepstream_app_config.txt @@ -0,0 +1,63 @@ +[application] +enable-perf-measurement=1 +perf-measurement-interval-sec=1 + +[tiled-display] +enable=1 +rows=1 +columns=1 +width=1280 +height=720 +gpu-id=0 +nvbuf-memory-type=0 + +[source0] +enable=1 +type=3 +uri=file://../../samples/streams/sample_1080p_h264.mp4 +num-sources=1 +gpu-id=0 +cudadec-memtype=0 + +[sink0] +enable=1 +type=2 +sync=0 +source-id=0 +gpu-id=0 +nvbuf-memory-type=0 + +[osd] +enable=1 +gpu-id=0 +border-width=1 +text-size=15 +text-color=1;1;1;1; +text-bg-color=0.3;0.3;0.3;1 +font=Serif +show-clock=0 +clock-x-offset=800 +clock-y-offset=820 +clock-text-size=12 +clock-color=1;0;0;0 +nvbuf-memory-type=0 + +[streammux] +gpu-id=0 +live-source=0 +batch-size=1 +batched-push-timeout=40000 +width=1920 +height=1080 +enable-padding=0 +nvbuf-memory-type=0 + +[primary-gie] +enable=1 +gpu-id=0 +gie-unique-id=1 +nvbuf-memory-type=0 +config-file=config_infer_primary.txt + +[tests] +file-loop=0 diff --git a/external/yolov5-4.0/labels.txt b/external/yolov5-4.0/labels.txt new file mode 100644 index 0000000..ca76c80 --- /dev/null +++ b/external/yolov5-4.0/labels.txt @@ -0,0 +1,80 @@ +person +bicycle +car +motorbike +aeroplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +sofa +pottedplant +bed +diningtable +toilet +tvmonitor +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/Makefile b/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/Makefile new file mode 100644 index 0000000..8dc0218 --- /dev/null +++ b/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/Makefile @@ -0,0 +1,52 @@ +# +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +CUDA_VER?= +ifeq ($(CUDA_VER),) + $(error "CUDA_VER is not set") +endif +CC:= g++ +NVCC:=/usr/local/cuda-$(CUDA_VER)/bin/nvcc + +CFLAGS:= -Wall -std=c++11 -shared -fPIC -Wno-error=deprecated-declarations +CFLAGS+= -I../../includes -I/usr/local/cuda-$(CUDA_VER)/include + +LIBS:= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib64 -lcudart -lcublas -lstdc++fs +LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group + +INCS:= $(wildcard *.h) +SRCFILES:= nvdsparsebbox_Yolo.cpp \ + yololayer.cu + +TARGET_LIB:= libnvdsinfer_custom_impl_Yolo.so + +TARGET_OBJS:= $(SRCFILES:.cpp=.o) +TARGET_OBJS:= $(TARGET_OBJS:.cu=.o) + +all: $(TARGET_LIB) + +%.o: %.cpp $(INCS) Makefile + $(CC) -c -o $@ $(CFLAGS) $< + +%.o: %.cu $(INCS) Makefile + $(NVCC) -c -o $@ --compiler-options '-fPIC' $< + +$(TARGET_LIB) : $(TARGET_OBJS) + $(CC) -o $@ $(TARGET_OBJS) $(LFLAGS) + +clean: + rm -rf $(TARGET_LIB) + rm -rf $(TARGET_OBJS) diff --git a/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/cuda_utils.h b/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/cuda_utils.h new file mode 100644 index 0000000..8fbd319 --- /dev/null +++ b/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/cuda_utils.h @@ -0,0 +1,18 @@ +#ifndef TRTX_CUDA_UTILS_H_ +#define TRTX_CUDA_UTILS_H_ + +#include + +#ifndef CUDA_CHECK +#define CUDA_CHECK(callstr)\ + {\ + cudaError_t error_code = callstr;\ + if (error_code != cudaSuccess) {\ + std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\ + assert(0);\ + }\ + } +#endif // CUDA_CHECK + +#endif // TRTX_CUDA_UTILS_H_ + diff --git a/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp b/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp new file mode 100644 index 0000000..e38d4c1 --- /dev/null +++ b/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "nvdsinfer_custom_impl.h" + +#include + +#define kNMS_THRESH 0.45 + +static constexpr int LOCATIONS = 4; +struct alignas(float) Detection{ + //center_x center_y w h + float bbox[LOCATIONS]; + float conf; // bbox_conf * cls_conf + float class_id; + }; + +float iou(float lbox[4], float rbox[4]) { + float interBox[] = { + std::max(lbox[0] - lbox[2]/2.f , rbox[0] - rbox[2]/2.f), //left + std::min(lbox[0] + lbox[2]/2.f , rbox[0] + rbox[2]/2.f), //right + std::max(lbox[1] - lbox[3]/2.f , rbox[1] - rbox[3]/2.f), //top + std::min(lbox[1] + lbox[3]/2.f , rbox[1] + rbox[3]/2.f), //bottom + }; + + if(interBox[2] > interBox[3] || interBox[0] > interBox[1]) + return 0.0f; + + float interBoxS =(interBox[1]-interBox[0])*(interBox[3]-interBox[2]); + return interBoxS/(lbox[2]*lbox[3] + rbox[2]*rbox[3] -interBoxS); +} + +bool cmp(Detection& a, Detection& b) { + return a.conf > b.conf; +} + +void nms(std::vector& res, float *output, float conf_thresh, float nms_thresh) { + int det_size = sizeof(Detection) / sizeof(float); + std::map> m; + for (int i = 0; i < output[0] && i < 1000; i++) { + if (output[1 + det_size * i + 4] <= conf_thresh) continue; + Detection det; + memcpy(&det, &output[1 + det_size * i], det_size * sizeof(float)); + if (m.count(det.class_id) == 0) m.emplace(det.class_id, std::vector()); + m[det.class_id].push_back(det); + } + for (auto it = m.begin(); it != m.end(); it++) { + auto& dets = it->second; + std::sort(dets.begin(), dets.end(), cmp); + for (size_t m = 0; m < dets.size(); ++m) { + auto& item = dets[m]; + res.push_back(item); + for (size_t n = m + 1; n < dets.size(); ++n) { + if (iou(item.bbox, dets[n].bbox) > nms_thresh) { + dets.erase(dets.begin()+n); + --n; + } + } + } + } +} + +/* This is a sample bounding box parsing function for the sample YoloV5 detector model */ +static bool NvDsInferParseYoloV5( + std::vector const& outputLayersInfo, + NvDsInferNetworkInfo const& networkInfo, + NvDsInferParseDetectionParams const& detectionParams, + std::vector& objectList) +{ + const float kCONF_THRESH = detectionParams.perClassThreshold[0]; + + std::vector res; + + nms(res, (float*)(outputLayersInfo[0].buffer), kCONF_THRESH, kNMS_THRESH); + + for(auto& r : res) { + NvDsInferParseObjectInfo oinfo; + + oinfo.classId = r.class_id; + oinfo.left = static_cast(r.bbox[0]-r.bbox[2]*0.5f); + oinfo.top = static_cast(r.bbox[1]-r.bbox[3]*0.5f); + oinfo.width = static_cast(r.bbox[2]); + oinfo.height = static_cast(r.bbox[3]); + oinfo.detectionConfidence = r.conf; + objectList.push_back(oinfo); + } + + return true; +} + +extern "C" bool NvDsInferParseCustomYoloV5( + std::vector const &outputLayersInfo, + NvDsInferNetworkInfo const &networkInfo, + NvDsInferParseDetectionParams const &detectionParams, + std::vector &objectList) +{ + return NvDsInferParseYoloV5( + outputLayersInfo, networkInfo, detectionParams, objectList); +} + +/* Check that the custom function has been defined correctly */ +CHECK_CUSTOM_PARSE_FUNC_PROTOTYPE(NvDsInferParseCustomYoloV5); diff --git a/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/yololayer.cu b/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/yololayer.cu new file mode 100644 index 0000000..9d95e29 --- /dev/null +++ b/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/yololayer.cu @@ -0,0 +1,333 @@ +#include +#include +#include +#include "yololayer.h" +#include "cuda_utils.h" + +namespace Tn +{ + template + void write(char*& buffer, const T& val) + { + *reinterpret_cast(buffer) = val; + buffer += sizeof(T); + } + + template + void read(const char*& buffer, T& val) + { + val = *reinterpret_cast(buffer); + buffer += sizeof(T); + } +} + +using namespace Yolo; + +namespace nvinfer1 +{ + YoloLayerPlugin::YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector& vYoloKernel) + { + mClassCount = classCount; + mYoloV5NetWidth = netWidth; + mYoloV5NetHeight = netHeight; + mMaxOutObject = maxOut; + mYoloKernel = vYoloKernel; + mKernelCount = vYoloKernel.size(); + + CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); + size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2; + for (int ii = 0; ii < mKernelCount; ii++) + { + CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen)); + const auto& yolo = mYoloKernel[ii]; + CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); + } + } + YoloLayerPlugin::~YoloLayerPlugin() + { + for (int ii = 0; ii < mKernelCount; ii++) + { + CUDA_CHECK(cudaFree(mAnchor[ii])); + } + CUDA_CHECK(cudaFreeHost(mAnchor)); + } + + // create the plugin at runtime from a byte stream + YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length) + { + using namespace Tn; + const char *d = reinterpret_cast(data), *a = d; + read(d, mClassCount); + read(d, mThreadCount); + read(d, mKernelCount); + read(d, mYoloV5NetWidth); + read(d, mYoloV5NetHeight); + read(d, mMaxOutObject); + mYoloKernel.resize(mKernelCount); + auto kernelSize = mKernelCount * sizeof(YoloKernel); + memcpy(mYoloKernel.data(), d, kernelSize); + d += kernelSize; + CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); + size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2; + for (int ii = 0; ii < mKernelCount; ii++) + { + CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen)); + const auto& yolo = mYoloKernel[ii]; + CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); + } + assert(d == a + length); + } + + void YoloLayerPlugin::serialize(void* buffer) const + { + using namespace Tn; + char* d = static_cast(buffer), *a = d; + write(d, mClassCount); + write(d, mThreadCount); + write(d, mKernelCount); + write(d, mYoloV5NetWidth); + write(d, mYoloV5NetHeight); + write(d, mMaxOutObject); + auto kernelSize = mKernelCount * sizeof(YoloKernel); + memcpy(d, mYoloKernel.data(), kernelSize); + d += kernelSize; + + assert(d == a + getSerializationSize()); + } + + size_t YoloLayerPlugin::getSerializationSize() const + { + return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size() + sizeof(mYoloV5NetWidth) + sizeof(mYoloV5NetHeight) + sizeof(mMaxOutObject); + } + + int YoloLayerPlugin::initialize() + { + return 0; + } + + Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) + { + //output the result to channel + int totalsize = mMaxOutObject * sizeof(Detection) / sizeof(float); + + return Dims3(totalsize + 1, 1, 1); + } + + // Set plugin namespace + void YoloLayerPlugin::setPluginNamespace(const char* pluginNamespace) + { + mPluginNamespace = pluginNamespace; + } + + const char* YoloLayerPlugin::getPluginNamespace() const + { + return mPluginNamespace; + } + + // Return the DataType of the plugin output at the requested index + DataType YoloLayerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const + { + return DataType::kFLOAT; + } + + // Return true if output tensor is broadcast across a batch. + bool YoloLayerPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const + { + return false; + } + + // Return true if plugin can use input that is broadcast across batch without replication. + bool YoloLayerPlugin::canBroadcastInputAcrossBatch(int inputIndex) const + { + return false; + } + + void YoloLayerPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) + { + } + + // Attach the plugin object to an execution context and grant the plugin the access to some context resource. + void YoloLayerPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) + { + } + + // Detach the plugin object from its execution context. + void YoloLayerPlugin::detachFromContext() {} + + const char* YoloLayerPlugin::getPluginType() const + { + return "YoloLayer_TRT"; + } + + const char* YoloLayerPlugin::getPluginVersion() const + { + return "1"; + } + + void YoloLayerPlugin::destroy() + { + delete this; + } + + // Clone the plugin + IPluginV2IOExt* YoloLayerPlugin::clone() const + { + YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, mYoloKernel); + p->setPluginNamespace(mPluginNamespace); + return p; + } + + __device__ float Logist(float data) { return 1.0f / (1.0f + expf(-data)); }; + + __global__ void CalDetection(const float *input, float *output, int noElements, + const int netwidth, const int netheight, int maxoutobject, int yoloWidth, int yoloHeight, const float anchors[CHECK_COUNT * 2], int classes, int outputElem) + { + + int idx = threadIdx.x + blockDim.x * blockIdx.x; + if (idx >= noElements) return; + + int total_grid = yoloWidth * yoloHeight; + int bnIdx = idx / total_grid; + idx = idx - total_grid * bnIdx; + int info_len_i = 5 + classes; + const float* curInput = input + bnIdx * (info_len_i * total_grid * CHECK_COUNT); + + for (int k = 0; k < 3; ++k) { + float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]); + if (box_prob < IGNORE_THRESH) continue; + int class_id = 0; + float max_cls_prob = 0.0; + for (int i = 5; i < info_len_i; ++i) { + float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]); + if (p > max_cls_prob) { + max_cls_prob = p; + class_id = i - 5; + } + } + float *res_count = output + bnIdx * outputElem; + int count = (int)atomicAdd(res_count, 1); + if (count >= maxoutobject) return; + char* data = (char *)res_count + sizeof(float) + count * sizeof(Detection); + Detection* det = (Detection*)(data); + + int row = idx / yoloWidth; + int col = idx % yoloWidth; + + //Location + // pytorch: + // y = x[i].sigmoid() + // y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy + // y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + // X: (sigmoid(tx) + cx)/FeaturemapW * netwidth + det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * netwidth / yoloWidth; + det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * netheight / yoloHeight; + + // W: (Pw * e^tw) / FeaturemapW * netwidth + // v5: https://github.com/ultralytics/yolov5/issues/471 + det->bbox[2] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 2 * total_grid]); + det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2 * k]; + det->bbox[3] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 3 * total_grid]); + det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2 * k + 1]; + det->conf = box_prob * max_cls_prob; + det->class_id = class_id; + } + } + + void YoloLayerPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) + { + int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float); + for (int idx = 0; idx < batchSize; ++idx) { + CUDA_CHECK(cudaMemset(output + idx * outputElem, 0, sizeof(float))); + } + int numElem = 0; + for (unsigned int i = 0; i < mYoloKernel.size(); ++i) + { + const auto& yolo = mYoloKernel[i]; + numElem = yolo.width*yolo.height*batchSize; + if (numElem < mThreadCount) + mThreadCount = numElem; + + //printf("Net: %d %d \n", mYoloV5NetWidth, mYoloV5NetHeight); + CalDetection << < (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> > + (inputs[i], output, numElem, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, yolo.width, yolo.height, (float *)mAnchor[i], mClassCount, outputElem); + } + } + + + int YoloLayerPlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) + { + forwardGpu((const float *const *)inputs, (float*)outputs[0], stream, batchSize); + return 0; + } + + PluginFieldCollection YoloPluginCreator::mFC{}; + std::vector YoloPluginCreator::mPluginAttributes; + + YoloPluginCreator::YoloPluginCreator() + { + mPluginAttributes.clear(); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* YoloPluginCreator::getPluginName() const + { + return "YoloLayer_TRT"; + } + + const char* YoloPluginCreator::getPluginVersion() const + { + return "1"; + } + + const PluginFieldCollection* YoloPluginCreator::getFieldNames() + { + return &mFC; + } + + IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) + { + int class_count = -1; + int input_w = -1; + int input_h = -1; + int max_output_object_count = -1; + std::vector yolo_kernels(3); + + const PluginField* fields = fc->fields; + for (int i = 0; i < fc->nbFields; i++) { + if (strcmp(fields[i].name, "netdata") == 0) { + assert(fields[i].type == PluginFieldType::kFLOAT32); + int *tmp = (int*)(fields[i].data); + class_count = tmp[0]; + input_w = tmp[1]; + input_h = tmp[2]; + max_output_object_count = tmp[3]; + } else if (strstr(fields[i].name, "yolodata") != NULL) { + assert(fields[i].type == PluginFieldType::kFLOAT32); + int *tmp = (int*)(fields[i].data); + YoloKernel kernel; + kernel.width = tmp[0]; + kernel.height = tmp[1]; + for (int j = 0; j < fields[i].length - 2; j++) { + kernel.anchors[j] = tmp[j + 2]; + } + yolo_kernels[2 - (fields[i].name[8] - '1')] = kernel; + } + } + assert(class_count && input_w && input_h && max_output_object_count); + YoloLayerPlugin* obj = new YoloLayerPlugin(class_count, input_w, input_h, max_output_object_count, yolo_kernels); + obj->setPluginNamespace(mNamespace.c_str()); + return obj; + } + + IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) + { + // This object will be deleted when the network is destroyed, which will + // call YoloLayerPlugin::destroy() + YoloLayerPlugin* obj = new YoloLayerPlugin(serialData, serialLength); + obj->setPluginNamespace(mNamespace.c_str()); + return obj; + } +} + diff --git a/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/yololayer.h b/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/yololayer.h new file mode 100644 index 0000000..7357259 --- /dev/null +++ b/external/yolov5-4.0/nvdsinfer_custom_impl_Yolo/yololayer.h @@ -0,0 +1,137 @@ +#ifndef _YOLO_LAYER_H +#define _YOLO_LAYER_H + +#include +#include +#include "NvInfer.h" + +namespace Yolo +{ + static constexpr int CHECK_COUNT = 3; + static constexpr float IGNORE_THRESH = 0.1f; + struct YoloKernel + { + int width; + int height; + float anchors[CHECK_COUNT * 2]; + }; + static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000; + static constexpr int CLASS_NUM = 80; + static constexpr int INPUT_H = 640; // yolov5's input height and width must be divisible by 32. + static constexpr int INPUT_W = 640; + + static constexpr int LOCATIONS = 4; + struct alignas(float) Detection { + //center_x center_y w h + float bbox[LOCATIONS]; + float conf; // bbox_conf * cls_conf + float class_id; + }; +} + +namespace nvinfer1 +{ + class YoloLayerPlugin : public IPluginV2IOExt + { + public: + YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector& vYoloKernel); + YoloLayerPlugin(const void* data, size_t length); + ~YoloLayerPlugin(); + + int getNbOutputs() const override + { + return 1; + } + + Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override; + + int initialize() override; + + virtual void terminate() override {}; + + virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } + + virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override; + + virtual size_t getSerializationSize() const override; + + virtual void serialize(void* buffer) const override; + + bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override { + return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT; + } + + const char* getPluginType() const override; + + const char* getPluginVersion() const override; + + void destroy() override; + + IPluginV2IOExt* clone() const override; + + void setPluginNamespace(const char* pluginNamespace) override; + + const char* getPluginNamespace() const override; + + DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override; + + bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override; + + bool canBroadcastInputAcrossBatch(int inputIndex) const override; + + void attachToContext( + cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override; + + void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override; + + void detachFromContext() override; + + private: + void forwardGpu(const float *const * inputs, float * output, cudaStream_t stream, int batchSize = 1); + int mThreadCount = 256; + const char* mPluginNamespace; + int mKernelCount; + int mClassCount; + int mYoloV5NetWidth; + int mYoloV5NetHeight; + int mMaxOutObject; + std::vector mYoloKernel; + void** mAnchor; + }; + + class YoloPluginCreator : public IPluginCreator + { + public: + YoloPluginCreator(); + + ~YoloPluginCreator() override = default; + + const char* getPluginName() const override; + + const char* getPluginVersion() const override; + + const PluginFieldCollection* getFieldNames() override; + + IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override; + + IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override; + + void setPluginNamespace(const char* libNamespace) override + { + mNamespace = libNamespace; + } + + const char* getPluginNamespace() const override + { + return mNamespace.c_str(); + } + + private: + std::string mNamespace; + static PluginFieldCollection mFC; + static std::vector mPluginAttributes; + }; + REGISTER_TENSORRT_PLUGIN(YoloPluginCreator); +}; + +#endif