Support for YOLOv5 3.0/3.1
Added support for YOLOv5 3.0/3.1
This commit is contained in:
191
YOLOv5-3.X.md
Normal file
191
YOLOv5-3.X.md
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
# YOLOv5
|
||||||
|
NVIDIA DeepStream SDK 5.1 configuration for YOLOv5 3.0/3.1 models
|
||||||
|
|
||||||
|
Thanks [DanaHan](https://github.com/DanaHan/Yolov5-in-Deepstream-5.0), [wang-xinyu](https://github.com/wang-xinyu/tensorrtx) and [Ultralytics](https://github.com/ultralytics/yolov5)
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
* [Requirements](#requirements)
|
||||||
|
* [Convert PyTorch model to wts file](#convert-pytorch-model-to-wts-file)
|
||||||
|
* [Convert wts file to TensorRT model](#convert-wts-file-to-tensorrt-model)
|
||||||
|
* [Compile nvdsinfer_custom_impl_Yolo](#compile-nvdsinfer_custom_impl_yolo)
|
||||||
|
* [Testing model](#testing-model)
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
### Requirements
|
||||||
|
* [TensorRTX](https://github.com/wang-xinyu/tensorrtx/blob/master/tutorials/install.md)
|
||||||
|
|
||||||
|
* [Ultralytics](https://github.com/ultralytics/yolov5/blob/v3.1/requirements.txt)
|
||||||
|
|
||||||
|
* Matplotlib (for Jetson plataform)
|
||||||
|
```
|
||||||
|
sudo apt-get install python3-matplotlib
|
||||||
|
```
|
||||||
|
|
||||||
|
* PyTorch (for Jetson plataform)
|
||||||
|
```
|
||||||
|
wget https://nvidia.box.com/shared/static/9eptse6jyly1ggt9axbja2yrmj6pbarc.whl -O torch-1.6.0-cp36-cp36m-linux_aarch64.whl
|
||||||
|
sudo apt-get install python3-pip libopenblas-base libopenmpi-dev
|
||||||
|
pip3 install torch-1.6.0-cp36-cp36m-linux_aarch64.whl
|
||||||
|
```
|
||||||
|
|
||||||
|
* TorchVision (for Jetson platform)
|
||||||
|
```
|
||||||
|
git clone -b v0.7.0 https://github.com/pytorch/vision torchvision
|
||||||
|
sudo apt-get install libjpeg-dev zlib1g-dev python3-pip
|
||||||
|
cd torchvision
|
||||||
|
export BUILD_VERSION=0.7.0
|
||||||
|
sudo python3 setup.py install
|
||||||
|
```
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
### Convert PyTorch model to wts file
|
||||||
|
1. Download repositories
|
||||||
|
```
|
||||||
|
git clone https://github.com/DanaHan/Yolov5-in-Deepstream-5.0.git yolov5converter
|
||||||
|
git clone -b yolov5-v3.1 https://github.com/wang-xinyu/tensorrtx.git
|
||||||
|
git clone -b v3.1 https://github.com/ultralytics/yolov5.git
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Download latest YoloV5 (YOLOv5s, YOLOv5m, YOLOv5l or YOLOv5x) weights to yolov5/weights directory (example for YOLOv5s)
|
||||||
|
```
|
||||||
|
wget https://github.com/ultralytics/yolov5/releases/download/v3.1/yolov5s.pt -P yolov5/weights/
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Copy gen_wts.py file (from tensorrtx/yolov5 folder) to yolov5 (ultralytics) folder
|
||||||
|
```
|
||||||
|
cp tensorrtx/yolov5/gen_wts.py yolov5/gen_wts.py
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Generate wts file
|
||||||
|
```
|
||||||
|
cd yolov5
|
||||||
|
python3 gen_wts.py
|
||||||
|
```
|
||||||
|
|
||||||
|
yolov5s.wts file will be generated in yolov5 folder
|
||||||
|
|
||||||
|
<br />
|
||||||
|
|
||||||
|
Note: if you want to generate wts file to another YOLOv5 model (YOLOv5m, YOLOv5l or YOLOv5x), edit get_wts.py file changing yolov5s to your model name
|
||||||
|
```
|
||||||
|
model = torch.load('weights/yolov5s.pt', map_location=device)['model'].float() # load to FP32
|
||||||
|
model.to(device).eval()
|
||||||
|
|
||||||
|
f = open('yolov5s.wts', 'w')
|
||||||
|
```
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
### Convert wts file to TensorRT model
|
||||||
|
1. Replace yololayer files from tensorrtx/yolov5 folder to yololayer and hardswish files from yolov5converter
|
||||||
|
```
|
||||||
|
mv yolov5converter/yololayer.cu tensorrtx/yolov5/yololayer.cu
|
||||||
|
mv yolov5converter/yololayer.h tensorrtx/yolov5/yololayer.h
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Move generated yolov5s.wts file to tensorrtx/yolov5 folder (example for YOLOv5s)
|
||||||
|
```
|
||||||
|
cp yolov5/yolov5s.wts tensorrtx/yolov5/yolov5s.wts
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Build tensorrtx/yolov5
|
||||||
|
```
|
||||||
|
cd tensorrtx/yolov5
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake ..
|
||||||
|
make
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Convert to TensorRT model (yolov5s.engine file will be generated in tensorrtx/yolov5/build folder)
|
||||||
|
```
|
||||||
|
sudo ./yolov5 -s
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Create a custom yolo folder and copy generated files (example for YOLOv5s)
|
||||||
|
```
|
||||||
|
mkdir /opt/nvidia/deepstream/deepstream-5.1/sources/yolo
|
||||||
|
cp yolov5s.engine /opt/nvidia/deepstream/deepstream-5.1/sources/yolo/yolov5s.engine
|
||||||
|
```
|
||||||
|
|
||||||
|
<br />
|
||||||
|
|
||||||
|
Note: by default, yolov5 script generate model with batch size = 1, FP16 mode and s model.
|
||||||
|
```
|
||||||
|
#define USE_FP16 // comment out this if want to use FP32
|
||||||
|
#define DEVICE 0 // GPU id
|
||||||
|
#define NMS_THRESH 0.4
|
||||||
|
#define CONF_THRESH 0.5
|
||||||
|
#define BATCH_SIZE 1
|
||||||
|
|
||||||
|
#define NET s // s m l x
|
||||||
|
```
|
||||||
|
Edit yolov5.cpp file before compile if you want to change this parameters.
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
### Compile nvdsinfer_custom_impl_Yolo
|
||||||
|
1. Run command
|
||||||
|
```
|
||||||
|
sudo chmod -R 777 /opt/nvidia/deepstream/deepstream-5.1/sources/
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Donwload [my external/yolov5 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5-3.X) and move files to created yolo folder
|
||||||
|
|
||||||
|
3. Compile lib
|
||||||
|
|
||||||
|
* x86 platform
|
||||||
|
```
|
||||||
|
cd /opt/nvidia/deepstream/deepstream-5.1/sources/yolo
|
||||||
|
CUDA_VER=11.1 make -C nvdsinfer_custom_impl_Yolo
|
||||||
|
```
|
||||||
|
|
||||||
|
* Jetson platform
|
||||||
|
```
|
||||||
|
cd /opt/nvidia/deepstream/deepstream-5.1/sources/yolo
|
||||||
|
CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
|
||||||
|
```
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
### Testing model
|
||||||
|
Use my edited [deepstream_app_config.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5-3.X/deepstream_app_config.txt) and [config_infer_primary.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5-3.X/config_infer_primary.txt) files available in [my external/yolov5-3.X folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5-3.X)
|
||||||
|
|
||||||
|
Run command
|
||||||
|
```
|
||||||
|
deepstream-app -c deepstream_app_config.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
<br />
|
||||||
|
|
||||||
|
Note: based on selected model, edit config_infer_primary.txt file
|
||||||
|
|
||||||
|
For example, if you using YOLOv5x
|
||||||
|
|
||||||
|
```
|
||||||
|
model-engine-file=yolov5s.engine
|
||||||
|
```
|
||||||
|
|
||||||
|
to
|
||||||
|
|
||||||
|
```
|
||||||
|
model-engine-file=yolov5x.engine
|
||||||
|
```
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
To change NMS_THRESH, edit nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp file and recompile
|
||||||
|
|
||||||
|
```
|
||||||
|
#define kNMS_THRESH 0.45
|
||||||
|
```
|
||||||
|
|
||||||
|
To change CONF_THRESH, edit config_infer_primary.txt file
|
||||||
|
|
||||||
|
```
|
||||||
|
[class-attrs-all]
|
||||||
|
pre-cluster-threshold=0.25
|
||||||
|
```
|
||||||
18
external/yolov5-3.X/config_infer_primary.txt
vendored
Normal file
18
external/yolov5-3.X/config_infer_primary.txt
vendored
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
[property]
|
||||||
|
gpu-id=0
|
||||||
|
net-scale-factor=0.0039215697906911373
|
||||||
|
model-color-format=0
|
||||||
|
model-engine-file=yolov5s.engine
|
||||||
|
labelfile-path=labels.txt
|
||||||
|
num-detected-classes=80
|
||||||
|
interval=0
|
||||||
|
gie-unique-id=1
|
||||||
|
process-mode=1
|
||||||
|
network-type=0
|
||||||
|
cluster-mode=4
|
||||||
|
maintain-aspect-ratio=0
|
||||||
|
parse-bbox-func-name=NvDsInferParseCustomYoloV5
|
||||||
|
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
|
||||||
|
|
||||||
|
[class-attrs-all]
|
||||||
|
pre-cluster-threshold=0.25
|
||||||
63
external/yolov5-3.X/deepstream_app_config.txt
vendored
Normal file
63
external/yolov5-3.X/deepstream_app_config.txt
vendored
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
[application]
|
||||||
|
enable-perf-measurement=1
|
||||||
|
perf-measurement-interval-sec=1
|
||||||
|
|
||||||
|
[tiled-display]
|
||||||
|
enable=1
|
||||||
|
rows=1
|
||||||
|
columns=1
|
||||||
|
width=1280
|
||||||
|
height=720
|
||||||
|
gpu-id=0
|
||||||
|
nvbuf-memory-type=0
|
||||||
|
|
||||||
|
[source0]
|
||||||
|
enable=1
|
||||||
|
type=3
|
||||||
|
uri=file://../../samples/streams/sample_1080p_h264.mp4
|
||||||
|
num-sources=1
|
||||||
|
gpu-id=0
|
||||||
|
cudadec-memtype=0
|
||||||
|
|
||||||
|
[sink0]
|
||||||
|
enable=1
|
||||||
|
type=2
|
||||||
|
sync=0
|
||||||
|
source-id=0
|
||||||
|
gpu-id=0
|
||||||
|
nvbuf-memory-type=0
|
||||||
|
|
||||||
|
[osd]
|
||||||
|
enable=1
|
||||||
|
gpu-id=0
|
||||||
|
border-width=1
|
||||||
|
text-size=15
|
||||||
|
text-color=1;1;1;1;
|
||||||
|
text-bg-color=0.3;0.3;0.3;1
|
||||||
|
font=Serif
|
||||||
|
show-clock=0
|
||||||
|
clock-x-offset=800
|
||||||
|
clock-y-offset=820
|
||||||
|
clock-text-size=12
|
||||||
|
clock-color=1;0;0;0
|
||||||
|
nvbuf-memory-type=0
|
||||||
|
|
||||||
|
[streammux]
|
||||||
|
gpu-id=0
|
||||||
|
live-source=0
|
||||||
|
batch-size=1
|
||||||
|
batched-push-timeout=40000
|
||||||
|
width=1920
|
||||||
|
height=1080
|
||||||
|
enable-padding=0
|
||||||
|
nvbuf-memory-type=0
|
||||||
|
|
||||||
|
[primary-gie]
|
||||||
|
enable=1
|
||||||
|
gpu-id=0
|
||||||
|
gie-unique-id=1
|
||||||
|
nvbuf-memory-type=0
|
||||||
|
config-file=config_infer_primary.txt
|
||||||
|
|
||||||
|
[tests]
|
||||||
|
file-loop=0
|
||||||
80
external/yolov5-3.X/labels.txt
vendored
Normal file
80
external/yolov5-3.X/labels.txt
vendored
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
person
|
||||||
|
bicycle
|
||||||
|
car
|
||||||
|
motorbike
|
||||||
|
aeroplane
|
||||||
|
bus
|
||||||
|
train
|
||||||
|
truck
|
||||||
|
boat
|
||||||
|
traffic light
|
||||||
|
fire hydrant
|
||||||
|
stop sign
|
||||||
|
parking meter
|
||||||
|
bench
|
||||||
|
bird
|
||||||
|
cat
|
||||||
|
dog
|
||||||
|
horse
|
||||||
|
sheep
|
||||||
|
cow
|
||||||
|
elephant
|
||||||
|
bear
|
||||||
|
zebra
|
||||||
|
giraffe
|
||||||
|
backpack
|
||||||
|
umbrella
|
||||||
|
handbag
|
||||||
|
tie
|
||||||
|
suitcase
|
||||||
|
frisbee
|
||||||
|
skis
|
||||||
|
snowboard
|
||||||
|
sports ball
|
||||||
|
kite
|
||||||
|
baseball bat
|
||||||
|
baseball glove
|
||||||
|
skateboard
|
||||||
|
surfboard
|
||||||
|
tennis racket
|
||||||
|
bottle
|
||||||
|
wine glass
|
||||||
|
cup
|
||||||
|
fork
|
||||||
|
knife
|
||||||
|
spoon
|
||||||
|
bowl
|
||||||
|
banana
|
||||||
|
apple
|
||||||
|
sandwich
|
||||||
|
orange
|
||||||
|
broccoli
|
||||||
|
carrot
|
||||||
|
hot dog
|
||||||
|
pizza
|
||||||
|
donut
|
||||||
|
cake
|
||||||
|
chair
|
||||||
|
sofa
|
||||||
|
pottedplant
|
||||||
|
bed
|
||||||
|
diningtable
|
||||||
|
toilet
|
||||||
|
tvmonitor
|
||||||
|
laptop
|
||||||
|
mouse
|
||||||
|
remote
|
||||||
|
keyboard
|
||||||
|
cell phone
|
||||||
|
microwave
|
||||||
|
oven
|
||||||
|
toaster
|
||||||
|
sink
|
||||||
|
refrigerator
|
||||||
|
book
|
||||||
|
clock
|
||||||
|
vase
|
||||||
|
scissors
|
||||||
|
teddy bear
|
||||||
|
hair drier
|
||||||
|
toothbrush
|
||||||
52
external/yolov5-3.X/nvdsinfer_custom_impl_Yolo/Makefile
vendored
Normal file
52
external/yolov5-3.X/nvdsinfer_custom_impl_Yolo/Makefile
vendored
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
CUDA_VER?=
|
||||||
|
ifeq ($(CUDA_VER),)
|
||||||
|
$(error "CUDA_VER is not set")
|
||||||
|
endif
|
||||||
|
CC:= g++
|
||||||
|
NVCC:=/usr/local/cuda-$(CUDA_VER)/bin/nvcc
|
||||||
|
|
||||||
|
CFLAGS:= -Wall -std=c++11 -shared -fPIC -Wno-error=deprecated-declarations
|
||||||
|
CFLAGS+= -I../../includes -I/usr/local/cuda-$(CUDA_VER)/include
|
||||||
|
|
||||||
|
LIBS:= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib64 -lcudart -lcublas -lstdc++fs
|
||||||
|
LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group
|
||||||
|
|
||||||
|
INCS:= $(wildcard *.h)
|
||||||
|
SRCFILES:= nvdsparsebbox_Yolo.cpp \
|
||||||
|
yololayer.cu
|
||||||
|
|
||||||
|
TARGET_LIB:= libnvdsinfer_custom_impl_Yolo.so
|
||||||
|
|
||||||
|
TARGET_OBJS:= $(SRCFILES:.cpp=.o)
|
||||||
|
TARGET_OBJS:= $(TARGET_OBJS:.cu=.o)
|
||||||
|
|
||||||
|
all: $(TARGET_LIB)
|
||||||
|
|
||||||
|
%.o: %.cpp $(INCS) Makefile
|
||||||
|
$(CC) -c -o $@ $(CFLAGS) $<
|
||||||
|
|
||||||
|
%.o: %.cu $(INCS) Makefile
|
||||||
|
$(NVCC) -c -o $@ --compiler-options '-fPIC' $<
|
||||||
|
|
||||||
|
$(TARGET_LIB) : $(TARGET_OBJS)
|
||||||
|
$(CC) -o $@ $(TARGET_OBJS) $(LFLAGS)
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -rf $(TARGET_LIB)
|
||||||
|
rm -rf $(TARGET_OBJS)
|
||||||
122
external/yolov5-3.X/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp
vendored
Normal file
122
external/yolov5-3.X/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp
vendored
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "nvdsinfer_custom_impl.h"
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#define kNMS_THRESH 0.45
|
||||||
|
|
||||||
|
static constexpr int LOCATIONS = 4;
|
||||||
|
struct alignas(float) Detection{
|
||||||
|
//center_x center_y w h
|
||||||
|
float bbox[LOCATIONS];
|
||||||
|
float conf; // bbox_conf * cls_conf
|
||||||
|
float class_id;
|
||||||
|
};
|
||||||
|
|
||||||
|
float iou(float lbox[4], float rbox[4]) {
|
||||||
|
float interBox[] = {
|
||||||
|
std::max(lbox[0] - lbox[2]/2.f , rbox[0] - rbox[2]/2.f), //left
|
||||||
|
std::min(lbox[0] + lbox[2]/2.f , rbox[0] + rbox[2]/2.f), //right
|
||||||
|
std::max(lbox[1] - lbox[3]/2.f , rbox[1] - rbox[3]/2.f), //top
|
||||||
|
std::min(lbox[1] + lbox[3]/2.f , rbox[1] + rbox[3]/2.f), //bottom
|
||||||
|
};
|
||||||
|
|
||||||
|
if(interBox[2] > interBox[3] || interBox[0] > interBox[1])
|
||||||
|
return 0.0f;
|
||||||
|
|
||||||
|
float interBoxS =(interBox[1]-interBox[0])*(interBox[3]-interBox[2]);
|
||||||
|
return interBoxS/(lbox[2]*lbox[3] + rbox[2]*rbox[3] -interBoxS);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool cmp(Detection& a, Detection& b) {
|
||||||
|
return a.conf > b.conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void nms(std::vector<Detection>& res, float *output, float conf_thresh, float nms_thresh) {
|
||||||
|
int det_size = sizeof(Detection) / sizeof(float);
|
||||||
|
std::map<float, std::vector<Detection>> m;
|
||||||
|
for (int i = 0; i < output[0] && i < 1000; i++) {
|
||||||
|
if (output[1 + det_size * i + 4] <= conf_thresh) continue;
|
||||||
|
Detection det;
|
||||||
|
memcpy(&det, &output[1 + det_size * i], det_size * sizeof(float));
|
||||||
|
if (m.count(det.class_id) == 0) m.emplace(det.class_id, std::vector<Detection>());
|
||||||
|
m[det.class_id].push_back(det);
|
||||||
|
}
|
||||||
|
for (auto it = m.begin(); it != m.end(); it++) {
|
||||||
|
auto& dets = it->second;
|
||||||
|
std::sort(dets.begin(), dets.end(), cmp);
|
||||||
|
for (size_t m = 0; m < dets.size(); ++m) {
|
||||||
|
auto& item = dets[m];
|
||||||
|
res.push_back(item);
|
||||||
|
for (size_t n = m + 1; n < dets.size(); ++n) {
|
||||||
|
if (iou(item.bbox, dets[n].bbox) > nms_thresh) {
|
||||||
|
dets.erase(dets.begin()+n);
|
||||||
|
--n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* This is a sample bounding box parsing function for the sample YoloV5 detector model */
|
||||||
|
static bool NvDsInferParseYoloV5(
|
||||||
|
std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
|
||||||
|
NvDsInferNetworkInfo const& networkInfo,
|
||||||
|
NvDsInferParseDetectionParams const& detectionParams,
|
||||||
|
std::vector<NvDsInferParseObjectInfo>& objectList)
|
||||||
|
{
|
||||||
|
const float kCONF_THRESH = detectionParams.perClassThreshold[0];
|
||||||
|
|
||||||
|
std::vector<Detection> res;
|
||||||
|
|
||||||
|
nms(res, (float*)(outputLayersInfo[0].buffer), kCONF_THRESH, kNMS_THRESH);
|
||||||
|
|
||||||
|
for(auto& r : res) {
|
||||||
|
NvDsInferParseObjectInfo oinfo;
|
||||||
|
|
||||||
|
oinfo.classId = r.class_id;
|
||||||
|
oinfo.left = static_cast<unsigned int>(r.bbox[0]-r.bbox[2]*0.5f);
|
||||||
|
oinfo.top = static_cast<unsigned int>(r.bbox[1]-r.bbox[3]*0.5f);
|
||||||
|
oinfo.width = static_cast<unsigned int>(r.bbox[2]);
|
||||||
|
oinfo.height = static_cast<unsigned int>(r.bbox[3]);
|
||||||
|
oinfo.detectionConfidence = r.conf;
|
||||||
|
objectList.push_back(oinfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" bool NvDsInferParseCustomYoloV5(
|
||||||
|
std::vector<NvDsInferLayerInfo> const &outputLayersInfo,
|
||||||
|
NvDsInferNetworkInfo const &networkInfo,
|
||||||
|
NvDsInferParseDetectionParams const &detectionParams,
|
||||||
|
std::vector<NvDsInferParseObjectInfo> &objectList)
|
||||||
|
{
|
||||||
|
return NvDsInferParseYoloV5(
|
||||||
|
outputLayersInfo, networkInfo, detectionParams, objectList);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Check that the custom function has been defined correctly */
|
||||||
|
CHECK_CUSTOM_PARSE_FUNC_PROTOTYPE(NvDsInferParseCustomYoloV5);
|
||||||
94
external/yolov5-3.X/nvdsinfer_custom_impl_Yolo/utils.h
vendored
Normal file
94
external/yolov5-3.X/nvdsinfer_custom_impl_Yolo/utils.h
vendored
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
#ifndef __TRT_UTILS_H_
|
||||||
|
#define __TRT_UTILS_H_
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cudnn.h>
|
||||||
|
|
||||||
|
#ifndef CUDA_CHECK
|
||||||
|
|
||||||
|
#define CUDA_CHECK(callstr) \
|
||||||
|
{ \
|
||||||
|
cudaError_t error_code = callstr; \
|
||||||
|
if (error_code != cudaSuccess) { \
|
||||||
|
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
|
||||||
|
assert(0); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace Tn
|
||||||
|
{
|
||||||
|
class Profiler : public nvinfer1::IProfiler
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
void printLayerTimes(int itrationsTimes)
|
||||||
|
{
|
||||||
|
float totalTime = 0;
|
||||||
|
for (size_t i = 0; i < mProfile.size(); i++)
|
||||||
|
{
|
||||||
|
printf("%-40.40s %4.3fms\n", mProfile[i].first.c_str(), mProfile[i].second / itrationsTimes);
|
||||||
|
totalTime += mProfile[i].second;
|
||||||
|
}
|
||||||
|
printf("Time over all layers: %4.3f\n", totalTime / itrationsTimes);
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
typedef std::pair<std::string, float> Record;
|
||||||
|
std::vector<Record> mProfile;
|
||||||
|
|
||||||
|
virtual void reportLayerTime(const char* layerName, float ms)
|
||||||
|
{
|
||||||
|
auto record = std::find_if(mProfile.begin(), mProfile.end(), [&](const Record& r){ return r.first == layerName; });
|
||||||
|
if (record == mProfile.end())
|
||||||
|
mProfile.push_back(std::make_pair(layerName, ms));
|
||||||
|
else
|
||||||
|
record->second += ms;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//Logger for TensorRT info/warning/errors
|
||||||
|
class Logger : public nvinfer1::ILogger
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
Logger(): Logger(Severity::kWARNING) {}
|
||||||
|
|
||||||
|
Logger(Severity severity): reportableSeverity(severity) {}
|
||||||
|
|
||||||
|
void log(Severity severity, const char* msg) override
|
||||||
|
{
|
||||||
|
// suppress messages with severity enum value greater than the reportable
|
||||||
|
if (severity > reportableSeverity) return;
|
||||||
|
|
||||||
|
switch (severity)
|
||||||
|
{
|
||||||
|
case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break;
|
||||||
|
case Severity::kERROR: std::cerr << "ERROR: "; break;
|
||||||
|
case Severity::kWARNING: std::cerr << "WARNING: "; break;
|
||||||
|
case Severity::kINFO: std::cerr << "INFO: "; break;
|
||||||
|
default: std::cerr << "UNKNOWN: "; break;
|
||||||
|
}
|
||||||
|
std::cerr << msg << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
Severity reportableSeverity{Severity::kWARNING};
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void write(char*& buffer, const T& val)
|
||||||
|
{
|
||||||
|
*reinterpret_cast<T*>(buffer) = val;
|
||||||
|
buffer += sizeof(T);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void read(const char*& buffer, T& val)
|
||||||
|
{
|
||||||
|
val = *reinterpret_cast<const T*>(buffer);
|
||||||
|
buffer += sizeof(T);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
270
external/yolov5-3.X/nvdsinfer_custom_impl_Yolo/yololayer.cu
vendored
Normal file
270
external/yolov5-3.X/nvdsinfer_custom_impl_Yolo/yololayer.cu
vendored
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
#include <assert.h>
|
||||||
|
#include "yololayer.h"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
using namespace Yolo;
|
||||||
|
|
||||||
|
namespace nvinfer1
|
||||||
|
{
|
||||||
|
YoloLayerPlugin::YoloLayerPlugin()
|
||||||
|
{
|
||||||
|
mClassCount = CLASS_NUM;
|
||||||
|
mYoloKernel.clear();
|
||||||
|
mYoloKernel.push_back(yolo1);
|
||||||
|
mYoloKernel.push_back(yolo2);
|
||||||
|
mYoloKernel.push_back(yolo3);
|
||||||
|
|
||||||
|
mKernelCount = mYoloKernel.size();
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
|
||||||
|
size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
|
||||||
|
for(int ii = 0; ii < mKernelCount; ii ++)
|
||||||
|
{
|
||||||
|
CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen));
|
||||||
|
const auto& yolo = mYoloKernel[ii];
|
||||||
|
CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
YoloLayerPlugin::~YoloLayerPlugin()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the plugin at runtime from a byte stream
|
||||||
|
YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length)
|
||||||
|
{
|
||||||
|
using namespace Tn;
|
||||||
|
const char *d = reinterpret_cast<const char *>(data), *a = d;
|
||||||
|
read(d, mClassCount);
|
||||||
|
read(d, mThreadCount);
|
||||||
|
read(d, mKernelCount);
|
||||||
|
mYoloKernel.resize(mKernelCount);
|
||||||
|
auto kernelSize = mKernelCount*sizeof(YoloKernel);
|
||||||
|
memcpy(mYoloKernel.data(),d,kernelSize);
|
||||||
|
d += kernelSize;
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
|
||||||
|
size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
|
||||||
|
for(int ii = 0; ii < mKernelCount; ii ++)
|
||||||
|
{
|
||||||
|
CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen));
|
||||||
|
const auto& yolo = mYoloKernel[ii];
|
||||||
|
CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(d == a + length);
|
||||||
|
}
|
||||||
|
|
||||||
|
void YoloLayerPlugin::serialize(void* buffer) const
|
||||||
|
{
|
||||||
|
using namespace Tn;
|
||||||
|
char* d = static_cast<char*>(buffer), *a = d;
|
||||||
|
write(d, mClassCount);
|
||||||
|
write(d, mThreadCount);
|
||||||
|
write(d, mKernelCount);
|
||||||
|
auto kernelSize = mKernelCount*sizeof(YoloKernel);
|
||||||
|
memcpy(d,mYoloKernel.data(),kernelSize);
|
||||||
|
d += kernelSize;
|
||||||
|
|
||||||
|
assert(d == a + getSerializationSize());
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t YoloLayerPlugin::getSerializationSize() const
|
||||||
|
{
|
||||||
|
return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
int YoloLayerPlugin::initialize()
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
|
||||||
|
{
|
||||||
|
//output the result to channel
|
||||||
|
int totalsize = MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float);
|
||||||
|
|
||||||
|
return Dims3(totalsize + 1, 1, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set plugin namespace
|
||||||
|
void YoloLayerPlugin::setPluginNamespace(const char* pluginNamespace)
|
||||||
|
{
|
||||||
|
mPluginNamespace = pluginNamespace;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* YoloLayerPlugin::getPluginNamespace() const
|
||||||
|
{
|
||||||
|
return mPluginNamespace;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the DataType of the plugin output at the requested index
|
||||||
|
DataType YoloLayerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const
|
||||||
|
{
|
||||||
|
return DataType::kFLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return true if output tensor is broadcast across a batch.
|
||||||
|
bool YoloLayerPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return true if plugin can use input that is broadcast across batch without replication.
|
||||||
|
bool YoloLayerPlugin::canBroadcastInputAcrossBatch(int inputIndex) const
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void YoloLayerPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attach the plugin object to an execution context and grant the plugin the access to some context resource.
|
||||||
|
void YoloLayerPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detach the plugin object from its execution context.
|
||||||
|
void YoloLayerPlugin::detachFromContext() {}
|
||||||
|
|
||||||
|
const char* YoloLayerPlugin::getPluginType() const
|
||||||
|
{
|
||||||
|
return "YoloLayer_TRT";
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* YoloLayerPlugin::getPluginVersion() const
|
||||||
|
{
|
||||||
|
return "1";
|
||||||
|
}
|
||||||
|
|
||||||
|
void YoloLayerPlugin::destroy()
|
||||||
|
{
|
||||||
|
delete this;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone the plugin
|
||||||
|
IPluginV2IOExt* YoloLayerPlugin::clone() const
|
||||||
|
{
|
||||||
|
YoloLayerPlugin *p = new YoloLayerPlugin();
|
||||||
|
p->setPluginNamespace(mPluginNamespace);
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ float Logist(float data){ return 1.0f / (1.0f + expf(-data)); };
|
||||||
|
|
||||||
|
__global__ void CalDetection(const float *input, float *output,int noElements,
|
||||||
|
int yoloWidth,int yoloHeight,const float anchors[CHECK_COUNT*2],int classes,int outputElem) {
|
||||||
|
|
||||||
|
int idx = threadIdx.x + blockDim.x * blockIdx.x;
|
||||||
|
if (idx >= noElements) return;
|
||||||
|
|
||||||
|
int total_grid = yoloWidth * yoloHeight;
|
||||||
|
int bnIdx = idx / total_grid;
|
||||||
|
idx = idx - total_grid*bnIdx;
|
||||||
|
int info_len_i = 5 + classes;
|
||||||
|
const float* curInput = input + bnIdx * (info_len_i * total_grid * CHECK_COUNT);
|
||||||
|
|
||||||
|
for (int k = 0; k < 3; ++k) {
|
||||||
|
float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]);
|
||||||
|
if (box_prob < IGNORE_THRESH) continue;
|
||||||
|
int class_id = 0;
|
||||||
|
float max_cls_prob = 0.0;
|
||||||
|
for (int i = 5; i < info_len_i; ++i) {
|
||||||
|
float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]);
|
||||||
|
if (p > max_cls_prob) {
|
||||||
|
max_cls_prob = p;
|
||||||
|
class_id = i - 5;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float *res_count = output + bnIdx*outputElem;
|
||||||
|
int count = (int)atomicAdd(res_count, 1);
|
||||||
|
if (count >= MAX_OUTPUT_BBOX_COUNT) return;
|
||||||
|
char* data = (char *)res_count + sizeof(float) + count * sizeof(Detection);
|
||||||
|
Detection* det = (Detection*)(data);
|
||||||
|
|
||||||
|
int row = idx / yoloWidth;
|
||||||
|
int col = idx % yoloWidth;
|
||||||
|
|
||||||
|
//Location
|
||||||
|
det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * INPUT_W / yoloWidth;
|
||||||
|
det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * INPUT_H / yoloHeight;
|
||||||
|
det->bbox[2] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 2 * total_grid]);
|
||||||
|
det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2*k];
|
||||||
|
det->bbox[3] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 3 * total_grid]);
|
||||||
|
det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2*k + 1];
|
||||||
|
det->conf = box_prob * max_cls_prob;
|
||||||
|
det->class_id = class_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void YoloLayerPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) {
|
||||||
|
|
||||||
|
int outputElem = 1 + MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float);
|
||||||
|
|
||||||
|
for(int idx = 0 ; idx < batchSize; ++idx) {
|
||||||
|
CUDA_CHECK(cudaMemset(output + idx*outputElem, 0, sizeof(float)));
|
||||||
|
}
|
||||||
|
int numElem = 0;
|
||||||
|
for (unsigned int i = 0; i < mYoloKernel.size(); ++i)
|
||||||
|
{
|
||||||
|
const auto& yolo = mYoloKernel[i];
|
||||||
|
numElem = yolo.width*yolo.height*batchSize;
|
||||||
|
if (numElem < mThreadCount)
|
||||||
|
mThreadCount = numElem;
|
||||||
|
CalDetection<<< (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream>>>
|
||||||
|
(inputs[i], output, numElem, yolo.width, yolo.height, (float *)mAnchor[i], mClassCount, outputElem);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int YoloLayerPlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
forwardGpu((const float *const *)inputs, (float*)outputs[0], stream, batchSize);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
PluginFieldCollection YoloPluginCreator::mFC{};
|
||||||
|
std::vector<PluginField> YoloPluginCreator::mPluginAttributes;
|
||||||
|
|
||||||
|
YoloPluginCreator::YoloPluginCreator()
|
||||||
|
{
|
||||||
|
mPluginAttributes.clear();
|
||||||
|
|
||||||
|
mFC.nbFields = mPluginAttributes.size();
|
||||||
|
mFC.fields = mPluginAttributes.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* YoloPluginCreator::getPluginName() const
|
||||||
|
{
|
||||||
|
return "YoloLayer_TRT";
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* YoloPluginCreator::getPluginVersion() const
|
||||||
|
{
|
||||||
|
return "1";
|
||||||
|
}
|
||||||
|
|
||||||
|
const PluginFieldCollection* YoloPluginCreator::getFieldNames()
|
||||||
|
{
|
||||||
|
return &mFC;
|
||||||
|
}
|
||||||
|
|
||||||
|
IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
|
||||||
|
{
|
||||||
|
YoloLayerPlugin* obj = new YoloLayerPlugin();
|
||||||
|
obj->setPluginNamespace(mNamespace.c_str());
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
|
||||||
|
IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
|
||||||
|
{
|
||||||
|
// This object will be deleted when the network is destroyed, which will
|
||||||
|
// call MishPlugin::destroy()
|
||||||
|
YoloLayerPlugin* obj = new YoloLayerPlugin(serialData, serialLength);
|
||||||
|
obj->setPluginNamespace(mNamespace.c_str());
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
152
external/yolov5-3.X/nvdsinfer_custom_impl_Yolo/yololayer.h
vendored
Normal file
152
external/yolov5-3.X/nvdsinfer_custom_impl_Yolo/yololayer.h
vendored
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
#ifndef _YOLO_LAYER_H
|
||||||
|
#define _YOLO_LAYER_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include "NvInfer.h"
|
||||||
|
|
||||||
|
namespace Yolo
|
||||||
|
{
|
||||||
|
static constexpr int CHECK_COUNT = 3;
|
||||||
|
static constexpr float IGNORE_THRESH = 0.1f;
|
||||||
|
static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000;
|
||||||
|
static constexpr int CLASS_NUM = 80;
|
||||||
|
static constexpr int INPUT_H = 608;
|
||||||
|
static constexpr int INPUT_W = 608;
|
||||||
|
|
||||||
|
struct YoloKernel
|
||||||
|
{
|
||||||
|
int width;
|
||||||
|
int height;
|
||||||
|
float anchors[CHECK_COUNT*2];
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr YoloKernel yolo1 = {
|
||||||
|
INPUT_W / 32,
|
||||||
|
INPUT_H / 32,
|
||||||
|
{116,90, 156,198, 373,326}
|
||||||
|
};
|
||||||
|
static constexpr YoloKernel yolo2 = {
|
||||||
|
INPUT_W / 16,
|
||||||
|
INPUT_H / 16,
|
||||||
|
{30,61, 62,45, 59,119}
|
||||||
|
};
|
||||||
|
static constexpr YoloKernel yolo3 = {
|
||||||
|
INPUT_W / 8,
|
||||||
|
INPUT_H / 8,
|
||||||
|
{10,13, 16,30, 33,23}
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr int LOCATIONS = 4;
|
||||||
|
struct alignas(float) Detection{
|
||||||
|
//center_x center_y w h
|
||||||
|
float bbox[LOCATIONS];
|
||||||
|
float conf; // bbox_conf * cls_conf
|
||||||
|
float class_id;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace nvinfer1
|
||||||
|
{
|
||||||
|
class YoloLayerPlugin: public IPluginV2IOExt
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
explicit YoloLayerPlugin();
|
||||||
|
YoloLayerPlugin(const void* data, size_t length);
|
||||||
|
|
||||||
|
~YoloLayerPlugin();
|
||||||
|
|
||||||
|
int getNbOutputs() const override
|
||||||
|
{
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
|
||||||
|
|
||||||
|
int initialize() override;
|
||||||
|
|
||||||
|
virtual void terminate() override {};
|
||||||
|
|
||||||
|
virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;}
|
||||||
|
|
||||||
|
virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override;
|
||||||
|
|
||||||
|
virtual size_t getSerializationSize() const override;
|
||||||
|
|
||||||
|
virtual void serialize(void* buffer) const override;
|
||||||
|
|
||||||
|
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
|
||||||
|
return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* getPluginType() const override;
|
||||||
|
|
||||||
|
const char* getPluginVersion() const override;
|
||||||
|
|
||||||
|
void destroy() override;
|
||||||
|
|
||||||
|
IPluginV2IOExt* clone() const override;
|
||||||
|
|
||||||
|
void setPluginNamespace(const char* pluginNamespace) override;
|
||||||
|
|
||||||
|
const char* getPluginNamespace() const override;
|
||||||
|
|
||||||
|
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
|
||||||
|
|
||||||
|
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
|
||||||
|
|
||||||
|
bool canBroadcastInputAcrossBatch(int inputIndex) const override;
|
||||||
|
|
||||||
|
void attachToContext(
|
||||||
|
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
|
||||||
|
|
||||||
|
void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
|
||||||
|
|
||||||
|
void detachFromContext() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void forwardGpu(const float *const * inputs,float * output, cudaStream_t stream,int batchSize = 1);
|
||||||
|
int mClassCount;
|
||||||
|
int mKernelCount;
|
||||||
|
std::vector<Yolo::YoloKernel> mYoloKernel;
|
||||||
|
int mThreadCount = 256;
|
||||||
|
void** mAnchor;
|
||||||
|
const char* mPluginNamespace;
|
||||||
|
};
|
||||||
|
|
||||||
|
class YoloPluginCreator : public IPluginCreator
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
YoloPluginCreator();
|
||||||
|
|
||||||
|
~YoloPluginCreator() override = default;
|
||||||
|
|
||||||
|
const char* getPluginName() const override;
|
||||||
|
|
||||||
|
const char* getPluginVersion() const override;
|
||||||
|
|
||||||
|
const PluginFieldCollection* getFieldNames() override;
|
||||||
|
|
||||||
|
IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
|
||||||
|
|
||||||
|
IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
|
||||||
|
|
||||||
|
void setPluginNamespace(const char* libNamespace) override
|
||||||
|
{
|
||||||
|
mNamespace = libNamespace;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* getPluginNamespace() const override
|
||||||
|
{
|
||||||
|
return mNamespace.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string mNamespace;
|
||||||
|
static PluginFieldCollection mFC;
|
||||||
|
static std::vector<PluginField> mPluginAttributes;
|
||||||
|
};
|
||||||
|
REGISTER_TENSORRT_PLUGIN(YoloPluginCreator);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
Reference in New Issue
Block a user