Update to YOLOv5 5.0
Updated files for YOLOv5 5.0
This commit is contained in:
105
YOLOv5.md
105
YOLOv5.md
@@ -1,9 +1,9 @@
|
||||
# YOLOv5
|
||||
NVIDIA DeepStream SDK 5.1 configuration for YOLOv5 models
|
||||
NVIDIA DeepStream SDK 5.1 configuration for YOLOv5 5.0 models
|
||||
|
||||
Thanks [DanaHan](https://github.com/DanaHan/Yolov5-in-Deepstream-5.0), [wang-xinyu](https://github.com/wang-xinyu/tensorrtx) and [Ultralytics](https://github.com/ultralytics/yolov5)
|
||||
Thanks [wang-xinyu](https://github.com/wang-xinyu/tensorrtx) and [Ultralytics](https://github.com/ultralytics/yolov5)
|
||||
|
||||
Supported version: YOLOv5 3.0/3.1
|
||||
Supported version: YOLOv5 5.0
|
||||
|
||||
##
|
||||
|
||||
@@ -16,53 +16,15 @@ Supported version: YOLOv5 3.0/3.1
|
||||
##
|
||||
|
||||
### Requirements
|
||||
* Python3
|
||||
```
|
||||
sudo apt-get install python3 python3-dev python3-pip
|
||||
pip3 install --upgrade pip
|
||||
```
|
||||
* [TensorRTX](https://github.com/wang-xinyu/tensorrtx/blob/master/tutorials/install.md)
|
||||
|
||||
* OpenCV Python
|
||||
```
|
||||
sudo apt-get install libopencv-dev
|
||||
pip3 install opencv-python
|
||||
```
|
||||
|
||||
* Matplotlib
|
||||
```
|
||||
pip3 install matplotlib
|
||||
```
|
||||
* [Ultralytics](https://github.com/ultralytics/yolov5/blob/master/requirements.txt)
|
||||
|
||||
* Matplotlib (for Jetson plataform)
|
||||
```
|
||||
sudo apt-get install python3-matplotlib
|
||||
```
|
||||
|
||||
* Scipy
|
||||
```
|
||||
pip3 install scipy
|
||||
```
|
||||
|
||||
* tqdm
|
||||
```
|
||||
pip3 install tqdm
|
||||
```
|
||||
|
||||
* Pandas
|
||||
```
|
||||
pip3 install pandas
|
||||
```
|
||||
|
||||
* seaborn
|
||||
```
|
||||
pip3 install seaborn
|
||||
```
|
||||
|
||||
* PyTorch
|
||||
```
|
||||
pip3 install torch torchvision
|
||||
```
|
||||
|
||||
* PyTorch (for Jetson plataform)
|
||||
```
|
||||
wget https://nvidia.box.com/shared/static/9eptse6jyly1ggt9axbja2yrmj6pbarc.whl -O torch-1.6.0-cp36-cp36m-linux_aarch64.whl
|
||||
@@ -84,20 +46,13 @@ sudo python3 setup.py install
|
||||
### Convert PyTorch model to wts file
|
||||
1. Download repositories
|
||||
```
|
||||
git clone https://github.com/DanaHan/Yolov5-in-Deepstream-5.0.git yolov5converter
|
||||
git clone https://github.com/wang-xinyu/tensorrtx.git
|
||||
git clone https://github.com/ultralytics/yolov5.git
|
||||
```
|
||||
|
||||
Note: checkout TensorRTX repo to 3.0/3.1 YOLOv5 version
|
||||
2. Download latest YoloV5 (YOLOv5s, YOLOv5m, YOLOv5l or YOLOv5x) weights to yolov5 folder (example for YOLOv5s)
|
||||
```
|
||||
cd tensorrtx
|
||||
git checkout '6d0f5cb'
|
||||
```
|
||||
|
||||
2. Download latest YoloV5 (YOLOv5s, YOLOv5m, YOLOv5l or YOLOv5x) weights to yolov5/weights directory (example for YOLOv5s)
|
||||
```
|
||||
wget https://github.com/ultralytics/yolov5/releases/download/v3.1/yolov5s.pt -P yolov5/weights/
|
||||
wget https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5s.pt -P yolov5/
|
||||
```
|
||||
|
||||
3. Copy gen_wts.py file (from tensorrtx/yolov5 folder) to yolov5 (ultralytics) folder
|
||||
@@ -108,36 +63,15 @@ cp tensorrtx/yolov5/gen_wts.py yolov5/gen_wts.py
|
||||
4. Generate wts file
|
||||
```
|
||||
cd yolov5
|
||||
python3 gen_wts.py
|
||||
python3 gen_wts.py yolov5s.pt
|
||||
```
|
||||
|
||||
yolov5s.wts file will be generated in yolov5 folder
|
||||
|
||||
<br />
|
||||
|
||||
Note: if you want to generate wts file to another YOLOv5 model (YOLOv5m, YOLOv5l or YOLOv5x), edit get_wts.py file changing yolov5s to your model name
|
||||
```
|
||||
model = torch.load('weights/yolov5s.pt', map_location=device)['model'].float() # load to FP32
|
||||
model.to(device).eval()
|
||||
|
||||
f = open('yolov5s.wts', 'w')
|
||||
```
|
||||
|
||||
##
|
||||
|
||||
### Convert wts file to TensorRT model
|
||||
1. Replace yololayer files from tensorrtx/yolov5 folder to yololayer and hardswish files from yolov5converter
|
||||
```
|
||||
mv yolov5converter/yololayer.cu tensorrtx/yolov5/yololayer.cu
|
||||
mv yolov5converter/yololayer.h tensorrtx/yolov5/yololayer.h
|
||||
```
|
||||
|
||||
2. Move generated yolov5s.wts file to tensorrtx/yolov5 folder (example for YOLOv5s)
|
||||
```
|
||||
cp yolov5/yolov5s.wts tensorrtx/yolov5/yolov5s.wts
|
||||
```
|
||||
|
||||
3. Build tensorrtx/yolov5
|
||||
1. Build tensorrtx/yolov5
|
||||
```
|
||||
cd tensorrtx/yolov5
|
||||
mkdir build
|
||||
@@ -146,12 +80,17 @@ cmake ..
|
||||
make
|
||||
```
|
||||
|
||||
4. Convert to TensorRT model (yolov5s.engine file will be generated in tensorrtx/yolov5/build folder)
|
||||
2. Move generated yolov5s.wts file to tensorrtx/yolov5 folder (example for YOLOv5s)
|
||||
```
|
||||
sudo ./yolov5 -s
|
||||
cp yolov5/yolov5s.wts tensorrtx/yolov5/build/yolov5s.wts
|
||||
```
|
||||
|
||||
5. Create a custom yolo folder and copy generated files (example for YOLOv5s)
|
||||
3. Convert to TensorRT model (yolov5s.engine file will be generated in tensorrtx/yolov5/build folder)
|
||||
```
|
||||
sudo ./yolov5 -s yolov5s.wts yolov5s.engine s
|
||||
```
|
||||
|
||||
4. Create a custom yolo folder and copy generated file (example for YOLOv5s)
|
||||
```
|
||||
mkdir /opt/nvidia/deepstream/deepstream-5.1/sources/yolo
|
||||
cp yolov5s.engine /opt/nvidia/deepstream/deepstream-5.1/sources/yolo/yolov5s.engine
|
||||
@@ -159,15 +98,13 @@ cp yolov5s.engine /opt/nvidia/deepstream/deepstream-5.1/sources/yolo/yolov5s.eng
|
||||
|
||||
<br />
|
||||
|
||||
Note: by default, yolov5 script generate model with batch size = 1, FP16 mode and s model.
|
||||
Note: by default, yolov5 script generate model with batch size = 1 and FP16 mode.
|
||||
```
|
||||
#define USE_FP16 // comment out this if want to use FP32
|
||||
#define USE_FP32 // set USE_INT8 or USE_FP16 or USE_FP32
|
||||
#define DEVICE 0 // GPU id
|
||||
#define NMS_THRESH 0.4
|
||||
#define CONF_THRESH 0.5
|
||||
#define BATCH_SIZE 1
|
||||
|
||||
#define NET s // s m l x
|
||||
```
|
||||
Edit yolov5.cpp file before compile if you want to change this parameters.
|
||||
|
||||
@@ -179,7 +116,7 @@ Edit yolov5.cpp file before compile if you want to change this parameters.
|
||||
sudo chmod -R 777 /opt/nvidia/deepstream/deepstream-5.1/sources/
|
||||
```
|
||||
|
||||
2. Donwload [my external/yolov5 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5) and move files to created yolo folder
|
||||
2. Donwload [my external/yolov5-5.0 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5-5.0) and move files to created yolo folder
|
||||
|
||||
3. Compile lib
|
||||
|
||||
@@ -198,7 +135,7 @@ CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
|
||||
##
|
||||
|
||||
### Testing model
|
||||
Use my edited [deepstream_app_config.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5/deepstream_app_config.txt) and [config_infer_primary.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5/config_infer_primary.txt) files available in [my external/yolov5 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5)
|
||||
Use my edited [deepstream_app_config.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5-5.0/deepstream_app_config.txt) and [config_infer_primary.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5-5.0/config_infer_primary.txt) files available in [my external/yolov5-5.0 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5-5.0)
|
||||
|
||||
Run command
|
||||
```
|
||||
|
||||
18
external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/cuda_utils.h
vendored
Normal file
18
external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/cuda_utils.h
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
#ifndef TRTX_CUDA_UTILS_H_
|
||||
#define TRTX_CUDA_UTILS_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#ifndef CUDA_CHECK
|
||||
#define CUDA_CHECK(callstr)\
|
||||
{\
|
||||
cudaError_t error_code = callstr;\
|
||||
if (error_code != cudaSuccess) {\
|
||||
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\
|
||||
assert(0);\
|
||||
}\
|
||||
}
|
||||
#endif // CUDA_CHECK
|
||||
|
||||
#endif // TRTX_CUDA_UTILS_H_
|
||||
|
||||
@@ -1,33 +1,55 @@
|
||||
#include <assert.h>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include "yololayer.h"
|
||||
#include "utils.h"
|
||||
#include "cuda_utils.h"
|
||||
|
||||
namespace Tn
|
||||
{
|
||||
template<typename T>
|
||||
void write(char*& buffer, const T& val)
|
||||
{
|
||||
*reinterpret_cast<T*>(buffer) = val;
|
||||
buffer += sizeof(T);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void read(const char*& buffer, T& val)
|
||||
{
|
||||
val = *reinterpret_cast<const T*>(buffer);
|
||||
buffer += sizeof(T);
|
||||
}
|
||||
}
|
||||
|
||||
using namespace Yolo;
|
||||
|
||||
namespace nvinfer1
|
||||
{
|
||||
YoloLayerPlugin::YoloLayerPlugin()
|
||||
YoloLayerPlugin::YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector<Yolo::YoloKernel>& vYoloKernel)
|
||||
{
|
||||
mClassCount = CLASS_NUM;
|
||||
mYoloKernel.clear();
|
||||
mYoloKernel.push_back(yolo1);
|
||||
mYoloKernel.push_back(yolo2);
|
||||
mYoloKernel.push_back(yolo3);
|
||||
|
||||
mKernelCount = mYoloKernel.size();
|
||||
mClassCount = classCount;
|
||||
mYoloV5NetWidth = netWidth;
|
||||
mYoloV5NetHeight = netHeight;
|
||||
mMaxOutObject = maxOut;
|
||||
mYoloKernel = vYoloKernel;
|
||||
mKernelCount = vYoloKernel.size();
|
||||
|
||||
CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
|
||||
size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
|
||||
for(int ii = 0; ii < mKernelCount; ii ++)
|
||||
size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2;
|
||||
for (int ii = 0; ii < mKernelCount; ii++)
|
||||
{
|
||||
CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen));
|
||||
CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen));
|
||||
const auto& yolo = mYoloKernel[ii];
|
||||
CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
|
||||
}
|
||||
}
|
||||
|
||||
YoloLayerPlugin::~YoloLayerPlugin()
|
||||
{
|
||||
for (int ii = 0; ii < mKernelCount; ii++)
|
||||
{
|
||||
CUDA_CHECK(cudaFree(mAnchor[ii]));
|
||||
}
|
||||
CUDA_CHECK(cudaFreeHost(mAnchor));
|
||||
}
|
||||
|
||||
// create the plugin at runtime from a byte stream
|
||||
@@ -38,20 +60,21 @@ namespace nvinfer1
|
||||
read(d, mClassCount);
|
||||
read(d, mThreadCount);
|
||||
read(d, mKernelCount);
|
||||
read(d, mYoloV5NetWidth);
|
||||
read(d, mYoloV5NetHeight);
|
||||
read(d, mMaxOutObject);
|
||||
mYoloKernel.resize(mKernelCount);
|
||||
auto kernelSize = mKernelCount*sizeof(YoloKernel);
|
||||
memcpy(mYoloKernel.data(),d,kernelSize);
|
||||
auto kernelSize = mKernelCount * sizeof(YoloKernel);
|
||||
memcpy(mYoloKernel.data(), d, kernelSize);
|
||||
d += kernelSize;
|
||||
|
||||
CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
|
||||
size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
|
||||
for(int ii = 0; ii < mKernelCount; ii ++)
|
||||
size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2;
|
||||
for (int ii = 0; ii < mKernelCount; ii++)
|
||||
{
|
||||
CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen));
|
||||
CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen));
|
||||
const auto& yolo = mYoloKernel[ii];
|
||||
CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
assert(d == a + length);
|
||||
}
|
||||
|
||||
@@ -62,27 +85,30 @@ namespace nvinfer1
|
||||
write(d, mClassCount);
|
||||
write(d, mThreadCount);
|
||||
write(d, mKernelCount);
|
||||
auto kernelSize = mKernelCount*sizeof(YoloKernel);
|
||||
memcpy(d,mYoloKernel.data(),kernelSize);
|
||||
write(d, mYoloV5NetWidth);
|
||||
write(d, mYoloV5NetHeight);
|
||||
write(d, mMaxOutObject);
|
||||
auto kernelSize = mKernelCount * sizeof(YoloKernel);
|
||||
memcpy(d, mYoloKernel.data(), kernelSize);
|
||||
d += kernelSize;
|
||||
|
||||
assert(d == a + getSerializationSize());
|
||||
}
|
||||
|
||||
|
||||
size_t YoloLayerPlugin::getSerializationSize() const
|
||||
{
|
||||
return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size();
|
||||
{
|
||||
return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size() + sizeof(mYoloV5NetWidth) + sizeof(mYoloV5NetHeight) + sizeof(mMaxOutObject);
|
||||
}
|
||||
|
||||
int YoloLayerPlugin::initialize()
|
||||
{
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
|
||||
{
|
||||
//output the result to channel
|
||||
int totalsize = MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float);
|
||||
int totalsize = mMaxOutObject * sizeof(Detection) / sizeof(float);
|
||||
|
||||
return Dims3(totalsize + 1, 1, 1);
|
||||
}
|
||||
@@ -146,26 +172,27 @@ namespace nvinfer1
|
||||
// Clone the plugin
|
||||
IPluginV2IOExt* YoloLayerPlugin::clone() const
|
||||
{
|
||||
YoloLayerPlugin *p = new YoloLayerPlugin();
|
||||
YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, mYoloKernel);
|
||||
p->setPluginNamespace(mPluginNamespace);
|
||||
return p;
|
||||
}
|
||||
|
||||
__device__ float Logist(float data){ return 1.0f / (1.0f + expf(-data)); };
|
||||
__device__ float Logist(float data) { return 1.0f / (1.0f + expf(-data)); };
|
||||
|
||||
__global__ void CalDetection(const float *input, float *output, int noElements,
|
||||
const int netwidth, const int netheight, int maxoutobject, int yoloWidth, int yoloHeight, const float anchors[CHECK_COUNT * 2], int classes, int outputElem)
|
||||
{
|
||||
|
||||
__global__ void CalDetection(const float *input, float *output,int noElements,
|
||||
int yoloWidth,int yoloHeight,const float anchors[CHECK_COUNT*2],int classes,int outputElem) {
|
||||
|
||||
int idx = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
if (idx >= noElements) return;
|
||||
|
||||
int total_grid = yoloWidth * yoloHeight;
|
||||
int bnIdx = idx / total_grid;
|
||||
idx = idx - total_grid*bnIdx;
|
||||
idx = idx - total_grid * bnIdx;
|
||||
int info_len_i = 5 + classes;
|
||||
const float* curInput = input + bnIdx * (info_len_i * total_grid * CHECK_COUNT);
|
||||
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
for (int k = 0; k < CHECK_COUNT; ++k) {
|
||||
float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]);
|
||||
if (box_prob < IGNORE_THRESH) continue;
|
||||
int class_id = 0;
|
||||
@@ -177,51 +204,57 @@ namespace nvinfer1
|
||||
class_id = i - 5;
|
||||
}
|
||||
}
|
||||
float *res_count = output + bnIdx*outputElem;
|
||||
float *res_count = output + bnIdx * outputElem;
|
||||
int count = (int)atomicAdd(res_count, 1);
|
||||
if (count >= MAX_OUTPUT_BBOX_COUNT) return;
|
||||
char* data = (char *)res_count + sizeof(float) + count * sizeof(Detection);
|
||||
Detection* det = (Detection*)(data);
|
||||
if (count >= maxoutobject) return;
|
||||
char *data = (char*)res_count + sizeof(float) + count * sizeof(Detection);
|
||||
Detection *det = (Detection*)(data);
|
||||
|
||||
int row = idx / yoloWidth;
|
||||
int col = idx % yoloWidth;
|
||||
|
||||
//Location
|
||||
det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * INPUT_W / yoloWidth;
|
||||
det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * INPUT_H / yoloHeight;
|
||||
// pytorch:
|
||||
// y = x[i].sigmoid()
|
||||
// y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
|
||||
// y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
||||
// X: (sigmoid(tx) + cx)/FeaturemapW * netwidth
|
||||
det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * netwidth / yoloWidth;
|
||||
det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * netheight / yoloHeight;
|
||||
|
||||
// W: (Pw * e^tw) / FeaturemapW * netwidth
|
||||
// v5: https://github.com/ultralytics/yolov5/issues/471
|
||||
det->bbox[2] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 2 * total_grid]);
|
||||
det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2*k];
|
||||
det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2 * k];
|
||||
det->bbox[3] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 3 * total_grid]);
|
||||
det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2*k + 1];
|
||||
det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2 * k + 1];
|
||||
det->conf = box_prob * max_cls_prob;
|
||||
det->class_id = class_id;
|
||||
}
|
||||
}
|
||||
|
||||
void YoloLayerPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) {
|
||||
|
||||
int outputElem = 1 + MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float);
|
||||
|
||||
for(int idx = 0 ; idx < batchSize; ++idx) {
|
||||
CUDA_CHECK(cudaMemset(output + idx*outputElem, 0, sizeof(float)));
|
||||
void YoloLayerPlugin::forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize)
|
||||
{
|
||||
int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float);
|
||||
for (int idx = 0; idx < batchSize; ++idx) {
|
||||
CUDA_CHECK(cudaMemset(output + idx * outputElem, 0, sizeof(float)));
|
||||
}
|
||||
int numElem = 0;
|
||||
for (unsigned int i = 0; i < mYoloKernel.size(); ++i)
|
||||
{
|
||||
for (unsigned int i = 0; i < mYoloKernel.size(); ++i) {
|
||||
const auto& yolo = mYoloKernel[i];
|
||||
numElem = yolo.width*yolo.height*batchSize;
|
||||
if (numElem < mThreadCount)
|
||||
mThreadCount = numElem;
|
||||
CalDetection<<< (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream>>>
|
||||
(inputs[i], output, numElem, yolo.width, yolo.height, (float *)mAnchor[i], mClassCount, outputElem);
|
||||
}
|
||||
numElem = yolo.width * yolo.height * batchSize;
|
||||
if (numElem < mThreadCount) mThreadCount = numElem;
|
||||
|
||||
//printf("Net: %d %d \n", mYoloV5NetWidth, mYoloV5NetHeight);
|
||||
CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> >
|
||||
(inputs[i], output, numElem, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, yolo.width, yolo.height, (float*)mAnchor[i], mClassCount, outputElem);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int YoloLayerPlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
|
||||
int YoloLayerPlugin::enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream)
|
||||
{
|
||||
forwardGpu((const float *const *)inputs, (float*)outputs[0], stream, batchSize);
|
||||
forwardGpu((const float* const*)inputs, (float*)outputs[0], stream, batchSize);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -238,22 +271,32 @@ namespace nvinfer1
|
||||
|
||||
const char* YoloPluginCreator::getPluginName() const
|
||||
{
|
||||
return "YoloLayer_TRT";
|
||||
return "YoloLayer_TRT";
|
||||
}
|
||||
|
||||
const char* YoloPluginCreator::getPluginVersion() const
|
||||
{
|
||||
return "1";
|
||||
return "1";
|
||||
}
|
||||
|
||||
const PluginFieldCollection* YoloPluginCreator::getFieldNames()
|
||||
{
|
||||
return &mFC;
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
|
||||
{
|
||||
YoloLayerPlugin* obj = new YoloLayerPlugin();
|
||||
assert(fc->nbFields == 2);
|
||||
assert(strcmp(fc->fields[0].name, "netinfo") == 0);
|
||||
assert(strcmp(fc->fields[1].name, "kernels") == 0);
|
||||
int *p_netinfo = (int*)(fc->fields[0].data);
|
||||
int class_count = p_netinfo[0];
|
||||
int input_w = p_netinfo[1];
|
||||
int input_h = p_netinfo[2];
|
||||
int max_output_object_count = p_netinfo[3];
|
||||
std::vector<Yolo::YoloKernel> kernels(fc->fields[1].length);
|
||||
memcpy(&kernels[0], fc->fields[1].data, kernels.size() * sizeof(Yolo::YoloKernel));
|
||||
YoloLayerPlugin* obj = new YoloLayerPlugin(class_count, input_w, input_h, max_output_object_count, kernels);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
@@ -261,10 +304,10 @@ namespace nvinfer1
|
||||
IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
|
||||
{
|
||||
// This object will be deleted when the network is destroyed, which will
|
||||
// call MishPlugin::destroy()
|
||||
// call YoloLayerPlugin::destroy()
|
||||
YoloLayerPlugin* obj = new YoloLayerPlugin(serialData, serialLength);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
137
external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.h
vendored
Normal file
137
external/yolov5-5.0/nvdsinfer_custom_impl_Yolo/yololayer.h
vendored
Normal file
@@ -0,0 +1,137 @@
|
||||
#ifndef _YOLO_LAYER_H
|
||||
#define _YOLO_LAYER_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "NvInfer.h"
|
||||
|
||||
namespace Yolo
|
||||
{
|
||||
static constexpr int CHECK_COUNT = 3;
|
||||
static constexpr float IGNORE_THRESH = 0.1f;
|
||||
struct YoloKernel
|
||||
{
|
||||
int width;
|
||||
int height;
|
||||
float anchors[CHECK_COUNT * 2];
|
||||
};
|
||||
static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000;
|
||||
static constexpr int CLASS_NUM = 80;
|
||||
static constexpr int INPUT_H = 640; // yolov5's input height and width must be divisible by 32.
|
||||
static constexpr int INPUT_W = 640;
|
||||
|
||||
static constexpr int LOCATIONS = 4;
|
||||
struct alignas(float) Detection {
|
||||
//center_x center_y w h
|
||||
float bbox[LOCATIONS];
|
||||
float conf; // bbox_conf * cls_conf
|
||||
float class_id;
|
||||
};
|
||||
}
|
||||
|
||||
namespace nvinfer1
|
||||
{
|
||||
class YoloLayerPlugin : public IPluginV2IOExt
|
||||
{
|
||||
public:
|
||||
YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector<Yolo::YoloKernel>& vYoloKernel);
|
||||
YoloLayerPlugin(const void* data, size_t length);
|
||||
~YoloLayerPlugin();
|
||||
|
||||
int getNbOutputs() const override
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
|
||||
|
||||
int initialize() override;
|
||||
|
||||
virtual void terminate() override {};
|
||||
|
||||
virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
|
||||
|
||||
virtual int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override;
|
||||
|
||||
virtual size_t getSerializationSize() const override;
|
||||
|
||||
virtual void serialize(void* buffer) const override;
|
||||
|
||||
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
|
||||
return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
|
||||
}
|
||||
|
||||
const char* getPluginType() const override;
|
||||
|
||||
const char* getPluginVersion() const override;
|
||||
|
||||
void destroy() override;
|
||||
|
||||
IPluginV2IOExt* clone() const override;
|
||||
|
||||
void setPluginNamespace(const char* pluginNamespace) override;
|
||||
|
||||
const char* getPluginNamespace() const override;
|
||||
|
||||
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
|
||||
|
||||
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
|
||||
|
||||
bool canBroadcastInputAcrossBatch(int inputIndex) const override;
|
||||
|
||||
void attachToContext(
|
||||
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
|
||||
|
||||
void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
|
||||
|
||||
void detachFromContext() override;
|
||||
|
||||
private:
|
||||
void forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize = 1);
|
||||
int mThreadCount = 256;
|
||||
const char* mPluginNamespace;
|
||||
int mKernelCount;
|
||||
int mClassCount;
|
||||
int mYoloV5NetWidth;
|
||||
int mYoloV5NetHeight;
|
||||
int mMaxOutObject;
|
||||
std::vector<Yolo::YoloKernel> mYoloKernel;
|
||||
void** mAnchor;
|
||||
};
|
||||
|
||||
class YoloPluginCreator : public IPluginCreator
|
||||
{
|
||||
public:
|
||||
YoloPluginCreator();
|
||||
|
||||
~YoloPluginCreator() override = default;
|
||||
|
||||
const char* getPluginName() const override;
|
||||
|
||||
const char* getPluginVersion() const override;
|
||||
|
||||
const PluginFieldCollection* getFieldNames() override;
|
||||
|
||||
IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
|
||||
|
||||
IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
|
||||
|
||||
void setPluginNamespace(const char* libNamespace) override
|
||||
{
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char* getPluginNamespace() const override
|
||||
{
|
||||
return mNamespace.c_str();
|
||||
}
|
||||
|
||||
private:
|
||||
std::string mNamespace;
|
||||
static PluginFieldCollection mFC;
|
||||
static std::vector<PluginField> mPluginAttributes;
|
||||
};
|
||||
REGISTER_TENSORRT_PLUGIN(YoloPluginCreator);
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -1,94 +0,0 @@
|
||||
#ifndef __TRT_UTILS_H_
|
||||
#define __TRT_UTILS_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cudnn.h>
|
||||
|
||||
#ifndef CUDA_CHECK
|
||||
|
||||
#define CUDA_CHECK(callstr) \
|
||||
{ \
|
||||
cudaError_t error_code = callstr; \
|
||||
if (error_code != cudaSuccess) { \
|
||||
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
|
||||
assert(0); \
|
||||
} \
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
namespace Tn
|
||||
{
|
||||
class Profiler : public nvinfer1::IProfiler
|
||||
{
|
||||
public:
|
||||
void printLayerTimes(int itrationsTimes)
|
||||
{
|
||||
float totalTime = 0;
|
||||
for (size_t i = 0; i < mProfile.size(); i++)
|
||||
{
|
||||
printf("%-40.40s %4.3fms\n", mProfile[i].first.c_str(), mProfile[i].second / itrationsTimes);
|
||||
totalTime += mProfile[i].second;
|
||||
}
|
||||
printf("Time over all layers: %4.3f\n", totalTime / itrationsTimes);
|
||||
}
|
||||
private:
|
||||
typedef std::pair<std::string, float> Record;
|
||||
std::vector<Record> mProfile;
|
||||
|
||||
virtual void reportLayerTime(const char* layerName, float ms)
|
||||
{
|
||||
auto record = std::find_if(mProfile.begin(), mProfile.end(), [&](const Record& r){ return r.first == layerName; });
|
||||
if (record == mProfile.end())
|
||||
mProfile.push_back(std::make_pair(layerName, ms));
|
||||
else
|
||||
record->second += ms;
|
||||
}
|
||||
};
|
||||
|
||||
//Logger for TensorRT info/warning/errors
|
||||
class Logger : public nvinfer1::ILogger
|
||||
{
|
||||
public:
|
||||
|
||||
Logger(): Logger(Severity::kWARNING) {}
|
||||
|
||||
Logger(Severity severity): reportableSeverity(severity) {}
|
||||
|
||||
void log(Severity severity, const char* msg) override
|
||||
{
|
||||
// suppress messages with severity enum value greater than the reportable
|
||||
if (severity > reportableSeverity) return;
|
||||
|
||||
switch (severity)
|
||||
{
|
||||
case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break;
|
||||
case Severity::kERROR: std::cerr << "ERROR: "; break;
|
||||
case Severity::kWARNING: std::cerr << "WARNING: "; break;
|
||||
case Severity::kINFO: std::cerr << "INFO: "; break;
|
||||
default: std::cerr << "UNKNOWN: "; break;
|
||||
}
|
||||
std::cerr << msg << std::endl;
|
||||
}
|
||||
|
||||
Severity reportableSeverity{Severity::kWARNING};
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
void write(char*& buffer, const T& val)
|
||||
{
|
||||
*reinterpret_cast<T*>(buffer) = val;
|
||||
buffer += sizeof(T);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void read(const char*& buffer, T& val)
|
||||
{
|
||||
val = *reinterpret_cast<const T*>(buffer);
|
||||
buffer += sizeof(T);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,152 +0,0 @@
|
||||
#ifndef _YOLO_LAYER_H
|
||||
#define _YOLO_LAYER_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "NvInfer.h"
|
||||
|
||||
namespace Yolo
|
||||
{
|
||||
static constexpr int CHECK_COUNT = 3;
|
||||
static constexpr float IGNORE_THRESH = 0.1f;
|
||||
static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000;
|
||||
static constexpr int CLASS_NUM = 80;
|
||||
static constexpr int INPUT_H = 608;
|
||||
static constexpr int INPUT_W = 608;
|
||||
|
||||
struct YoloKernel
|
||||
{
|
||||
int width;
|
||||
int height;
|
||||
float anchors[CHECK_COUNT*2];
|
||||
};
|
||||
|
||||
static constexpr YoloKernel yolo1 = {
|
||||
INPUT_W / 32,
|
||||
INPUT_H / 32,
|
||||
{116,90, 156,198, 373,326}
|
||||
};
|
||||
static constexpr YoloKernel yolo2 = {
|
||||
INPUT_W / 16,
|
||||
INPUT_H / 16,
|
||||
{30,61, 62,45, 59,119}
|
||||
};
|
||||
static constexpr YoloKernel yolo3 = {
|
||||
INPUT_W / 8,
|
||||
INPUT_H / 8,
|
||||
{10,13, 16,30, 33,23}
|
||||
};
|
||||
|
||||
static constexpr int LOCATIONS = 4;
|
||||
struct alignas(float) Detection{
|
||||
//center_x center_y w h
|
||||
float bbox[LOCATIONS];
|
||||
float conf; // bbox_conf * cls_conf
|
||||
float class_id;
|
||||
};
|
||||
}
|
||||
|
||||
namespace nvinfer1
|
||||
{
|
||||
class YoloLayerPlugin: public IPluginV2IOExt
|
||||
{
|
||||
public:
|
||||
explicit YoloLayerPlugin();
|
||||
YoloLayerPlugin(const void* data, size_t length);
|
||||
|
||||
~YoloLayerPlugin();
|
||||
|
||||
int getNbOutputs() const override
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
|
||||
|
||||
int initialize() override;
|
||||
|
||||
virtual void terminate() override {};
|
||||
|
||||
virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;}
|
||||
|
||||
virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override;
|
||||
|
||||
virtual size_t getSerializationSize() const override;
|
||||
|
||||
virtual void serialize(void* buffer) const override;
|
||||
|
||||
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
|
||||
return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
|
||||
}
|
||||
|
||||
const char* getPluginType() const override;
|
||||
|
||||
const char* getPluginVersion() const override;
|
||||
|
||||
void destroy() override;
|
||||
|
||||
IPluginV2IOExt* clone() const override;
|
||||
|
||||
void setPluginNamespace(const char* pluginNamespace) override;
|
||||
|
||||
const char* getPluginNamespace() const override;
|
||||
|
||||
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
|
||||
|
||||
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
|
||||
|
||||
bool canBroadcastInputAcrossBatch(int inputIndex) const override;
|
||||
|
||||
void attachToContext(
|
||||
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
|
||||
|
||||
void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
|
||||
|
||||
void detachFromContext() override;
|
||||
|
||||
private:
|
||||
void forwardGpu(const float *const * inputs,float * output, cudaStream_t stream,int batchSize = 1);
|
||||
int mClassCount;
|
||||
int mKernelCount;
|
||||
std::vector<Yolo::YoloKernel> mYoloKernel;
|
||||
int mThreadCount = 256;
|
||||
void** mAnchor;
|
||||
const char* mPluginNamespace;
|
||||
};
|
||||
|
||||
class YoloPluginCreator : public IPluginCreator
|
||||
{
|
||||
public:
|
||||
YoloPluginCreator();
|
||||
|
||||
~YoloPluginCreator() override = default;
|
||||
|
||||
const char* getPluginName() const override;
|
||||
|
||||
const char* getPluginVersion() const override;
|
||||
|
||||
const PluginFieldCollection* getFieldNames() override;
|
||||
|
||||
IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
|
||||
|
||||
IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
|
||||
|
||||
void setPluginNamespace(const char* libNamespace) override
|
||||
{
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char* getPluginNamespace() const override
|
||||
{
|
||||
return mNamespace.c_str();
|
||||
}
|
||||
|
||||
private:
|
||||
std::string mNamespace;
|
||||
static PluginFieldCollection mFC;
|
||||
static std::vector<PluginField> mPluginAttributes;
|
||||
};
|
||||
REGISTER_TENSORRT_PLUGIN(YoloPluginCreator);
|
||||
};
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user