From 470ed82658a5546b55185b3223f8057ecf54cf88 Mon Sep 17 00:00:00 2001 From: Marcos Luciano Date: Sun, 10 Jan 2021 11:08:20 -0300 Subject: [PATCH] YOLOv5 files updated * Added supported version information * Not needed to use libmyplugins.so anymore --- YOLOv5.md | 25 +- external/yolov5/config_infer_primary.txt | 1 - .../nvdsinfer_custom_impl_Yolo/Makefile | 4 +- .../yolov5/nvdsinfer_custom_impl_Yolo/utils.h | 94 ++++++ .../nvdsinfer_custom_impl_Yolo/yololayer.cu | 270 ++++++++++++++++++ .../nvdsinfer_custom_impl_Yolo/yololayer.h | 152 ++++++++++ 6 files changed, 539 insertions(+), 7 deletions(-) create mode 100644 external/yolov5/nvdsinfer_custom_impl_Yolo/utils.h create mode 100644 external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.cu create mode 100644 external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.h diff --git a/YOLOv5.md b/YOLOv5.md index 087a5e1..89fa2c4 100644 --- a/YOLOv5.md +++ b/YOLOv5.md @@ -3,6 +3,8 @@ NVIDIA DeepStream SDK 5.0.1 configuration for YOLOv5 models Thanks [DanaHan](https://github.com/DanaHan/Yolov5-in-Deepstream-5.0), [wang-xinyu](https://github.com/wang-xinyu/tensorrtx) and [Ultralytics](https://github.com/ultralytics/yolov5) +Supported version: YOLOv5 3.0/3.1 + ## * [Requirements](#requirements) @@ -46,6 +48,16 @@ pip3 install scipy pip3 install tqdm ``` +* Pandas +``` +pip3 install pandas +``` + +* seaborn +``` +pip3 install seaborn +``` + * PyTorch ``` pip3 install torch torchvision @@ -77,6 +89,12 @@ git clone https://github.com/wang-xinyu/tensorrtx.git git clone https://github.com/ultralytics/yolov5.git ``` +Note: checkout TensorRTX repo to 3.0/3.1 YOLOv5 version +``` +cd tensorrtx +git checkout '6d0f5cb' +``` + 2. Download latest YoloV5 (YOLOv5s, YOLOv5m, YOLOv5l or YOLOv5x) weights to yolov5/weights directory (example for YOLOv5s) ``` wget https://github.com/ultralytics/yolov5/releases/download/v3.1/yolov5s.pt -P yolov5/weights/ @@ -112,8 +130,6 @@ f = open('yolov5s.wts', 'w') ``` mv yolov5converter/yololayer.cu tensorrtx/yolov5/yololayer.cu mv yolov5converter/yololayer.h tensorrtx/yolov5/yololayer.h -mv yolov5converter/hardswish.cu tensorrtx/yolov5/hardswish.cu -mv yolov5converter/hardswish.h tensorrtx/yolov5/hardswish.h ``` 2. Move generated yolov5s.wts file to tensorrtx/yolov5 folder (example for YOLOv5s) @@ -130,7 +146,7 @@ cmake .. make ``` -4. Convert to TensorRT model (yolov5s.engine and libmyplugins.so files will be generated in tensorrtx/yolov5/build folder) +4. Convert to TensorRT model (yolov5s.engine file will be generated in tensorrtx/yolov5/build folder) ``` sudo ./yolov5 -s ``` @@ -139,7 +155,6 @@ sudo ./yolov5 -s ``` mkdir /opt/nvidia/deepstream/deepstream-5.0/sources/yolo cp yolov5s.engine /opt/nvidia/deepstream/deepstream-5.0/sources/yolo/yolov5s.engine -cp libmyplugins.so /opt/nvidia/deepstream/deepstream-5.0/sources/yolo/libmyplugins.so ```
@@ -179,7 +194,7 @@ Use my edited [deepstream_app_config.txt](https://raw.githubusercontent.com/marc Run command ``` -LD_PRELOAD=./libmyplugins.so deepstream-app -c deepstream_app_config.txt +deepstream-app -c deepstream_app_config.txt ```
diff --git a/external/yolov5/config_infer_primary.txt b/external/yolov5/config_infer_primary.txt index a28a5a8..ee008f0 100644 --- a/external/yolov5/config_infer_primary.txt +++ b/external/yolov5/config_infer_primary.txt @@ -13,7 +13,6 @@ cluster-mode=4 maintain-aspect-ratio=0 parse-bbox-func-name=NvDsInferParseCustomYoloV5 custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so -engine-create-func-name=NvDsInferYoloCudaEngineGet [class-attrs-all] pre-cluster-threshold=0.25 diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/Makefile b/external/yolov5/nvdsinfer_custom_impl_Yolo/Makefile index b93ce4d..8dc0218 100644 --- a/external/yolov5/nvdsinfer_custom_impl_Yolo/Makefile +++ b/external/yolov5/nvdsinfer_custom_impl_Yolo/Makefile @@ -28,7 +28,8 @@ LIBS:= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib6 LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group INCS:= $(wildcard *.h) -SRCFILES:= nvdsparsebbox_Yolo.cpp +SRCFILES:= nvdsparsebbox_Yolo.cpp \ + yololayer.cu TARGET_LIB:= libnvdsinfer_custom_impl_Yolo.so @@ -48,3 +49,4 @@ $(TARGET_LIB) : $(TARGET_OBJS) clean: rm -rf $(TARGET_LIB) + rm -rf $(TARGET_OBJS) diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/utils.h b/external/yolov5/nvdsinfer_custom_impl_Yolo/utils.h new file mode 100644 index 0000000..0de663c --- /dev/null +++ b/external/yolov5/nvdsinfer_custom_impl_Yolo/utils.h @@ -0,0 +1,94 @@ +#ifndef __TRT_UTILS_H_ +#define __TRT_UTILS_H_ + +#include +#include +#include +#include + +#ifndef CUDA_CHECK + +#define CUDA_CHECK(callstr) \ + { \ + cudaError_t error_code = callstr; \ + if (error_code != cudaSuccess) { \ + std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ + assert(0); \ + } \ + } + +#endif + +namespace Tn +{ + class Profiler : public nvinfer1::IProfiler + { + public: + void printLayerTimes(int itrationsTimes) + { + float totalTime = 0; + for (size_t i = 0; i < mProfile.size(); i++) + { + printf("%-40.40s %4.3fms\n", mProfile[i].first.c_str(), mProfile[i].second / itrationsTimes); + totalTime += mProfile[i].second; + } + printf("Time over all layers: %4.3f\n", totalTime / itrationsTimes); + } + private: + typedef std::pair Record; + std::vector mProfile; + + virtual void reportLayerTime(const char* layerName, float ms) + { + auto record = std::find_if(mProfile.begin(), mProfile.end(), [&](const Record& r){ return r.first == layerName; }); + if (record == mProfile.end()) + mProfile.push_back(std::make_pair(layerName, ms)); + else + record->second += ms; + } + }; + + //Logger for TensorRT info/warning/errors + class Logger : public nvinfer1::ILogger + { + public: + + Logger(): Logger(Severity::kWARNING) {} + + Logger(Severity severity): reportableSeverity(severity) {} + + void log(Severity severity, const char* msg) override + { + // suppress messages with severity enum value greater than the reportable + if (severity > reportableSeverity) return; + + switch (severity) + { + case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break; + case Severity::kERROR: std::cerr << "ERROR: "; break; + case Severity::kWARNING: std::cerr << "WARNING: "; break; + case Severity::kINFO: std::cerr << "INFO: "; break; + default: std::cerr << "UNKNOWN: "; break; + } + std::cerr << msg << std::endl; + } + + Severity reportableSeverity{Severity::kWARNING}; + }; + + template + void write(char*& buffer, const T& val) + { + *reinterpret_cast(buffer) = val; + buffer += sizeof(T); + } + + template + void read(const char*& buffer, T& val) + { + val = *reinterpret_cast(buffer); + buffer += sizeof(T); + } +} + +#endif \ No newline at end of file diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.cu b/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.cu new file mode 100644 index 0000000..a2e6ba3 --- /dev/null +++ b/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.cu @@ -0,0 +1,270 @@ +#include +#include "yololayer.h" +#include "utils.h" + +using namespace Yolo; + +namespace nvinfer1 +{ + YoloLayerPlugin::YoloLayerPlugin() + { + mClassCount = CLASS_NUM; + mYoloKernel.clear(); + mYoloKernel.push_back(yolo1); + mYoloKernel.push_back(yolo2); + mYoloKernel.push_back(yolo3); + + mKernelCount = mYoloKernel.size(); + + CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); + size_t AnchorLen = sizeof(float)* CHECK_COUNT*2; + for(int ii = 0; ii < mKernelCount; ii ++) + { + CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen)); + const auto& yolo = mYoloKernel[ii]; + CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); + } + } + + YoloLayerPlugin::~YoloLayerPlugin() + { + } + + // create the plugin at runtime from a byte stream + YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length) + { + using namespace Tn; + const char *d = reinterpret_cast(data), *a = d; + read(d, mClassCount); + read(d, mThreadCount); + read(d, mKernelCount); + mYoloKernel.resize(mKernelCount); + auto kernelSize = mKernelCount*sizeof(YoloKernel); + memcpy(mYoloKernel.data(),d,kernelSize); + d += kernelSize; + + CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); + size_t AnchorLen = sizeof(float)* CHECK_COUNT*2; + for(int ii = 0; ii < mKernelCount; ii ++) + { + CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen)); + const auto& yolo = mYoloKernel[ii]; + CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); + } + + assert(d == a + length); + } + + void YoloLayerPlugin::serialize(void* buffer) const + { + using namespace Tn; + char* d = static_cast(buffer), *a = d; + write(d, mClassCount); + write(d, mThreadCount); + write(d, mKernelCount); + auto kernelSize = mKernelCount*sizeof(YoloKernel); + memcpy(d,mYoloKernel.data(),kernelSize); + d += kernelSize; + + assert(d == a + getSerializationSize()); + } + + size_t YoloLayerPlugin::getSerializationSize() const + { + return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size(); + } + + int YoloLayerPlugin::initialize() + { + return 0; + } + + Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) + { + //output the result to channel + int totalsize = MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float); + + return Dims3(totalsize + 1, 1, 1); + } + + // Set plugin namespace + void YoloLayerPlugin::setPluginNamespace(const char* pluginNamespace) + { + mPluginNamespace = pluginNamespace; + } + + const char* YoloLayerPlugin::getPluginNamespace() const + { + return mPluginNamespace; + } + + // Return the DataType of the plugin output at the requested index + DataType YoloLayerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const + { + return DataType::kFLOAT; + } + + // Return true if output tensor is broadcast across a batch. + bool YoloLayerPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const + { + return false; + } + + // Return true if plugin can use input that is broadcast across batch without replication. + bool YoloLayerPlugin::canBroadcastInputAcrossBatch(int inputIndex) const + { + return false; + } + + void YoloLayerPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) + { + } + + // Attach the plugin object to an execution context and grant the plugin the access to some context resource. + void YoloLayerPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) + { + } + + // Detach the plugin object from its execution context. + void YoloLayerPlugin::detachFromContext() {} + + const char* YoloLayerPlugin::getPluginType() const + { + return "YoloLayer_TRT"; + } + + const char* YoloLayerPlugin::getPluginVersion() const + { + return "1"; + } + + void YoloLayerPlugin::destroy() + { + delete this; + } + + // Clone the plugin + IPluginV2IOExt* YoloLayerPlugin::clone() const + { + YoloLayerPlugin *p = new YoloLayerPlugin(); + p->setPluginNamespace(mPluginNamespace); + return p; + } + + __device__ float Logist(float data){ return 1.0f / (1.0f + expf(-data)); }; + + __global__ void CalDetection(const float *input, float *output,int noElements, + int yoloWidth,int yoloHeight,const float anchors[CHECK_COUNT*2],int classes,int outputElem) { + + int idx = threadIdx.x + blockDim.x * blockIdx.x; + if (idx >= noElements) return; + + int total_grid = yoloWidth * yoloHeight; + int bnIdx = idx / total_grid; + idx = idx - total_grid*bnIdx; + int info_len_i = 5 + classes; + const float* curInput = input + bnIdx * (info_len_i * total_grid * CHECK_COUNT); + + for (int k = 0; k < 3; ++k) { + float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]); + if (box_prob < IGNORE_THRESH) continue; + int class_id = 0; + float max_cls_prob = 0.0; + for (int i = 5; i < info_len_i; ++i) { + float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]); + if (p > max_cls_prob) { + max_cls_prob = p; + class_id = i - 5; + } + } + float *res_count = output + bnIdx*outputElem; + int count = (int)atomicAdd(res_count, 1); + if (count >= MAX_OUTPUT_BBOX_COUNT) return; + char* data = (char *)res_count + sizeof(float) + count * sizeof(Detection); + Detection* det = (Detection*)(data); + + int row = idx / yoloWidth; + int col = idx % yoloWidth; + + //Location + det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * INPUT_W / yoloWidth; + det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * INPUT_H / yoloHeight; + det->bbox[2] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 2 * total_grid]); + det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2*k]; + det->bbox[3] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 3 * total_grid]); + det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2*k + 1]; + det->conf = box_prob * max_cls_prob; + det->class_id = class_id; + } + } + + void YoloLayerPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) { + + int outputElem = 1 + MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float); + + for(int idx = 0 ; idx < batchSize; ++idx) { + CUDA_CHECK(cudaMemset(output + idx*outputElem, 0, sizeof(float))); + } + int numElem = 0; + for (unsigned int i = 0; i < mYoloKernel.size(); ++i) + { + const auto& yolo = mYoloKernel[i]; + numElem = yolo.width*yolo.height*batchSize; + if (numElem < mThreadCount) + mThreadCount = numElem; + CalDetection<<< (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream>>> + (inputs[i], output, numElem, yolo.width, yolo.height, (float *)mAnchor[i], mClassCount, outputElem); + } + + } + + + int YoloLayerPlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) + { + forwardGpu((const float *const *)inputs, (float*)outputs[0], stream, batchSize); + return 0; + } + + PluginFieldCollection YoloPluginCreator::mFC{}; + std::vector YoloPluginCreator::mPluginAttributes; + + YoloPluginCreator::YoloPluginCreator() + { + mPluginAttributes.clear(); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* YoloPluginCreator::getPluginName() const + { + return "YoloLayer_TRT"; + } + + const char* YoloPluginCreator::getPluginVersion() const + { + return "1"; + } + + const PluginFieldCollection* YoloPluginCreator::getFieldNames() + { + return &mFC; + } + + IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) + { + YoloLayerPlugin* obj = new YoloLayerPlugin(); + obj->setPluginNamespace(mNamespace.c_str()); + return obj; + } + + IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) + { + // This object will be deleted when the network is destroyed, which will + // call MishPlugin::destroy() + YoloLayerPlugin* obj = new YoloLayerPlugin(serialData, serialLength); + obj->setPluginNamespace(mNamespace.c_str()); + return obj; + } + +} diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.h b/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.h new file mode 100644 index 0000000..91116cd --- /dev/null +++ b/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.h @@ -0,0 +1,152 @@ +#ifndef _YOLO_LAYER_H +#define _YOLO_LAYER_H + +#include +#include +#include "NvInfer.h" + +namespace Yolo +{ + static constexpr int CHECK_COUNT = 3; + static constexpr float IGNORE_THRESH = 0.1f; + static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000; + static constexpr int CLASS_NUM = 80; + static constexpr int INPUT_H = 608; + static constexpr int INPUT_W = 608; + + struct YoloKernel + { + int width; + int height; + float anchors[CHECK_COUNT*2]; + }; + + static constexpr YoloKernel yolo1 = { + INPUT_W / 32, + INPUT_H / 32, + {116,90, 156,198, 373,326} + }; + static constexpr YoloKernel yolo2 = { + INPUT_W / 16, + INPUT_H / 16, + {30,61, 62,45, 59,119} + }; + static constexpr YoloKernel yolo3 = { + INPUT_W / 8, + INPUT_H / 8, + {10,13, 16,30, 33,23} + }; + + static constexpr int LOCATIONS = 4; + struct alignas(float) Detection{ + //center_x center_y w h + float bbox[LOCATIONS]; + float conf; // bbox_conf * cls_conf + float class_id; + }; +} + +namespace nvinfer1 +{ + class YoloLayerPlugin: public IPluginV2IOExt + { + public: + explicit YoloLayerPlugin(); + YoloLayerPlugin(const void* data, size_t length); + + ~YoloLayerPlugin(); + + int getNbOutputs() const override + { + return 1; + } + + Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override; + + int initialize() override; + + virtual void terminate() override {}; + + virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;} + + virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override; + + virtual size_t getSerializationSize() const override; + + virtual void serialize(void* buffer) const override; + + bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override { + return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT; + } + + const char* getPluginType() const override; + + const char* getPluginVersion() const override; + + void destroy() override; + + IPluginV2IOExt* clone() const override; + + void setPluginNamespace(const char* pluginNamespace) override; + + const char* getPluginNamespace() const override; + + DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override; + + bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override; + + bool canBroadcastInputAcrossBatch(int inputIndex) const override; + + void attachToContext( + cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override; + + void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override; + + void detachFromContext() override; + + private: + void forwardGpu(const float *const * inputs,float * output, cudaStream_t stream,int batchSize = 1); + int mClassCount; + int mKernelCount; + std::vector mYoloKernel; + int mThreadCount = 256; + void** mAnchor; + const char* mPluginNamespace; + }; + + class YoloPluginCreator : public IPluginCreator + { + public: + YoloPluginCreator(); + + ~YoloPluginCreator() override = default; + + const char* getPluginName() const override; + + const char* getPluginVersion() const override; + + const PluginFieldCollection* getFieldNames() override; + + IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override; + + IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override; + + void setPluginNamespace(const char* libNamespace) override + { + mNamespace = libNamespace; + } + + const char* getPluginNamespace() const override + { + return mNamespace.c_str(); + } + + private: + std::string mNamespace; + static PluginFieldCollection mFC; + static std::vector mPluginAttributes; + }; + REGISTER_TENSORRT_PLUGIN(YoloPluginCreator); +}; + +#endif