diff --git a/YOLOv5.md b/YOLOv5.md
index 087a5e1..89fa2c4 100644
--- a/YOLOv5.md
+++ b/YOLOv5.md
@@ -3,6 +3,8 @@ NVIDIA DeepStream SDK 5.0.1 configuration for YOLOv5 models
Thanks [DanaHan](https://github.com/DanaHan/Yolov5-in-Deepstream-5.0), [wang-xinyu](https://github.com/wang-xinyu/tensorrtx) and [Ultralytics](https://github.com/ultralytics/yolov5)
+Supported version: YOLOv5 3.0/3.1
+
##
* [Requirements](#requirements)
@@ -46,6 +48,16 @@ pip3 install scipy
pip3 install tqdm
```
+* Pandas
+```
+pip3 install pandas
+```
+
+* seaborn
+```
+pip3 install seaborn
+```
+
* PyTorch
```
pip3 install torch torchvision
@@ -77,6 +89,12 @@ git clone https://github.com/wang-xinyu/tensorrtx.git
git clone https://github.com/ultralytics/yolov5.git
```
+Note: checkout TensorRTX repo to 3.0/3.1 YOLOv5 version
+```
+cd tensorrtx
+git checkout '6d0f5cb'
+```
+
2. Download latest YoloV5 (YOLOv5s, YOLOv5m, YOLOv5l or YOLOv5x) weights to yolov5/weights directory (example for YOLOv5s)
```
wget https://github.com/ultralytics/yolov5/releases/download/v3.1/yolov5s.pt -P yolov5/weights/
@@ -112,8 +130,6 @@ f = open('yolov5s.wts', 'w')
```
mv yolov5converter/yololayer.cu tensorrtx/yolov5/yololayer.cu
mv yolov5converter/yololayer.h tensorrtx/yolov5/yololayer.h
-mv yolov5converter/hardswish.cu tensorrtx/yolov5/hardswish.cu
-mv yolov5converter/hardswish.h tensorrtx/yolov5/hardswish.h
```
2. Move generated yolov5s.wts file to tensorrtx/yolov5 folder (example for YOLOv5s)
@@ -130,7 +146,7 @@ cmake ..
make
```
-4. Convert to TensorRT model (yolov5s.engine and libmyplugins.so files will be generated in tensorrtx/yolov5/build folder)
+4. Convert to TensorRT model (yolov5s.engine file will be generated in tensorrtx/yolov5/build folder)
```
sudo ./yolov5 -s
```
@@ -139,7 +155,6 @@ sudo ./yolov5 -s
```
mkdir /opt/nvidia/deepstream/deepstream-5.0/sources/yolo
cp yolov5s.engine /opt/nvidia/deepstream/deepstream-5.0/sources/yolo/yolov5s.engine
-cp libmyplugins.so /opt/nvidia/deepstream/deepstream-5.0/sources/yolo/libmyplugins.so
```
@@ -179,7 +194,7 @@ Use my edited [deepstream_app_config.txt](https://raw.githubusercontent.com/marc
Run command
```
-LD_PRELOAD=./libmyplugins.so deepstream-app -c deepstream_app_config.txt
+deepstream-app -c deepstream_app_config.txt
```
diff --git a/external/yolov5/config_infer_primary.txt b/external/yolov5/config_infer_primary.txt
index a28a5a8..ee008f0 100644
--- a/external/yolov5/config_infer_primary.txt
+++ b/external/yolov5/config_infer_primary.txt
@@ -13,7 +13,6 @@ cluster-mode=4
maintain-aspect-ratio=0
parse-bbox-func-name=NvDsInferParseCustomYoloV5
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
-engine-create-func-name=NvDsInferYoloCudaEngineGet
[class-attrs-all]
pre-cluster-threshold=0.25
diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/Makefile b/external/yolov5/nvdsinfer_custom_impl_Yolo/Makefile
index b93ce4d..8dc0218 100644
--- a/external/yolov5/nvdsinfer_custom_impl_Yolo/Makefile
+++ b/external/yolov5/nvdsinfer_custom_impl_Yolo/Makefile
@@ -28,7 +28,8 @@ LIBS:= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib6
LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group
INCS:= $(wildcard *.h)
-SRCFILES:= nvdsparsebbox_Yolo.cpp
+SRCFILES:= nvdsparsebbox_Yolo.cpp \
+ yololayer.cu
TARGET_LIB:= libnvdsinfer_custom_impl_Yolo.so
@@ -48,3 +49,4 @@ $(TARGET_LIB) : $(TARGET_OBJS)
clean:
rm -rf $(TARGET_LIB)
+ rm -rf $(TARGET_OBJS)
diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/utils.h b/external/yolov5/nvdsinfer_custom_impl_Yolo/utils.h
new file mode 100644
index 0000000..0de663c
--- /dev/null
+++ b/external/yolov5/nvdsinfer_custom_impl_Yolo/utils.h
@@ -0,0 +1,94 @@
+#ifndef __TRT_UTILS_H_
+#define __TRT_UTILS_H_
+
+#include
+#include
+#include
+#include
+
+#ifndef CUDA_CHECK
+
+#define CUDA_CHECK(callstr) \
+ { \
+ cudaError_t error_code = callstr; \
+ if (error_code != cudaSuccess) { \
+ std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
+ assert(0); \
+ } \
+ }
+
+#endif
+
+namespace Tn
+{
+ class Profiler : public nvinfer1::IProfiler
+ {
+ public:
+ void printLayerTimes(int itrationsTimes)
+ {
+ float totalTime = 0;
+ for (size_t i = 0; i < mProfile.size(); i++)
+ {
+ printf("%-40.40s %4.3fms\n", mProfile[i].first.c_str(), mProfile[i].second / itrationsTimes);
+ totalTime += mProfile[i].second;
+ }
+ printf("Time over all layers: %4.3f\n", totalTime / itrationsTimes);
+ }
+ private:
+ typedef std::pair Record;
+ std::vector mProfile;
+
+ virtual void reportLayerTime(const char* layerName, float ms)
+ {
+ auto record = std::find_if(mProfile.begin(), mProfile.end(), [&](const Record& r){ return r.first == layerName; });
+ if (record == mProfile.end())
+ mProfile.push_back(std::make_pair(layerName, ms));
+ else
+ record->second += ms;
+ }
+ };
+
+ //Logger for TensorRT info/warning/errors
+ class Logger : public nvinfer1::ILogger
+ {
+ public:
+
+ Logger(): Logger(Severity::kWARNING) {}
+
+ Logger(Severity severity): reportableSeverity(severity) {}
+
+ void log(Severity severity, const char* msg) override
+ {
+ // suppress messages with severity enum value greater than the reportable
+ if (severity > reportableSeverity) return;
+
+ switch (severity)
+ {
+ case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break;
+ case Severity::kERROR: std::cerr << "ERROR: "; break;
+ case Severity::kWARNING: std::cerr << "WARNING: "; break;
+ case Severity::kINFO: std::cerr << "INFO: "; break;
+ default: std::cerr << "UNKNOWN: "; break;
+ }
+ std::cerr << msg << std::endl;
+ }
+
+ Severity reportableSeverity{Severity::kWARNING};
+ };
+
+ template
+ void write(char*& buffer, const T& val)
+ {
+ *reinterpret_cast(buffer) = val;
+ buffer += sizeof(T);
+ }
+
+ template
+ void read(const char*& buffer, T& val)
+ {
+ val = *reinterpret_cast(buffer);
+ buffer += sizeof(T);
+ }
+}
+
+#endif
\ No newline at end of file
diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.cu b/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.cu
new file mode 100644
index 0000000..a2e6ba3
--- /dev/null
+++ b/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.cu
@@ -0,0 +1,270 @@
+#include
+#include "yololayer.h"
+#include "utils.h"
+
+using namespace Yolo;
+
+namespace nvinfer1
+{
+ YoloLayerPlugin::YoloLayerPlugin()
+ {
+ mClassCount = CLASS_NUM;
+ mYoloKernel.clear();
+ mYoloKernel.push_back(yolo1);
+ mYoloKernel.push_back(yolo2);
+ mYoloKernel.push_back(yolo3);
+
+ mKernelCount = mYoloKernel.size();
+
+ CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
+ size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
+ for(int ii = 0; ii < mKernelCount; ii ++)
+ {
+ CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen));
+ const auto& yolo = mYoloKernel[ii];
+ CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
+ }
+ }
+
+ YoloLayerPlugin::~YoloLayerPlugin()
+ {
+ }
+
+ // create the plugin at runtime from a byte stream
+ YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length)
+ {
+ using namespace Tn;
+ const char *d = reinterpret_cast(data), *a = d;
+ read(d, mClassCount);
+ read(d, mThreadCount);
+ read(d, mKernelCount);
+ mYoloKernel.resize(mKernelCount);
+ auto kernelSize = mKernelCount*sizeof(YoloKernel);
+ memcpy(mYoloKernel.data(),d,kernelSize);
+ d += kernelSize;
+
+ CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
+ size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
+ for(int ii = 0; ii < mKernelCount; ii ++)
+ {
+ CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen));
+ const auto& yolo = mYoloKernel[ii];
+ CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
+ }
+
+ assert(d == a + length);
+ }
+
+ void YoloLayerPlugin::serialize(void* buffer) const
+ {
+ using namespace Tn;
+ char* d = static_cast(buffer), *a = d;
+ write(d, mClassCount);
+ write(d, mThreadCount);
+ write(d, mKernelCount);
+ auto kernelSize = mKernelCount*sizeof(YoloKernel);
+ memcpy(d,mYoloKernel.data(),kernelSize);
+ d += kernelSize;
+
+ assert(d == a + getSerializationSize());
+ }
+
+ size_t YoloLayerPlugin::getSerializationSize() const
+ {
+ return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size();
+ }
+
+ int YoloLayerPlugin::initialize()
+ {
+ return 0;
+ }
+
+ Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
+ {
+ //output the result to channel
+ int totalsize = MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float);
+
+ return Dims3(totalsize + 1, 1, 1);
+ }
+
+ // Set plugin namespace
+ void YoloLayerPlugin::setPluginNamespace(const char* pluginNamespace)
+ {
+ mPluginNamespace = pluginNamespace;
+ }
+
+ const char* YoloLayerPlugin::getPluginNamespace() const
+ {
+ return mPluginNamespace;
+ }
+
+ // Return the DataType of the plugin output at the requested index
+ DataType YoloLayerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const
+ {
+ return DataType::kFLOAT;
+ }
+
+ // Return true if output tensor is broadcast across a batch.
+ bool YoloLayerPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const
+ {
+ return false;
+ }
+
+ // Return true if plugin can use input that is broadcast across batch without replication.
+ bool YoloLayerPlugin::canBroadcastInputAcrossBatch(int inputIndex) const
+ {
+ return false;
+ }
+
+ void YoloLayerPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput)
+ {
+ }
+
+ // Attach the plugin object to an execution context and grant the plugin the access to some context resource.
+ void YoloLayerPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator)
+ {
+ }
+
+ // Detach the plugin object from its execution context.
+ void YoloLayerPlugin::detachFromContext() {}
+
+ const char* YoloLayerPlugin::getPluginType() const
+ {
+ return "YoloLayer_TRT";
+ }
+
+ const char* YoloLayerPlugin::getPluginVersion() const
+ {
+ return "1";
+ }
+
+ void YoloLayerPlugin::destroy()
+ {
+ delete this;
+ }
+
+ // Clone the plugin
+ IPluginV2IOExt* YoloLayerPlugin::clone() const
+ {
+ YoloLayerPlugin *p = new YoloLayerPlugin();
+ p->setPluginNamespace(mPluginNamespace);
+ return p;
+ }
+
+ __device__ float Logist(float data){ return 1.0f / (1.0f + expf(-data)); };
+
+ __global__ void CalDetection(const float *input, float *output,int noElements,
+ int yoloWidth,int yoloHeight,const float anchors[CHECK_COUNT*2],int classes,int outputElem) {
+
+ int idx = threadIdx.x + blockDim.x * blockIdx.x;
+ if (idx >= noElements) return;
+
+ int total_grid = yoloWidth * yoloHeight;
+ int bnIdx = idx / total_grid;
+ idx = idx - total_grid*bnIdx;
+ int info_len_i = 5 + classes;
+ const float* curInput = input + bnIdx * (info_len_i * total_grid * CHECK_COUNT);
+
+ for (int k = 0; k < 3; ++k) {
+ float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]);
+ if (box_prob < IGNORE_THRESH) continue;
+ int class_id = 0;
+ float max_cls_prob = 0.0;
+ for (int i = 5; i < info_len_i; ++i) {
+ float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]);
+ if (p > max_cls_prob) {
+ max_cls_prob = p;
+ class_id = i - 5;
+ }
+ }
+ float *res_count = output + bnIdx*outputElem;
+ int count = (int)atomicAdd(res_count, 1);
+ if (count >= MAX_OUTPUT_BBOX_COUNT) return;
+ char* data = (char *)res_count + sizeof(float) + count * sizeof(Detection);
+ Detection* det = (Detection*)(data);
+
+ int row = idx / yoloWidth;
+ int col = idx % yoloWidth;
+
+ //Location
+ det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * INPUT_W / yoloWidth;
+ det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * INPUT_H / yoloHeight;
+ det->bbox[2] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 2 * total_grid]);
+ det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2*k];
+ det->bbox[3] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 3 * total_grid]);
+ det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2*k + 1];
+ det->conf = box_prob * max_cls_prob;
+ det->class_id = class_id;
+ }
+ }
+
+ void YoloLayerPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) {
+
+ int outputElem = 1 + MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float);
+
+ for(int idx = 0 ; idx < batchSize; ++idx) {
+ CUDA_CHECK(cudaMemset(output + idx*outputElem, 0, sizeof(float)));
+ }
+ int numElem = 0;
+ for (unsigned int i = 0; i < mYoloKernel.size(); ++i)
+ {
+ const auto& yolo = mYoloKernel[i];
+ numElem = yolo.width*yolo.height*batchSize;
+ if (numElem < mThreadCount)
+ mThreadCount = numElem;
+ CalDetection<<< (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream>>>
+ (inputs[i], output, numElem, yolo.width, yolo.height, (float *)mAnchor[i], mClassCount, outputElem);
+ }
+
+ }
+
+
+ int YoloLayerPlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
+ {
+ forwardGpu((const float *const *)inputs, (float*)outputs[0], stream, batchSize);
+ return 0;
+ }
+
+ PluginFieldCollection YoloPluginCreator::mFC{};
+ std::vector YoloPluginCreator::mPluginAttributes;
+
+ YoloPluginCreator::YoloPluginCreator()
+ {
+ mPluginAttributes.clear();
+
+ mFC.nbFields = mPluginAttributes.size();
+ mFC.fields = mPluginAttributes.data();
+ }
+
+ const char* YoloPluginCreator::getPluginName() const
+ {
+ return "YoloLayer_TRT";
+ }
+
+ const char* YoloPluginCreator::getPluginVersion() const
+ {
+ return "1";
+ }
+
+ const PluginFieldCollection* YoloPluginCreator::getFieldNames()
+ {
+ return &mFC;
+ }
+
+ IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
+ {
+ YoloLayerPlugin* obj = new YoloLayerPlugin();
+ obj->setPluginNamespace(mNamespace.c_str());
+ return obj;
+ }
+
+ IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
+ {
+ // This object will be deleted when the network is destroyed, which will
+ // call MishPlugin::destroy()
+ YoloLayerPlugin* obj = new YoloLayerPlugin(serialData, serialLength);
+ obj->setPluginNamespace(mNamespace.c_str());
+ return obj;
+ }
+
+}
diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.h b/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.h
new file mode 100644
index 0000000..91116cd
--- /dev/null
+++ b/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.h
@@ -0,0 +1,152 @@
+#ifndef _YOLO_LAYER_H
+#define _YOLO_LAYER_H
+
+#include
+#include
+#include "NvInfer.h"
+
+namespace Yolo
+{
+ static constexpr int CHECK_COUNT = 3;
+ static constexpr float IGNORE_THRESH = 0.1f;
+ static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000;
+ static constexpr int CLASS_NUM = 80;
+ static constexpr int INPUT_H = 608;
+ static constexpr int INPUT_W = 608;
+
+ struct YoloKernel
+ {
+ int width;
+ int height;
+ float anchors[CHECK_COUNT*2];
+ };
+
+ static constexpr YoloKernel yolo1 = {
+ INPUT_W / 32,
+ INPUT_H / 32,
+ {116,90, 156,198, 373,326}
+ };
+ static constexpr YoloKernel yolo2 = {
+ INPUT_W / 16,
+ INPUT_H / 16,
+ {30,61, 62,45, 59,119}
+ };
+ static constexpr YoloKernel yolo3 = {
+ INPUT_W / 8,
+ INPUT_H / 8,
+ {10,13, 16,30, 33,23}
+ };
+
+ static constexpr int LOCATIONS = 4;
+ struct alignas(float) Detection{
+ //center_x center_y w h
+ float bbox[LOCATIONS];
+ float conf; // bbox_conf * cls_conf
+ float class_id;
+ };
+}
+
+namespace nvinfer1
+{
+ class YoloLayerPlugin: public IPluginV2IOExt
+ {
+ public:
+ explicit YoloLayerPlugin();
+ YoloLayerPlugin(const void* data, size_t length);
+
+ ~YoloLayerPlugin();
+
+ int getNbOutputs() const override
+ {
+ return 1;
+ }
+
+ Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
+
+ int initialize() override;
+
+ virtual void terminate() override {};
+
+ virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;}
+
+ virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override;
+
+ virtual size_t getSerializationSize() const override;
+
+ virtual void serialize(void* buffer) const override;
+
+ bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
+ return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
+ }
+
+ const char* getPluginType() const override;
+
+ const char* getPluginVersion() const override;
+
+ void destroy() override;
+
+ IPluginV2IOExt* clone() const override;
+
+ void setPluginNamespace(const char* pluginNamespace) override;
+
+ const char* getPluginNamespace() const override;
+
+ DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
+
+ bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
+
+ bool canBroadcastInputAcrossBatch(int inputIndex) const override;
+
+ void attachToContext(
+ cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
+
+ void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
+
+ void detachFromContext() override;
+
+ private:
+ void forwardGpu(const float *const * inputs,float * output, cudaStream_t stream,int batchSize = 1);
+ int mClassCount;
+ int mKernelCount;
+ std::vector mYoloKernel;
+ int mThreadCount = 256;
+ void** mAnchor;
+ const char* mPluginNamespace;
+ };
+
+ class YoloPluginCreator : public IPluginCreator
+ {
+ public:
+ YoloPluginCreator();
+
+ ~YoloPluginCreator() override = default;
+
+ const char* getPluginName() const override;
+
+ const char* getPluginVersion() const override;
+
+ const PluginFieldCollection* getFieldNames() override;
+
+ IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
+
+ IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
+
+ void setPluginNamespace(const char* libNamespace) override
+ {
+ mNamespace = libNamespace;
+ }
+
+ const char* getPluginNamespace() const override
+ {
+ return mNamespace.c_str();
+ }
+
+ private:
+ std::string mNamespace;
+ static PluginFieldCollection mFC;
+ static std::vector mPluginAttributes;
+ };
+ REGISTER_TENSORRT_PLUGIN(YoloPluginCreator);
+};
+
+#endif