diff --git a/YOLOv5.md b/YOLOv5.md
index 5303f2f..71f312c 100644
--- a/YOLOv5.md
+++ b/YOLOv5.md
@@ -1,9 +1,9 @@
# YOLOv5
-NVIDIA DeepStream SDK 5.1 configuration for YOLOv5 models
+NVIDIA DeepStream SDK 5.1 configuration for YOLOv5 5.0 models
-Thanks [DanaHan](https://github.com/DanaHan/Yolov5-in-Deepstream-5.0), [wang-xinyu](https://github.com/wang-xinyu/tensorrtx) and [Ultralytics](https://github.com/ultralytics/yolov5)
+Thanks [wang-xinyu](https://github.com/wang-xinyu/tensorrtx) and [Ultralytics](https://github.com/ultralytics/yolov5)
-Supported version: YOLOv5 3.0/3.1
+Supported version: YOLOv5 5.0
##
@@ -16,53 +16,15 @@ Supported version: YOLOv5 3.0/3.1
##
### Requirements
-* Python3
-```
-sudo apt-get install python3 python3-dev python3-pip
-pip3 install --upgrade pip
-```
+* [TensorRTX](https://github.com/wang-xinyu/tensorrtx/blob/master/tutorials/install.md)
-* OpenCV Python
-```
-sudo apt-get install libopencv-dev
-pip3 install opencv-python
-```
-
-* Matplotlib
-```
-pip3 install matplotlib
-```
+* [Ultralytics](https://github.com/ultralytics/yolov5/blob/master/requirements.txt)
* Matplotlib (for Jetson plataform)
```
sudo apt-get install python3-matplotlib
```
-* Scipy
-```
-pip3 install scipy
-```
-
-* tqdm
-```
-pip3 install tqdm
-```
-
-* Pandas
-```
-pip3 install pandas
-```
-
-* seaborn
-```
-pip3 install seaborn
-```
-
-* PyTorch
-```
-pip3 install torch torchvision
-```
-
* PyTorch (for Jetson plataform)
```
wget https://nvidia.box.com/shared/static/9eptse6jyly1ggt9axbja2yrmj6pbarc.whl -O torch-1.6.0-cp36-cp36m-linux_aarch64.whl
@@ -84,20 +46,13 @@ sudo python3 setup.py install
### Convert PyTorch model to wts file
1. Download repositories
```
-git clone https://github.com/DanaHan/Yolov5-in-Deepstream-5.0.git yolov5converter
git clone https://github.com/wang-xinyu/tensorrtx.git
git clone https://github.com/ultralytics/yolov5.git
```
-Note: checkout TensorRTX repo to 3.0/3.1 YOLOv5 version
+2. Download latest YoloV5 (YOLOv5s, YOLOv5m, YOLOv5l or YOLOv5x) weights to yolov5 folder (example for YOLOv5s)
```
-cd tensorrtx
-git checkout '6d0f5cb'
-```
-
-2. Download latest YoloV5 (YOLOv5s, YOLOv5m, YOLOv5l or YOLOv5x) weights to yolov5/weights directory (example for YOLOv5s)
-```
-wget https://github.com/ultralytics/yolov5/releases/download/v3.1/yolov5s.pt -P yolov5/weights/
+wget https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5s.pt -P yolov5/
```
3. Copy gen_wts.py file (from tensorrtx/yolov5 folder) to yolov5 (ultralytics) folder
@@ -108,36 +63,15 @@ cp tensorrtx/yolov5/gen_wts.py yolov5/gen_wts.py
4. Generate wts file
```
cd yolov5
-python3 gen_wts.py
+python3 gen_wts.py yolov5s.pt
```
yolov5s.wts file will be generated in yolov5 folder
-
-
-Note: if you want to generate wts file to another YOLOv5 model (YOLOv5m, YOLOv5l or YOLOv5x), edit get_wts.py file changing yolov5s to your model name
-```
-model = torch.load('weights/yolov5s.pt', map_location=device)['model'].float() # load to FP32
-model.to(device).eval()
-
-f = open('yolov5s.wts', 'w')
-```
-
##
### Convert wts file to TensorRT model
-1. Replace yololayer files from tensorrtx/yolov5 folder to yololayer and hardswish files from yolov5converter
-```
-mv yolov5converter/yololayer.cu tensorrtx/yolov5/yololayer.cu
-mv yolov5converter/yololayer.h tensorrtx/yolov5/yololayer.h
-```
-
-2. Move generated yolov5s.wts file to tensorrtx/yolov5 folder (example for YOLOv5s)
-```
-cp yolov5/yolov5s.wts tensorrtx/yolov5/yolov5s.wts
-```
-
-3. Build tensorrtx/yolov5
+1. Build tensorrtx/yolov5
```
cd tensorrtx/yolov5
mkdir build
@@ -146,12 +80,17 @@ cmake ..
make
```
-4. Convert to TensorRT model (yolov5s.engine file will be generated in tensorrtx/yolov5/build folder)
+2. Move generated yolov5s.wts file to tensorrtx/yolov5 folder (example for YOLOv5s)
```
-sudo ./yolov5 -s
+cp yolov5/yolov5s.wts tensorrtx/yolov5/build/yolov5s.wts
```
-5. Create a custom yolo folder and copy generated files (example for YOLOv5s)
+3. Convert to TensorRT model (yolov5s.engine file will be generated in tensorrtx/yolov5/build folder)
+```
+sudo ./yolov5 -s yolov5s.wts yolov5s.engine s
+```
+
+4. Create a custom yolo folder and copy generated file (example for YOLOv5s)
```
mkdir /opt/nvidia/deepstream/deepstream-5.1/sources/yolo
cp yolov5s.engine /opt/nvidia/deepstream/deepstream-5.1/sources/yolo/yolov5s.engine
@@ -159,15 +98,13 @@ cp yolov5s.engine /opt/nvidia/deepstream/deepstream-5.1/sources/yolo/yolov5s.eng
-Note: by default, yolov5 script generate model with batch size = 1, FP16 mode and s model.
+Note: by default, yolov5 script generate model with batch size = 1 and FP16 mode.
```
-#define USE_FP16 // comment out this if want to use FP32
+#define USE_FP32 // set USE_INT8 or USE_FP16 or USE_FP32
#define DEVICE 0 // GPU id
#define NMS_THRESH 0.4
#define CONF_THRESH 0.5
#define BATCH_SIZE 1
-
-#define NET s // s m l x
```
Edit yolov5.cpp file before compile if you want to change this parameters.
@@ -179,7 +116,7 @@ Edit yolov5.cpp file before compile if you want to change this parameters.
sudo chmod -R 777 /opt/nvidia/deepstream/deepstream-5.1/sources/
```
-2. Donwload [my external/yolov5 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5) and move files to created yolo folder
+2. Donwload [my external/yolov5-5.0 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5-5.0) and move files to created yolo folder
3. Compile lib
@@ -198,7 +135,7 @@ CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
##
### Testing model
-Use my edited [deepstream_app_config.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5/deepstream_app_config.txt) and [config_infer_primary.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5/config_infer_primary.txt) files available in [my external/yolov5 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5)
+Use my edited [deepstream_app_config.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5-5.0/deepstream_app_config.txt) and [config_infer_primary.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5-5.0/config_infer_primary.txt) files available in [my external/yolov5-5.0 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5-5.0)
Run command
```
diff --git a/external/yolov5/config_infer_primary.txt b/external/yolov5-5.0/config_infer_primary.txt
similarity index 100%
rename from external/yolov5/config_infer_primary.txt
rename to external/yolov5-5.0/config_infer_primary.txt
diff --git a/external/yolov5/deepstream_app_config.txt b/external/yolov5-5.0/deepstream_app_config.txt
similarity index 100%
rename from external/yolov5/deepstream_app_config.txt
rename to external/yolov5-5.0/deepstream_app_config.txt
diff --git a/external/yolov5/labels.txt b/external/yolov5-5.0/labels.txt
similarity index 100%
rename from external/yolov5/labels.txt
rename to external/yolov5-5.0/labels.txt
diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/Makefile b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/Makefile
similarity index 100%
rename from external/yolov5/nvdsinfer_custom_impl_Yolo/Makefile
rename to external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/Makefile
diff --git a/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/cuda_utils.h b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/cuda_utils.h
new file mode 100644
index 0000000..8fbd319
--- /dev/null
+++ b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/cuda_utils.h
@@ -0,0 +1,18 @@
+#ifndef TRTX_CUDA_UTILS_H_
+#define TRTX_CUDA_UTILS_H_
+
+#include
+
+#ifndef CUDA_CHECK
+#define CUDA_CHECK(callstr)\
+ {\
+ cudaError_t error_code = callstr;\
+ if (error_code != cudaSuccess) {\
+ std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\
+ assert(0);\
+ }\
+ }
+#endif // CUDA_CHECK
+
+#endif // TRTX_CUDA_UTILS_H_
+
diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp
similarity index 100%
rename from external/yolov5/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp
rename to external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp
diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.cu b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.cu
similarity index 55%
rename from external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.cu
rename to external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.cu
index a2e6ba3..525bf8d 100644
--- a/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.cu
+++ b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.cu
@@ -1,33 +1,55 @@
#include
+#include
+#include
#include "yololayer.h"
-#include "utils.h"
+#include "cuda_utils.h"
+
+namespace Tn
+{
+ template
+ void write(char*& buffer, const T& val)
+ {
+ *reinterpret_cast(buffer) = val;
+ buffer += sizeof(T);
+ }
+
+ template
+ void read(const char*& buffer, T& val)
+ {
+ val = *reinterpret_cast(buffer);
+ buffer += sizeof(T);
+ }
+}
using namespace Yolo;
namespace nvinfer1
{
- YoloLayerPlugin::YoloLayerPlugin()
+ YoloLayerPlugin::YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector& vYoloKernel)
{
- mClassCount = CLASS_NUM;
- mYoloKernel.clear();
- mYoloKernel.push_back(yolo1);
- mYoloKernel.push_back(yolo2);
- mYoloKernel.push_back(yolo3);
-
- mKernelCount = mYoloKernel.size();
+ mClassCount = classCount;
+ mYoloV5NetWidth = netWidth;
+ mYoloV5NetHeight = netHeight;
+ mMaxOutObject = maxOut;
+ mYoloKernel = vYoloKernel;
+ mKernelCount = vYoloKernel.size();
CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
- size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
- for(int ii = 0; ii < mKernelCount; ii ++)
+ size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2;
+ for (int ii = 0; ii < mKernelCount; ii++)
{
- CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen));
+ CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen));
const auto& yolo = mYoloKernel[ii];
CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
}
}
-
YoloLayerPlugin::~YoloLayerPlugin()
{
+ for (int ii = 0; ii < mKernelCount; ii++)
+ {
+ CUDA_CHECK(cudaFree(mAnchor[ii]));
+ }
+ CUDA_CHECK(cudaFreeHost(mAnchor));
}
// create the plugin at runtime from a byte stream
@@ -38,20 +60,21 @@ namespace nvinfer1
read(d, mClassCount);
read(d, mThreadCount);
read(d, mKernelCount);
+ read(d, mYoloV5NetWidth);
+ read(d, mYoloV5NetHeight);
+ read(d, mMaxOutObject);
mYoloKernel.resize(mKernelCount);
- auto kernelSize = mKernelCount*sizeof(YoloKernel);
- memcpy(mYoloKernel.data(),d,kernelSize);
+ auto kernelSize = mKernelCount * sizeof(YoloKernel);
+ memcpy(mYoloKernel.data(), d, kernelSize);
d += kernelSize;
-
CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
- size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
- for(int ii = 0; ii < mKernelCount; ii ++)
+ size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2;
+ for (int ii = 0; ii < mKernelCount; ii++)
{
- CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen));
+ CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen));
const auto& yolo = mYoloKernel[ii];
CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
}
-
assert(d == a + length);
}
@@ -62,27 +85,30 @@ namespace nvinfer1
write(d, mClassCount);
write(d, mThreadCount);
write(d, mKernelCount);
- auto kernelSize = mKernelCount*sizeof(YoloKernel);
- memcpy(d,mYoloKernel.data(),kernelSize);
+ write(d, mYoloV5NetWidth);
+ write(d, mYoloV5NetHeight);
+ write(d, mMaxOutObject);
+ auto kernelSize = mKernelCount * sizeof(YoloKernel);
+ memcpy(d, mYoloKernel.data(), kernelSize);
d += kernelSize;
assert(d == a + getSerializationSize());
}
-
+
size_t YoloLayerPlugin::getSerializationSize() const
- {
- return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size();
+ {
+ return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size() + sizeof(mYoloV5NetWidth) + sizeof(mYoloV5NetHeight) + sizeof(mMaxOutObject);
}
int YoloLayerPlugin::initialize()
- {
+ {
return 0;
}
-
+
Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
{
//output the result to channel
- int totalsize = MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float);
+ int totalsize = mMaxOutObject * sizeof(Detection) / sizeof(float);
return Dims3(totalsize + 1, 1, 1);
}
@@ -146,26 +172,27 @@ namespace nvinfer1
// Clone the plugin
IPluginV2IOExt* YoloLayerPlugin::clone() const
{
- YoloLayerPlugin *p = new YoloLayerPlugin();
+ YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, mYoloKernel);
p->setPluginNamespace(mPluginNamespace);
return p;
}
- __device__ float Logist(float data){ return 1.0f / (1.0f + expf(-data)); };
+ __device__ float Logist(float data) { return 1.0f / (1.0f + expf(-data)); };
+
+ __global__ void CalDetection(const float *input, float *output, int noElements,
+ const int netwidth, const int netheight, int maxoutobject, int yoloWidth, int yoloHeight, const float anchors[CHECK_COUNT * 2], int classes, int outputElem)
+ {
- __global__ void CalDetection(const float *input, float *output,int noElements,
- int yoloWidth,int yoloHeight,const float anchors[CHECK_COUNT*2],int classes,int outputElem) {
-
int idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx >= noElements) return;
int total_grid = yoloWidth * yoloHeight;
int bnIdx = idx / total_grid;
- idx = idx - total_grid*bnIdx;
+ idx = idx - total_grid * bnIdx;
int info_len_i = 5 + classes;
const float* curInput = input + bnIdx * (info_len_i * total_grid * CHECK_COUNT);
- for (int k = 0; k < 3; ++k) {
+ for (int k = 0; k < CHECK_COUNT; ++k) {
float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]);
if (box_prob < IGNORE_THRESH) continue;
int class_id = 0;
@@ -177,51 +204,57 @@ namespace nvinfer1
class_id = i - 5;
}
}
- float *res_count = output + bnIdx*outputElem;
+ float *res_count = output + bnIdx * outputElem;
int count = (int)atomicAdd(res_count, 1);
- if (count >= MAX_OUTPUT_BBOX_COUNT) return;
- char* data = (char *)res_count + sizeof(float) + count * sizeof(Detection);
- Detection* det = (Detection*)(data);
+ if (count >= maxoutobject) return;
+ char *data = (char*)res_count + sizeof(float) + count * sizeof(Detection);
+ Detection *det = (Detection*)(data);
int row = idx / yoloWidth;
int col = idx % yoloWidth;
//Location
- det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * INPUT_W / yoloWidth;
- det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * INPUT_H / yoloHeight;
+ // pytorch:
+ // y = x[i].sigmoid()
+ // y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
+ // y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
+ // X: (sigmoid(tx) + cx)/FeaturemapW * netwidth
+ det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * netwidth / yoloWidth;
+ det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * netheight / yoloHeight;
+
+ // W: (Pw * e^tw) / FeaturemapW * netwidth
+ // v5: https://github.com/ultralytics/yolov5/issues/471
det->bbox[2] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 2 * total_grid]);
- det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2*k];
+ det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2 * k];
det->bbox[3] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 3 * total_grid]);
- det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2*k + 1];
+ det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2 * k + 1];
det->conf = box_prob * max_cls_prob;
det->class_id = class_id;
}
}
- void YoloLayerPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) {
-
- int outputElem = 1 + MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float);
-
- for(int idx = 0 ; idx < batchSize; ++idx) {
- CUDA_CHECK(cudaMemset(output + idx*outputElem, 0, sizeof(float)));
+ void YoloLayerPlugin::forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize)
+ {
+ int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float);
+ for (int idx = 0; idx < batchSize; ++idx) {
+ CUDA_CHECK(cudaMemset(output + idx * outputElem, 0, sizeof(float)));
}
int numElem = 0;
- for (unsigned int i = 0; i < mYoloKernel.size(); ++i)
- {
+ for (unsigned int i = 0; i < mYoloKernel.size(); ++i) {
const auto& yolo = mYoloKernel[i];
- numElem = yolo.width*yolo.height*batchSize;
- if (numElem < mThreadCount)
- mThreadCount = numElem;
- CalDetection<<< (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream>>>
- (inputs[i], output, numElem, yolo.width, yolo.height, (float *)mAnchor[i], mClassCount, outputElem);
- }
+ numElem = yolo.width * yolo.height * batchSize;
+ if (numElem < mThreadCount) mThreadCount = numElem;
+ //printf("Net: %d %d \n", mYoloV5NetWidth, mYoloV5NetHeight);
+ CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> >
+ (inputs[i], output, numElem, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, yolo.width, yolo.height, (float*)mAnchor[i], mClassCount, outputElem);
+ }
}
- int YoloLayerPlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
+ int YoloLayerPlugin::enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream)
{
- forwardGpu((const float *const *)inputs, (float*)outputs[0], stream, batchSize);
+ forwardGpu((const float* const*)inputs, (float*)outputs[0], stream, batchSize);
return 0;
}
@@ -238,22 +271,32 @@ namespace nvinfer1
const char* YoloPluginCreator::getPluginName() const
{
- return "YoloLayer_TRT";
+ return "YoloLayer_TRT";
}
const char* YoloPluginCreator::getPluginVersion() const
{
- return "1";
+ return "1";
}
const PluginFieldCollection* YoloPluginCreator::getFieldNames()
{
- return &mFC;
+ return &mFC;
}
IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
{
- YoloLayerPlugin* obj = new YoloLayerPlugin();
+ assert(fc->nbFields == 2);
+ assert(strcmp(fc->fields[0].name, "netinfo") == 0);
+ assert(strcmp(fc->fields[1].name, "kernels") == 0);
+ int *p_netinfo = (int*)(fc->fields[0].data);
+ int class_count = p_netinfo[0];
+ int input_w = p_netinfo[1];
+ int input_h = p_netinfo[2];
+ int max_output_object_count = p_netinfo[3];
+ std::vector kernels(fc->fields[1].length);
+ memcpy(&kernels[0], fc->fields[1].data, kernels.size() * sizeof(Yolo::YoloKernel));
+ YoloLayerPlugin* obj = new YoloLayerPlugin(class_count, input_w, input_h, max_output_object_count, kernels);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
@@ -261,10 +304,10 @@ namespace nvinfer1
IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
{
// This object will be deleted when the network is destroyed, which will
- // call MishPlugin::destroy()
+ // call YoloLayerPlugin::destroy()
YoloLayerPlugin* obj = new YoloLayerPlugin(serialData, serialLength);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
-
}
+
diff --git a/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.h b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.h
new file mode 100644
index 0000000..49f6474
--- /dev/null
+++ b/external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.h
@@ -0,0 +1,137 @@
+#ifndef _YOLO_LAYER_H
+#define _YOLO_LAYER_H
+
+#include
+#include
+#include "NvInfer.h"
+
+namespace Yolo
+{
+ static constexpr int CHECK_COUNT = 3;
+ static constexpr float IGNORE_THRESH = 0.1f;
+ struct YoloKernel
+ {
+ int width;
+ int height;
+ float anchors[CHECK_COUNT * 2];
+ };
+ static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000;
+ static constexpr int CLASS_NUM = 80;
+ static constexpr int INPUT_H = 640; // yolov5's input height and width must be divisible by 32.
+ static constexpr int INPUT_W = 640;
+
+ static constexpr int LOCATIONS = 4;
+ struct alignas(float) Detection {
+ //center_x center_y w h
+ float bbox[LOCATIONS];
+ float conf; // bbox_conf * cls_conf
+ float class_id;
+ };
+}
+
+namespace nvinfer1
+{
+ class YoloLayerPlugin : public IPluginV2IOExt
+ {
+ public:
+ YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector& vYoloKernel);
+ YoloLayerPlugin(const void* data, size_t length);
+ ~YoloLayerPlugin();
+
+ int getNbOutputs() const override
+ {
+ return 1;
+ }
+
+ Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
+
+ int initialize() override;
+
+ virtual void terminate() override {};
+
+ virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
+
+ virtual int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override;
+
+ virtual size_t getSerializationSize() const override;
+
+ virtual void serialize(void* buffer) const override;
+
+ bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
+ return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
+ }
+
+ const char* getPluginType() const override;
+
+ const char* getPluginVersion() const override;
+
+ void destroy() override;
+
+ IPluginV2IOExt* clone() const override;
+
+ void setPluginNamespace(const char* pluginNamespace) override;
+
+ const char* getPluginNamespace() const override;
+
+ DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
+
+ bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
+
+ bool canBroadcastInputAcrossBatch(int inputIndex) const override;
+
+ void attachToContext(
+ cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
+
+ void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
+
+ void detachFromContext() override;
+
+ private:
+ void forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize = 1);
+ int mThreadCount = 256;
+ const char* mPluginNamespace;
+ int mKernelCount;
+ int mClassCount;
+ int mYoloV5NetWidth;
+ int mYoloV5NetHeight;
+ int mMaxOutObject;
+ std::vector mYoloKernel;
+ void** mAnchor;
+ };
+
+ class YoloPluginCreator : public IPluginCreator
+ {
+ public:
+ YoloPluginCreator();
+
+ ~YoloPluginCreator() override = default;
+
+ const char* getPluginName() const override;
+
+ const char* getPluginVersion() const override;
+
+ const PluginFieldCollection* getFieldNames() override;
+
+ IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
+
+ IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
+
+ void setPluginNamespace(const char* libNamespace) override
+ {
+ mNamespace = libNamespace;
+ }
+
+ const char* getPluginNamespace() const override
+ {
+ return mNamespace.c_str();
+ }
+
+ private:
+ std::string mNamespace;
+ static PluginFieldCollection mFC;
+ static std::vector mPluginAttributes;
+ };
+ REGISTER_TENSORRT_PLUGIN(YoloPluginCreator);
+};
+
+#endif
diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/utils.h b/external/yolov5/nvdsinfer_custom_impl_Yolo/utils.h
deleted file mode 100644
index 0de663c..0000000
--- a/external/yolov5/nvdsinfer_custom_impl_Yolo/utils.h
+++ /dev/null
@@ -1,94 +0,0 @@
-#ifndef __TRT_UTILS_H_
-#define __TRT_UTILS_H_
-
-#include
-#include
-#include
-#include
-
-#ifndef CUDA_CHECK
-
-#define CUDA_CHECK(callstr) \
- { \
- cudaError_t error_code = callstr; \
- if (error_code != cudaSuccess) { \
- std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
- assert(0); \
- } \
- }
-
-#endif
-
-namespace Tn
-{
- class Profiler : public nvinfer1::IProfiler
- {
- public:
- void printLayerTimes(int itrationsTimes)
- {
- float totalTime = 0;
- for (size_t i = 0; i < mProfile.size(); i++)
- {
- printf("%-40.40s %4.3fms\n", mProfile[i].first.c_str(), mProfile[i].second / itrationsTimes);
- totalTime += mProfile[i].second;
- }
- printf("Time over all layers: %4.3f\n", totalTime / itrationsTimes);
- }
- private:
- typedef std::pair Record;
- std::vector mProfile;
-
- virtual void reportLayerTime(const char* layerName, float ms)
- {
- auto record = std::find_if(mProfile.begin(), mProfile.end(), [&](const Record& r){ return r.first == layerName; });
- if (record == mProfile.end())
- mProfile.push_back(std::make_pair(layerName, ms));
- else
- record->second += ms;
- }
- };
-
- //Logger for TensorRT info/warning/errors
- class Logger : public nvinfer1::ILogger
- {
- public:
-
- Logger(): Logger(Severity::kWARNING) {}
-
- Logger(Severity severity): reportableSeverity(severity) {}
-
- void log(Severity severity, const char* msg) override
- {
- // suppress messages with severity enum value greater than the reportable
- if (severity > reportableSeverity) return;
-
- switch (severity)
- {
- case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break;
- case Severity::kERROR: std::cerr << "ERROR: "; break;
- case Severity::kWARNING: std::cerr << "WARNING: "; break;
- case Severity::kINFO: std::cerr << "INFO: "; break;
- default: std::cerr << "UNKNOWN: "; break;
- }
- std::cerr << msg << std::endl;
- }
-
- Severity reportableSeverity{Severity::kWARNING};
- };
-
- template
- void write(char*& buffer, const T& val)
- {
- *reinterpret_cast(buffer) = val;
- buffer += sizeof(T);
- }
-
- template
- void read(const char*& buffer, T& val)
- {
- val = *reinterpret_cast(buffer);
- buffer += sizeof(T);
- }
-}
-
-#endif
\ No newline at end of file
diff --git a/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.h b/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.h
deleted file mode 100644
index 91116cd..0000000
--- a/external/yolov5/nvdsinfer_custom_impl_Yolo/yololayer.h
+++ /dev/null
@@ -1,152 +0,0 @@
-#ifndef _YOLO_LAYER_H
-#define _YOLO_LAYER_H
-
-#include
-#include
-#include "NvInfer.h"
-
-namespace Yolo
-{
- static constexpr int CHECK_COUNT = 3;
- static constexpr float IGNORE_THRESH = 0.1f;
- static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000;
- static constexpr int CLASS_NUM = 80;
- static constexpr int INPUT_H = 608;
- static constexpr int INPUT_W = 608;
-
- struct YoloKernel
- {
- int width;
- int height;
- float anchors[CHECK_COUNT*2];
- };
-
- static constexpr YoloKernel yolo1 = {
- INPUT_W / 32,
- INPUT_H / 32,
- {116,90, 156,198, 373,326}
- };
- static constexpr YoloKernel yolo2 = {
- INPUT_W / 16,
- INPUT_H / 16,
- {30,61, 62,45, 59,119}
- };
- static constexpr YoloKernel yolo3 = {
- INPUT_W / 8,
- INPUT_H / 8,
- {10,13, 16,30, 33,23}
- };
-
- static constexpr int LOCATIONS = 4;
- struct alignas(float) Detection{
- //center_x center_y w h
- float bbox[LOCATIONS];
- float conf; // bbox_conf * cls_conf
- float class_id;
- };
-}
-
-namespace nvinfer1
-{
- class YoloLayerPlugin: public IPluginV2IOExt
- {
- public:
- explicit YoloLayerPlugin();
- YoloLayerPlugin(const void* data, size_t length);
-
- ~YoloLayerPlugin();
-
- int getNbOutputs() const override
- {
- return 1;
- }
-
- Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
-
- int initialize() override;
-
- virtual void terminate() override {};
-
- virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;}
-
- virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override;
-
- virtual size_t getSerializationSize() const override;
-
- virtual void serialize(void* buffer) const override;
-
- bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
- return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
- }
-
- const char* getPluginType() const override;
-
- const char* getPluginVersion() const override;
-
- void destroy() override;
-
- IPluginV2IOExt* clone() const override;
-
- void setPluginNamespace(const char* pluginNamespace) override;
-
- const char* getPluginNamespace() const override;
-
- DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
-
- bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
-
- bool canBroadcastInputAcrossBatch(int inputIndex) const override;
-
- void attachToContext(
- cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
-
- void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
-
- void detachFromContext() override;
-
- private:
- void forwardGpu(const float *const * inputs,float * output, cudaStream_t stream,int batchSize = 1);
- int mClassCount;
- int mKernelCount;
- std::vector mYoloKernel;
- int mThreadCount = 256;
- void** mAnchor;
- const char* mPluginNamespace;
- };
-
- class YoloPluginCreator : public IPluginCreator
- {
- public:
- YoloPluginCreator();
-
- ~YoloPluginCreator() override = default;
-
- const char* getPluginName() const override;
-
- const char* getPluginVersion() const override;
-
- const PluginFieldCollection* getFieldNames() override;
-
- IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
-
- IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
-
- void setPluginNamespace(const char* libNamespace) override
- {
- mNamespace = libNamespace;
- }
-
- const char* getPluginNamespace() const override
- {
- return mNamespace.c_str();
- }
-
- private:
- std::string mNamespace;
- static PluginFieldCollection mFC;
- static std::vector mPluginAttributes;
- };
- REGISTER_TENSORRT_PLUGIN(YoloPluginCreator);
-};
-
-#endif