Removed previous commits

This commit is contained in:
marcoslucianops
2020-12-20 13:39:54 -03:00
commit c34158664c
35 changed files with 3441 additions and 0 deletions

196
YOLOv5.md Normal file
View File

@@ -0,0 +1,196 @@
# YOLOv5
NVIDIA DeepStream SDK 5.0.1 configuration for YOLOv5 models
Thanks [DanaHan](https://github.com/DanaHan/Yolov5-in-Deepstream-5.0), [wang-xinyu](https://github.com/wang-xinyu/tensorrtx) and [Ultralytics](https://github.com/ultralytics/yolov5)
##
* [Requirements](#requirements)
* [Convert PyTorch model to wts file](#convert-pytorch-model-to-wts-file)
* [Convert wts file to TensorRT model](#convert-wts-file-to-tensorrt-model)
* [Compile nvdsinfer_custom_impl_Yolo](#compile-nvdsinfer_custom_impl_yolo)
* [Testing model](#testing-model)
##
### Requirements
* Python3
```
sudo apt-get install python3 python3-dev python3-pip
pip3 install --upgrade pip
```
* OpenCV Python
```
sudo apt-get install libopencv-dev
pip3 install opencv-python
```
* Matplotlib
```
pip3 install matplotlib
```
* Scipy
```
pip3 install scipy
```
* tqdm
```
pip3 install tqdm
```
* PyTorch
```
pip3 install torch torchvision
```
* PyTorch (for Jetson plataform)
```
wget https://nvidia.box.com/shared/static/9eptse6jyly1ggt9axbja2yrmj6pbarc.whl -O torch-1.6.0-cp36-cp36m-linux_aarch64.whl
sudo apt-get install python3-pip libopenblas-base libopenmpi-dev
pip3 install torch-1.6.0-cp36-cp36m-linux_aarch64.whl
```
* TorchVision (for Jetson platform)
```
git clone -b v0.7.0 https://github.com/pytorch/vision torchvision
sudo apt-get install libjpeg-dev zlib1g-dev python3-pip
cd torchvision
export BUILD_VERSION=0.7.0
sudo python3 setup.py install
```
##
### Convert PyTorch model to wts file
1. Download repositories
```
git clone https://github.com/DanaHan/Yolov5-in-Deepstream-5.0.git yolov5converter
git clone https://github.com/wang-xinyu/tensorrtx.git
git clone https://github.com/ultralytics/yolov5.git
```
2. Download latest YoloV5 (YOLOv5s, YOLOv5m, YOLOv5l or YOLOv5x) weights to yolov5/weights directory (example for YOLOv5s)
```
wget https://github.com/ultralytics/yolov5/releases/download/v3.1/yolov5s.pt -P yolov5/weights/
```
3. Copy gen_wts.py file (from tensorrtx/yolov5 folder) to yolov5 (ultralytics) folder
```
cp tensorrtx/yolov5/gen_wts.py yolov5/gen_wts.py
```
4. Generate wts file
```
cd yolov5
python3 gen_wts.py
```
yolov5s.wts file will be generated in yolov5 folder
<br />
Note: if you want to generate wts file to another YOLOv5 model (YOLOv5m, YOLOv5l or YOLOv5x), edit get_wts.py file changing yolov5s to your model name
```
model = torch.load('weights/yolov5s.pt', map_location=device)['model'].float() # load to FP32
model.to(device).eval()
f = open('yolov5s.wts', 'w')
```
##
### Convert wts file to TensorRT model
1. Replace yololayer files from tensorrtx/yolov5 folder to yololayer and hardswish files from yolov5converter
```
mv yolov5converter/yololayer.cu tensorrtx/yolov5/yololayer.cu
mv yolov5converter/yololayer.h tensorrtx/yolov5/yololayer.h
mv yolov5converter/hardswish.cu tensorrtx/yolov5/hardswish.cu
mv yolov5converter/hardswish.h tensorrtx/yolov5/hardswish.h
```
2. Move generated yolov5s.wts file to tensorrtx/yolov5 folder (example for YOLOv5s)
```
cp yolov5/yolov5s.wts tensorrtx/yolov5/yolov5s.wts
```
3. Build tensorrtx/yolov5
```
cd tensorrtx/yolov5
mkdir build
cd build
cmake ..
make
```
4. Convert to TensorRT model (yolov5s.engine and libmyplugins.so files will be generated in tensorrtx/yolov5/build folder)
```
sudo ./yolov5 -s
```
5. Create a custom yolo folder and copy generated files (example for YOLOv5s)
```
mkdir /opt/nvidia/deepstream/deepstream-5.0/sources/yolo
cp yolov5s.engine /opt/nvidia/deepstream/deepstream-5.0/sources/yolo/yolov5s.engine
cp libmyplugins.so /opt/nvidia/deepstream/deepstream-5.0/sources/yolo/libmyplugins.so
```
<br />
Note: by default, yolov5 script generate model with batch size = 1, FP16 mode and s model.
```
#define USE_FP16 // comment out this if want to use FP32
#define DEVICE 0 // GPU id
#define NMS_THRESH 0.4
#define CONF_THRESH 0.5
#define BATCH_SIZE 1
#define NET s // s m l x
```
Edit yolov5.cpp file before compile if you want to change this parameters.
##
### Compile nvdsinfer_custom_impl_Yolo
1. Donwload [my external/yolov5 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5) and move files to created yolo folder
2. Compile lib
```
cd /opt/nvidia/deepstream/deepstream-5.0/sources/yolo
CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
```
##
### Testing model
Use my edited [deepstream_app_config.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5/deepstream_app_config.txt) and [config_infer_primary.txt](https://raw.githubusercontent.com/marcoslucianops/DeepStream-Yolo/master/external/yolov5/config_infer_primary.txt) files available in [my external/yolov5 folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external/yolov5)
Run command
```
LD_PRELOAD=./libmyplugins.so deepstream-app -c deepstream_app_config.txt
```
<br />
Note: based on selected model, edit config_infer_primary.txt file
For example, if you using YOLOv5x
```
model-engine-file=yolov5s.engine
```
to
```
model-engine-file=yolov5x.engine
```
To change NMS_THRESH and THRESH_CONF, edit nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp file and recompile
```
#define NMS_THRESH 0.45
#define CONF_THRESH 0.25
```

View File

@@ -0,0 +1,16 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
model-engine-file=yolov5s.engine
labelfile-path=labels.txt
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=4
maintain-aspect-ratio=0
parse-bbox-func-name=NvDsInferParseCustomYoloV5
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
engine-create-func-name=NvDsInferYoloCudaEngineGet

View File

@@ -0,0 +1,63 @@
[application]
enable-perf-measurement=1
perf-measurement-interval-sec=1
[tiled-display]
enable=1
rows=1
columns=1
width=1280
height=720
gpu-id=0
nvbuf-memory-type=0
[source0]
enable=1
type=3
uri=file://../../samples/streams/sample_1080p_h264.mp4
num-sources=1
gpu-id=0
cudadec-memtype=0
[sink0]
enable=1
type=2
sync=0
source-id=0
gpu-id=0
nvbuf-memory-type=0
[osd]
enable=1
gpu-id=0
border-width=1
text-size=15
text-color=1;1;1;1;
text-bg-color=0.3;0.3;0.3;1
font=Serif
show-clock=0
clock-x-offset=800
clock-y-offset=820
clock-text-size=12
clock-color=1;0;0;0
nvbuf-memory-type=0
[streammux]
gpu-id=0
live-source=0
batch-size=1
batched-push-timeout=40000
width=1920
height=1080
enable-padding=0
nvbuf-memory-type=0
[primary-gie]
enable=1
gpu-id=0
gie-unique-id=1
nvbuf-memory-type=0
config-file=config_infer_primary.txt
[tests]
file-loop=0

80
external/yolov5/labels.txt vendored Normal file
View File

@@ -0,0 +1,80 @@
person
bicycle
car
motorbike
aeroplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
sofa
pottedplant
bed
diningtable
toilet
tvmonitor
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush

View File

@@ -0,0 +1,50 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
CUDA_VER?=
ifeq ($(CUDA_VER),)
$(error "CUDA_VER is not set")
endif
CC:= g++
NVCC:=/usr/local/cuda-$(CUDA_VER)/bin/nvcc
CFLAGS:= -Wall -std=c++11 -shared -fPIC -Wno-error=deprecated-declarations
CFLAGS+= -I../../includes -I/usr/local/cuda-$(CUDA_VER)/include
LIBS:= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib64 -lcudart -lcublas -lstdc++fs
LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group
INCS:= $(wildcard *.h)
SRCFILES:= nvdsparsebbox_Yolo.cpp
TARGET_LIB:= libnvdsinfer_custom_impl_Yolo.so
TARGET_OBJS:= $(SRCFILES:.cpp=.o)
TARGET_OBJS:= $(TARGET_OBJS:.cu=.o)
all: $(TARGET_LIB)
%.o: %.cpp $(INCS) Makefile
$(CC) -c -o $@ $(CFLAGS) $<
%.o: %.cu $(INCS) Makefile
$(NVCC) -c -o $@ --compiler-options '-fPIC' $<
$(TARGET_LIB) : $(TARGET_OBJS)
$(CC) -o $@ $(TARGET_OBJS) $(LFLAGS)
clean:
rm -rf $(TARGET_LIB)

View File

@@ -0,0 +1,130 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <unordered_map>
#include "nvdsinfer_custom_impl.h"
#include <map>
#define NMS_THRESH 0.45
#define CONF_THRESH 0.25
static const int NUM_CLASSES_YOLO = 80;
static constexpr int LOCATIONS = 4;
struct alignas(float) Detection{
//center_x center_y w h
float bbox[LOCATIONS];
float conf; // bbox_conf * cls_conf
float class_id;
};
float iou(float lbox[4], float rbox[4]) {
float interBox[] = {
std::max(lbox[0] - lbox[2]/2.f , rbox[0] - rbox[2]/2.f), //left
std::min(lbox[0] + lbox[2]/2.f , rbox[0] + rbox[2]/2.f), //right
std::max(lbox[1] - lbox[3]/2.f , rbox[1] - rbox[3]/2.f), //top
std::min(lbox[1] + lbox[3]/2.f , rbox[1] + rbox[3]/2.f), //bottom
};
if(interBox[2] > interBox[3] || interBox[0] > interBox[1])
return 0.0f;
float interBoxS =(interBox[1]-interBox[0])*(interBox[3]-interBox[2]);
return interBoxS/(lbox[2]*lbox[3] + rbox[2]*rbox[3] -interBoxS);
}
bool cmp(Detection& a, Detection& b) {
return a.conf > b.conf;
}
void nms(std::vector<Detection>& res, float *output, float conf_thresh, float nms_thresh) {
int det_size = sizeof(Detection) / sizeof(float);
std::map<float, std::vector<Detection>> m;
for (int i = 0; i < output[0] && i < 1000; i++) {
if (output[1 + det_size * i + 4] <= conf_thresh) continue;
Detection det;
memcpy(&det, &output[1 + det_size * i], det_size * sizeof(float));
if (m.count(det.class_id) == 0) m.emplace(det.class_id, std::vector<Detection>());
m[det.class_id].push_back(det);
}
for (auto it = m.begin(); it != m.end(); it++) {
auto& dets = it->second;
std::sort(dets.begin(), dets.end(), cmp);
for (size_t m = 0; m < dets.size(); ++m) {
auto& item = dets[m];
res.push_back(item);
for (size_t n = m + 1; n < dets.size(); ++n) {
if (iou(item.bbox, dets[n].bbox) > nms_thresh) {
dets.erase(dets.begin()+n);
--n;
}
}
}
}
}
/* This is a sample bounding box parsing function for the sample YoloV5 detector model */
static bool NvDsInferParseYoloV5(
std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams,
std::vector<NvDsInferParseObjectInfo>& objectList)
{
if (NUM_CLASSES_YOLO != detectionParams.numClassesConfigured)
{
std::cerr << "WARNING: Num classes mismatch. Configured:"
<< detectionParams.numClassesConfigured
<< ", detected by network: " << NUM_CLASSES_YOLO << std::endl;
}
std::vector<Detection> res;
nms(res, (float*)(outputLayersInfo[0].buffer), CONF_THRESH, NMS_THRESH);
for(auto& r : res) {
NvDsInferParseObjectInfo oinfo;
oinfo.classId = r.class_id;
oinfo.left = static_cast<unsigned int>(r.bbox[0]-r.bbox[2]*0.5f);
oinfo.top = static_cast<unsigned int>(r.bbox[1]-r.bbox[3]*0.5f);
oinfo.width = static_cast<unsigned int>(r.bbox[2]);
oinfo.height = static_cast<unsigned int>(r.bbox[3]);
oinfo.detectionConfidence = r.conf;
objectList.push_back(oinfo);
}
return true;
}
extern "C" bool NvDsInferParseCustomYoloV5(
std::vector<NvDsInferLayerInfo> const &outputLayersInfo,
NvDsInferNetworkInfo const &networkInfo,
NvDsInferParseDetectionParams const &detectionParams,
std::vector<NvDsInferParseObjectInfo> &objectList)
{
return NvDsInferParseYoloV5(
outputLayersInfo, networkInfo, detectionParams, objectList);
}
/* Check that the custom function has been defined correctly */
CHECK_CUSTOM_PARSE_FUNC_PROTOTYPE(NvDsInferParseCustomYoloV5);

View File

@@ -0,0 +1,23 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
custom-network-config=yolov4.cfg
model-file=yolov4.weights
model-engine-file=model_b1_gpu0_fp16.engine
labelfile-path=labels.txt
batch-size=1
network-mode=2
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=4
maintain-aspect-ratio=0
parse-bbox-func-name=NvDsInferParseYolo
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
engine-create-func-name=NvDsInferYoloCudaEngineGet
[class-attrs-all]
pre-cluster-threshold=0.25

View File

@@ -0,0 +1,24 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
custom-network-config=yolov2.cfg
model-file=yolov2.weights
model-engine-file=model_b1_gpu0_fp16.engine
labelfile-path=labels.txt
batch-size=1
network-mode=2
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=0
parse-bbox-func-name=NvDsInferParseYolo
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
engine-create-func-name=NvDsInferYoloCudaEngineGet
[class-attrs-all]
nms-iou-threshold=0.45
pre-cluster-threshold=0.25

View File

@@ -0,0 +1,63 @@
[application]
enable-perf-measurement=1
perf-measurement-interval-sec=1
[tiled-display]
enable=1
rows=1
columns=1
width=1280
height=720
gpu-id=0
nvbuf-memory-type=0
[source0]
enable=1
type=3
uri=file://../../samples/streams/sample_1080p_h264.mp4
num-sources=1
gpu-id=0
cudadec-memtype=0
[sink0]
enable=1
type=2
sync=0
source-id=0
gpu-id=0
nvbuf-memory-type=0
[osd]
enable=1
gpu-id=0
border-width=1
text-size=15
text-color=1;1;1;1;
text-bg-color=0.3;0.3;0.3;1
font=Serif
show-clock=0
clock-x-offset=800
clock-y-offset=820
clock-text-size=12
clock-color=1;0;0;0
nvbuf-memory-type=0
[streammux]
gpu-id=0
live-source=0
batch-size=1
batched-push-timeout=40000
width=1920
height=1080
enable-padding=0
nvbuf-memory-type=0
[primary-gie]
enable=1
gpu-id=0
gie-unique-id=1
nvbuf-memory-type=0
config-file=config_infer_primary.txt
[tests]
file-loop=0

80
native/labels.txt Normal file
View File

@@ -0,0 +1,80 @@
person
bicycle
car
motorbike
aeroplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
sofa
pottedplant
bed
diningtable
toilet
tvmonitor
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush

View File

@@ -0,0 +1,71 @@
################################################################################
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# Edited by Marcos Luciano
# https://www.github.com/marcoslucianops
################################################################################
CUDA_VER?=
ifeq ($(CUDA_VER),)
$(error "CUDA_VER is not set")
endif
CC:= g++
NVCC:=/usr/local/cuda-$(CUDA_VER)/bin/nvcc
CFLAGS:= -Wall -std=c++11 -shared -fPIC -Wno-error=deprecated-declarations
CFLAGS+= -I../../includes -I/usr/local/cuda-$(CUDA_VER)/include
LIBS:= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib64 -lcudart -lcublas -lstdc++fs
LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group
INCS:= $(wildcard *.h)
SRCFILES:= nvdsinfer_yolo_engine.cpp \
nvdsparsebbox_Yolo.cpp \
yoloPlugins.cpp \
layers/convolutional_layer.cpp \
layers/dropout_layer.cpp \
layers/shortcut_layer.cpp \
layers/route_layer.cpp \
layers/upsample_layer.cpp \
layers/maxpool_layer.cpp \
layers/activation_layer.cpp \
utils.cpp \
yolo.cpp \
yoloForward.cu
TARGET_LIB:= libnvdsinfer_custom_impl_Yolo.so
TARGET_OBJS:= $(SRCFILES:.cpp=.o)
TARGET_OBJS:= $(TARGET_OBJS:.cu=.o)
all: $(TARGET_LIB)
%.o: %.cpp $(INCS) Makefile
$(CC) -c -o $@ $(CFLAGS) $<
%.o: %.cu $(INCS) Makefile
$(NVCC) -c -o $@ --compiler-options '-fPIC' $<
$(TARGET_LIB) : $(TARGET_OBJS)
$(CC) -o $@ $(TARGET_OBJS) $(LFLAGS)
clean:
rm -rf $(TARGET_LIB)
rm -rf $(TARGET_OBJS)

View File

@@ -0,0 +1,82 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "activation_layer.h"
nvinfer1::ILayer* activationLayer(
int layerIdx,
std::string activation,
nvinfer1::ILayer* output,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
if (activation == "relu")
{
nvinfer1::IActivationLayer* relu = network->addActivation(
*input, nvinfer1::ActivationType::kRELU);
assert(relu != nullptr);
std::string reluLayerName = "relu_" + std::to_string(layerIdx);
relu->setName(reluLayerName.c_str());
output = relu;
}
else if (activation == "sigmoid" || activation == "logistic")
{
nvinfer1::IActivationLayer* sigmoid = network->addActivation(
*input, nvinfer1::ActivationType::kSIGMOID);
assert(sigmoid != nullptr);
std::string sigmoidLayerName = "sigmoid_" + std::to_string(layerIdx);
sigmoid->setName(sigmoidLayerName.c_str());
output = sigmoid;
}
else if (activation == "tanh")
{
nvinfer1::IActivationLayer* tanh = network->addActivation(
*input, nvinfer1::ActivationType::kTANH);
assert(tanh != nullptr);
std::string tanhLayerName = "tanh_" + std::to_string(layerIdx);
tanh->setName(tanhLayerName.c_str());
output = tanh;
}
else if (activation == "leaky")
{
nvinfer1::IActivationLayer* leaky = network->addActivation(
*input, nvinfer1::ActivationType::kLEAKY_RELU);
leaky->setAlpha(0.1);
assert(leaky != nullptr);
std::string leakyLayerName = "leaky_" + std::to_string(layerIdx);
leaky->setName(leakyLayerName.c_str());
output = leaky;
}
else if (activation == "softplus")
{
nvinfer1::IActivationLayer* softplus = network->addActivation(
*input, nvinfer1::ActivationType::kSOFTPLUS);
assert(softplus != nullptr);
std::string softplusLayerName = "softplus_" + std::to_string(layerIdx);
softplus->setName(softplusLayerName.c_str());
output = softplus;
}
else if (activation == "mish")
{
nvinfer1::IActivationLayer* softplus = network->addActivation(
*input, nvinfer1::ActivationType::kSOFTPLUS);
assert(softplus != nullptr);
std::string softplusLayerName = "softplus_" + std::to_string(layerIdx);
softplus->setName(softplusLayerName.c_str());
nvinfer1::IActivationLayer* tanh = network->addActivation(
*softplus->getOutput(0), nvinfer1::ActivationType::kTANH);
assert(tanh != nullptr);
std::string tanhLayerName = "tanh_" + std::to_string(layerIdx);
tanh->setName(tanhLayerName.c_str());
nvinfer1::IElementWiseLayer* mish = network->addElementWise(
*tanh->getOutput(0), *input,
nvinfer1::ElementWiseOperation::kPROD);
assert(mish != nullptr);
std::string mishLayerName = "mish_" + std::to_string(layerIdx);
mish->setName(mishLayerName.c_str());
output = mish;
}
return output;
}

View File

@@ -0,0 +1,23 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#ifndef __ACTIVATION_LAYER_H__
#define __ACTIVATION_LAYER_H__
#include <string>
#include <cassert>
#include "NvInfer.h"
#include "activation_layer.h"
nvinfer1::ILayer* activationLayer(
int layerIdx,
std::string activation,
nvinfer1::ILayer* output,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -0,0 +1,168 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include <math.h>
#include "convolutional_layer.h"
nvinfer1::ILayer* convolutionalLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& weightPtr,
int& inputChannels,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
assert(block.at("type") == "convolutional");
assert(block.find("filters") != block.end());
assert(block.find("pad") != block.end());
assert(block.find("size") != block.end());
assert(block.find("stride") != block.end());
int filters = std::stoi(block.at("filters"));
int padding = std::stoi(block.at("pad"));
int kernelSize = std::stoi(block.at("size"));
int stride = std::stoi(block.at("stride"));
std::string activation = block.at("activation");
int bias = filters;
bool batchNormalize = false;
if (block.find("batch_normalize") != block.end())
{
bias = 0;
batchNormalize = (block.at("batch_normalize") == "1");
}
int groups = 1;
if (block.find("groups") != block.end())
{
groups = std::stoi(block.at("groups"));
}
int pad;
if (padding)
pad = (kernelSize - 1) / 2;
else
pad = 0;
int size = filters * inputChannels * kernelSize * kernelSize / groups;
std::vector<float> bnBiases;
std::vector<float> bnWeights;
std::vector<float> bnRunningMean;
std::vector<float> bnRunningVar;
nvinfer1::Weights convWt{nvinfer1::DataType::kFLOAT, nullptr, size};
nvinfer1::Weights convBias{nvinfer1::DataType::kFLOAT, nullptr, bias};
if (batchNormalize == false)
{
float* val = new float[filters];
for (int i = 0; i < filters; ++i)
{
val[i] = weights[weightPtr];
weightPtr++;
}
convBias.values = val;
trtWeights.push_back(convBias);
val = new float[size];
for (int i = 0; i < size; ++i)
{
val[i] = weights[weightPtr];
weightPtr++;
}
convWt.values = val;
trtWeights.push_back(convWt);
}
else
{
for (int i = 0; i < filters; ++i)
{
bnBiases.push_back(weights[weightPtr]);
weightPtr++;
}
for (int i = 0; i < filters; ++i)
{
bnWeights.push_back(weights[weightPtr]);
weightPtr++;
}
for (int i = 0; i < filters; ++i)
{
bnRunningMean.push_back(weights[weightPtr]);
weightPtr++;
}
for (int i = 0; i < filters; ++i)
{
bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5));
weightPtr++;
}
float* val = new float[size];
for (int i = 0; i < size; ++i)
{
val[i] = weights[weightPtr];
weightPtr++;
}
convWt.values = val;
trtWeights.push_back(convWt);
trtWeights.push_back(convBias);
}
nvinfer1::IConvolutionLayer* conv = network->addConvolution(
*input, filters, nvinfer1::DimsHW{kernelSize, kernelSize}, convWt, convBias);
assert(conv != nullptr);
std::string convLayerName = "conv_" + std::to_string(layerIdx);
conv->setName(convLayerName.c_str());
conv->setStride(nvinfer1::DimsHW{stride, stride});
conv->setPadding(nvinfer1::DimsHW{pad, pad});
if (block.find("groups") != block.end())
{
conv->setNbGroups(groups);
}
nvinfer1::ILayer* output = conv;
if (batchNormalize == true)
{
size = filters;
nvinfer1::Weights shift{nvinfer1::DataType::kFLOAT, nullptr, size};
nvinfer1::Weights scale{nvinfer1::DataType::kFLOAT, nullptr, size};
nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, size};
float* shiftWt = new float[size];
for (int i = 0; i < size; ++i)
{
shiftWt[i]
= bnBiases.at(i) - ((bnRunningMean.at(i) * bnWeights.at(i)) / bnRunningVar.at(i));
}
shift.values = shiftWt;
float* scaleWt = new float[size];
for (int i = 0; i < size; ++i)
{
scaleWt[i] = bnWeights.at(i) / bnRunningVar[i];
}
scale.values = scaleWt;
float* powerWt = new float[size];
for (int i = 0; i < size; ++i)
{
powerWt[i] = 1.0;
}
power.values = powerWt;
trtWeights.push_back(shift);
trtWeights.push_back(scale);
trtWeights.push_back(power);
nvinfer1::IScaleLayer* bn = network->addScale(
*output->getOutput(0), nvinfer1::ScaleMode::kCHANNEL, shift, scale, power);
assert(bn != nullptr);
std::string bnLayerName = "batch_norm_" + std::to_string(layerIdx);
bn->setName(bnLayerName.c_str());
output = bn;
}
output = activationLayer(layerIdx, activation, output, output->getOutput(0), network);
assert(output != nullptr);
return output;
}

View File

@@ -0,0 +1,26 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#ifndef __CONVOLUTIONAL_LAYER_H__
#define __CONVOLUTIONAL_LAYER_H__
#include <map>
#include <vector>
#include "NvInfer.h"
#include "activation_layer.h"
nvinfer1::ILayer* convolutionalLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& weightPtr,
int& inputChannels,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -0,0 +1,15 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "dropout_layer.h"
nvinfer1::ILayer* dropoutLayer(
float probability,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
nvinfer1::ILayer* output;
return output;
}

View File

@@ -0,0 +1,16 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#ifndef __DROPOUT_LAYER_H__
#define __DROPOUT_LAYER_H__
#include "NvInfer.h"
nvinfer1::ILayer* dropoutLayer(
float probability,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -0,0 +1,30 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "maxpool_layer.h"
nvinfer1::ILayer* maxpoolLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
assert(block.at("type") == "maxpool");
assert(block.find("size") != block.end());
assert(block.find("stride") != block.end());
int size = std::stoi(block.at("size"));
int stride = std::stoi(block.at("stride"));
nvinfer1::IPoolingLayer* pool
= network->addPooling(*input, nvinfer1::PoolingType::kMAX, nvinfer1::DimsHW{size, size});
assert(pool);
std::string maxpoolLayerName = "maxpool_" + std::to_string(layerIdx);
pool->setStride(nvinfer1::DimsHW{stride, stride});
pool->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
pool->setName(maxpoolLayerName.c_str());
return pool;
}

View File

@@ -0,0 +1,20 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#ifndef __MAXPOOL_LAYER_H__
#define __MAXPOOL_LAYER_H__
#include <map>
#include <cassert>
#include "NvInfer.h"
nvinfer1::ILayer* maxpoolLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -0,0 +1,63 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "route_layer.h"
nvinfer1::ILayer* routeLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<nvinfer1::ITensor*> tensorOutputs,
nvinfer1::INetworkDefinition* network)
{
std::string strLayers = block.at("layers");
std::vector<int> idxLayers;
size_t lastPos = 0, pos = 0;
while ((pos = strLayers.find(',', lastPos)) != std::string::npos) {
int vL = std::stoi(trim(strLayers.substr(lastPos, pos - lastPos)));
idxLayers.push_back (vL);
lastPos = pos + 1;
}
if (lastPos < strLayers.length()) {
std::string lastV = trim(strLayers.substr(lastPos));
if (!lastV.empty()) {
idxLayers.push_back (std::stoi(lastV));
}
}
assert (!idxLayers.empty());
std::vector<nvinfer1::ITensor*> concatInputs;
for (int idxLayer : idxLayers) {
if (idxLayer < 0) {
idxLayer = tensorOutputs.size() + idxLayer;
}
assert (idxLayer >= 0 && idxLayer < (int)tensorOutputs.size());
concatInputs.push_back (tensorOutputs[idxLayer]);
}
nvinfer1::IConcatenationLayer* concat =
network->addConcatenation(concatInputs.data(), concatInputs.size());
assert(concat != nullptr);
std::string concatLayerName = "route_" + std::to_string(layerIdx - 1);
concat->setName(concatLayerName.c_str());
concat->setAxis(0);
nvinfer1::ILayer* output = concat;
if (block.find("groups") != block.end()) {
nvinfer1::Dims prevTensorDims = output->getOutput(0)->getDimensions();
int groups = stoi(block.at("groups"));
int group_id = stoi(block.at("group_id"));
int startSlice = (prevTensorDims.d[0] / groups) * group_id;
int channelSlice = (prevTensorDims.d[0] / groups);
nvinfer1::ISliceLayer* sl = network->addSlice(
*output->getOutput(0),
nvinfer1::Dims3{startSlice, 0, 0},
nvinfer1::Dims3{channelSlice, prevTensorDims.d[1], prevTensorDims.d[2]},
nvinfer1::Dims3{1, 1, 1});
assert(sl != nullptr);
output = sl;
}
return output;
}

View File

@@ -0,0 +1,18 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#ifndef __ROUTE_LAYER_H__
#define __ROUTE_LAYER_H__
#include "NvInfer.h"
#include "../utils.h"
nvinfer1::ILayer* routeLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<nvinfer1::ITensor*> tensorOutputs,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -0,0 +1,45 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "shortcut_layer.h"
nvinfer1::ILayer* shortcutLayer(
int layerIdx,
std::string activation,
std::string inputVol,
std::string shortcutVol,
nvinfer1::ITensor* input,
nvinfer1::ITensor* shortcutTensor,
nvinfer1::INetworkDefinition* network)
{
nvinfer1::ILayer* output;
nvinfer1::ITensor* outputTensor;
if (inputVol != shortcutVol)
{
nvinfer1::ISliceLayer* sl = network->addSlice(
*shortcutTensor,
nvinfer1::Dims3{0, 0, 0},
input->getDimensions(),
nvinfer1::Dims3{1, 1, 1});
assert(sl != nullptr);
outputTensor = sl->getOutput(0);
assert(outputTensor != nullptr);
} else
{
outputTensor = shortcutTensor;
assert(outputTensor != nullptr);
}
nvinfer1::IElementWiseLayer* ew = network->addElementWise(
*input, *outputTensor,
nvinfer1::ElementWiseOperation::kSUM);
assert(ew != nullptr);
output = activationLayer(layerIdx, activation, ew, ew->getOutput(0), network);
assert(output != nullptr);
return output;
}

View File

@@ -0,0 +1,22 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#ifndef __SHORTCUT_LAYER_H__
#define __SHORTCUT_LAYER_H__
#include "NvInfer.h"
#include "activation_layer.h"
nvinfer1::ILayer* shortcutLayer(
int layerIdx,
std::string activation,
std::string inputVol,
std::string shortcutVol,
nvinfer1::ITensor* input,
nvinfer1::ITensor* shortcutTensor,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -0,0 +1,86 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "upsample_layer.h"
nvinfer1::ILayer* upsampleLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& inputChannels,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
assert(block.at("type") == "upsample");
nvinfer1::Dims inpDims = input->getDimensions();
assert(inpDims.nbDims == 3);
assert(inpDims.d[1] == inpDims.d[2]);
int h = inpDims.d[1];
int w = inpDims.d[2];
int stride = std::stoi(block.at("stride"));
nvinfer1::Dims preDims{3,
{1, stride * h, w},
{nvinfer1::DimensionType::kCHANNEL, nvinfer1::DimensionType::kSPATIAL,
nvinfer1::DimensionType::kSPATIAL}};
int size = stride * h * w;
nvinfer1::Weights preMul{nvinfer1::DataType::kFLOAT, nullptr, size};
float* preWt = new float[size];
for (int i = 0, idx = 0; i < h; ++i)
{
for (int s = 0; s < stride; ++s)
{
for (int j = 0; j < w; ++j, ++idx)
{
preWt[idx] = (i == j) ? 1.0 : 0.0;
}
}
}
preMul.values = preWt;
trtWeights.push_back(preMul);
nvinfer1::IConstantLayer* preM = network->addConstant(preDims, preMul);
assert(preM != nullptr);
std::string preLayerName = "preMul_" + std::to_string(layerIdx);
preM->setName(preLayerName.c_str());
nvinfer1::Dims postDims{3,
{1, h, stride * w},
{nvinfer1::DimensionType::kCHANNEL, nvinfer1::DimensionType::kSPATIAL,
nvinfer1::DimensionType::kSPATIAL}};
size = stride * h * w;
nvinfer1::Weights postMul{nvinfer1::DataType::kFLOAT, nullptr, size};
float* postWt = new float[size];
for (int i = 0, idx = 0; i < h; ++i)
{
for (int j = 0; j < stride * w; ++j, ++idx)
{
postWt[idx] = (j / stride == i) ? 1.0 : 0.0;
}
}
postMul.values = postWt;
trtWeights.push_back(postMul);
nvinfer1::IConstantLayer* post_m = network->addConstant(postDims, postMul);
assert(post_m != nullptr);
std::string postLayerName = "postMul_" + std::to_string(layerIdx);
post_m->setName(postLayerName.c_str());
nvinfer1::IMatrixMultiplyLayer* mm1
= network->addMatrixMultiply(*preM->getOutput(0), nvinfer1::MatrixOperation::kNONE, *input,
nvinfer1::MatrixOperation::kNONE);
assert(mm1 != nullptr);
std::string mm1LayerName = "mm1_" + std::to_string(layerIdx);
mm1->setName(mm1LayerName.c_str());
nvinfer1::IMatrixMultiplyLayer* mm2
= network->addMatrixMultiply(*mm1->getOutput(0), nvinfer1::MatrixOperation::kNONE,
*post_m->getOutput(0), nvinfer1::MatrixOperation::kNONE);
assert(mm2 != nullptr);
std::string mm2LayerName = "mm2_" + std::to_string(layerIdx);
mm2->setName(mm2LayerName.c_str());
return mm2;
}

View File

@@ -0,0 +1,24 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#ifndef __UPSAMPLE_LAYER_H__
#define __UPSAMPLE_LAYER_H__
#include <map>
#include <vector>
#include <cassert>
#include "NvInfer.h"
nvinfer1::ILayer* upsampleLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& inputChannels,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -0,0 +1,107 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
* Edited by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "nvdsinfer_custom_impl.h"
#include "nvdsinfer_context.h"
#include "yoloPlugins.h"
#include "yolo.h"
#include <algorithm>
#define USE_CUDA_ENGINE_GET_API 1
static bool getYoloNetworkInfo (NetworkInfo &networkInfo, const NvDsInferContextInitParams* initParams)
{
std::string yoloCfg = initParams->customNetworkConfigFilePath;
std::string yoloType;
std::transform (yoloCfg.begin(), yoloCfg.end(), yoloCfg.begin(), [] (uint8_t c) {
return std::tolower (c);});
yoloType = yoloCfg.substr(0, yoloCfg.find(".cfg"));
networkInfo.networkType = yoloType;
networkInfo.configFilePath = initParams->customNetworkConfigFilePath;
networkInfo.wtsFilePath = initParams->modelFilePath;
networkInfo.deviceType = (initParams->useDLA ? "kDLA" : "kGPU");
networkInfo.inputBlobName = "data";
if (networkInfo.configFilePath.empty() ||
networkInfo.wtsFilePath.empty()) {
std::cerr << "YOLO config file or weights file is not specified"
<< std::endl;
return false;
}
if (!fileExists(networkInfo.configFilePath) ||
!fileExists(networkInfo.wtsFilePath)) {
std::cerr << "YOLO config file or weights file is not exist"
<< std::endl;
return false;
}
return true;
}
#if !USE_CUDA_ENGINE_GET_API
IModelParser* NvDsInferCreateModelParser(
const NvDsInferContextInitParams* initParams) {
NetworkInfo networkInfo;
if (!getYoloNetworkInfo(networkInfo, initParams)) {
return nullptr;
}
return new Yolo(networkInfo);
}
#else
extern "C"
bool NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder * const builder,
const NvDsInferContextInitParams * const initParams,
nvinfer1::DataType dataType,
nvinfer1::ICudaEngine *& cudaEngine);
extern "C"
bool NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder * const builder,
const NvDsInferContextInitParams * const initParams,
nvinfer1::DataType dataType,
nvinfer1::ICudaEngine *& cudaEngine)
{
NetworkInfo networkInfo;
if (!getYoloNetworkInfo(networkInfo, initParams)) {
return false;
}
Yolo yolo(networkInfo);
cudaEngine = yolo.createEngine (builder);
if (cudaEngine == nullptr)
{
std::cerr << "Failed to build CUDA engine on "
<< networkInfo.configFilePath << std::endl;
return false;
}
return true;
}
#endif

View File

@@ -0,0 +1,380 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
* Edited by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include <algorithm>
#include <cmath>
#include <sstream>
#include "nvdsinfer_custom_impl.h"
#include "utils.h"
#include "yoloPlugins.h"
extern "C" bool NvDsInferParseYolo(
std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams,
std::vector<NvDsInferParseObjectInfo>& objectList);
static std::vector<NvDsInferParseObjectInfo>
nonMaximumSuppression(const float nmsThresh, std::vector<NvDsInferParseObjectInfo> binfo)
{
auto overlap1D = [](float x1min, float x1max, float x2min, float x2max) -> float {
if (x1min > x2min)
{
std::swap(x1min, x2min);
std::swap(x1max, x2max);
}
return x1max < x2min ? 0 : std::min(x1max, x2max) - x2min;
};
auto computeIoU
= [&overlap1D](NvDsInferParseObjectInfo& bbox1, NvDsInferParseObjectInfo& bbox2) -> float {
float overlapX
= overlap1D(bbox1.left, bbox1.left + bbox1.width, bbox2.left, bbox2.left + bbox2.width);
float overlapY
= overlap1D(bbox1.top, bbox1.top + bbox1.height, bbox2.top, bbox2.top + bbox2.height);
float area1 = (bbox1.width) * (bbox1.height);
float area2 = (bbox2.width) * (bbox2.height);
float overlap2D = overlapX * overlapY;
float u = area1 + area2 - overlap2D;
return u == 0 ? 0 : overlap2D / u;
};
std::stable_sort(binfo.begin(), binfo.end(),
[](const NvDsInferParseObjectInfo& b1, const NvDsInferParseObjectInfo& b2) {
return b1.detectionConfidence > b2.detectionConfidence;
});
std::vector<NvDsInferParseObjectInfo> out;
for (auto i : binfo)
{
bool keep = true;
for (auto j : out)
{
if (keep)
{
float overlap = computeIoU(i, j);
keep = overlap <= nmsThresh;
}
else
break;
}
if (keep) out.push_back(i);
}
return out;
}
static std::vector<NvDsInferParseObjectInfo>
nmsAllClasses(const float nmsThresh,
std::vector<NvDsInferParseObjectInfo>& binfo,
const uint numClasses)
{
std::vector<NvDsInferParseObjectInfo> result;
std::vector<std::vector<NvDsInferParseObjectInfo>> splitBoxes(numClasses);
for (auto& box : binfo)
{
splitBoxes.at(box.classId).push_back(box);
}
for (auto& boxes : splitBoxes)
{
boxes = nonMaximumSuppression(nmsThresh, boxes);
result.insert(result.end(), boxes.begin(), boxes.end());
}
return result;
}
static NvDsInferParseObjectInfo convertBBox(const float& bx, const float& by, const float& bw,
const float& bh, const int& stride, const uint& netW,
const uint& netH)
{
NvDsInferParseObjectInfo b;
float xCenter = bx * stride;
float yCenter = by * stride;
float x0 = xCenter - bw / 2;
float y0 = yCenter - bh / 2;
float x1 = x0 + bw;
float y1 = y0 + bh;
x0 = clamp(x0, 0, netW);
y0 = clamp(y0, 0, netH);
x1 = clamp(x1, 0, netW);
y1 = clamp(y1, 0, netH);
b.left = x0;
b.width = clamp(x1 - x0, 0, netW);
b.top = y0;
b.height = clamp(y1 - y0, 0, netH);
return b;
}
static void addBBoxProposal(const float bx, const float by, const float bw, const float bh,
const uint stride, const uint& netW, const uint& netH, const int maxIndex,
const float maxProb, std::vector<NvDsInferParseObjectInfo>& binfo)
{
NvDsInferParseObjectInfo bbi = convertBBox(bx, by, bw, bh, stride, netW, netH);
if (bbi.width < 1 || bbi.height < 1) return;
bbi.detectionConfidence = maxProb;
bbi.classId = maxIndex;
binfo.push_back(bbi);
}
static std::vector<NvDsInferParseObjectInfo>
decodeYoloTensor(
const float* detections, const std::vector<int> &mask, const std::vector<float> &anchors,
const uint gridSizeW, const uint gridSizeH, const uint stride, const uint numBBoxes,
const uint numOutputClasses, const uint& netW,
const uint& netH,
const float confThresh)
{
std::vector<NvDsInferParseObjectInfo> binfo;
for (uint y = 0; y < gridSizeH; ++y) {
for (uint x = 0; x < gridSizeW; ++x) {
for (uint b = 0; b < numBBoxes; ++b)
{
const float pw = anchors[mask[b] * 2];
const float ph = anchors[mask[b] * 2 + 1];
const int numGridCells = gridSizeH * gridSizeW;
const int bbindex = y * gridSizeW + x;
const float bx
= x + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 0)];
const float by
= y + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 1)];
const float bw
= pw * detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 2)];
const float bh
= ph * detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 3)];
const float objectness
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 4)];
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numOutputClasses; ++i)
{
float prob
= (detections[bbindex
+ numGridCells * (b * (5 + numOutputClasses) + (5 + i))]);
if (prob > maxProb)
{
maxProb = prob;
maxIndex = i;
}
}
maxProb = objectness * maxProb;
if (maxProb > confThresh)
{
addBBoxProposal(bx, by, bw, bh, stride, netW, netH, maxIndex, maxProb, binfo);
}
}
}
}
return binfo;
}
static std::vector<NvDsInferParseObjectInfo>
decodeYoloV2Tensor(
const float* detections, const std::vector<float> &anchors,
const uint gridSizeW, const uint gridSizeH, const uint stride, const uint numBBoxes,
const uint numOutputClasses, const uint& netW,
const uint& netH)
{
std::vector<NvDsInferParseObjectInfo> binfo;
for (uint y = 0; y < gridSizeH; ++y) {
for (uint x = 0; x < gridSizeW; ++x) {
for (uint b = 0; b < numBBoxes; ++b)
{
const float pw = anchors[b * 2];
const float ph = anchors[b * 2 + 1];
const int numGridCells = gridSizeH * gridSizeW;
const int bbindex = y * gridSizeW + x;
const float bx
= x + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 0)];
const float by
= y + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 1)];
const float bw
= pw * detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 2)];
const float bh
= ph * detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 3)];
const float objectness
= detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 4)];
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numOutputClasses; ++i)
{
float prob
= (detections[bbindex
+ numGridCells * (b * (5 + numOutputClasses) + (5 + i))]);
if (prob > maxProb)
{
maxProb = prob;
maxIndex = i;
}
}
maxProb = objectness * maxProb;
addBBoxProposal(bx, by, bw, bh, stride, netW, netH, maxIndex, maxProb, binfo);
}
}
}
return binfo;
}
static inline std::vector<const NvDsInferLayerInfo*>
SortLayers(const std::vector<NvDsInferLayerInfo> & outputLayersInfo)
{
std::vector<const NvDsInferLayerInfo*> outLayers;
for (auto const &layer : outputLayersInfo) {
outLayers.push_back (&layer);
}
std::sort(outLayers.begin(), outLayers.end(),
[](const NvDsInferLayerInfo* a, const NvDsInferLayerInfo* b) {
return a->inferDims.d[1] < b->inferDims.d[1];
});
return outLayers;
}
static bool NvDsInferParseYolo(
std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams,
std::vector<NvDsInferParseObjectInfo>& objectList,
const std::vector<float> &anchors,
const std::vector<std::vector<int>> &masks,
const uint &num_classes,
const float &beta_nms)
{
const float kCONF_THRESH = detectionParams.perClassThreshold[0];
const std::vector<const NvDsInferLayerInfo*> sortedLayers =
SortLayers (outputLayersInfo);
if (sortedLayers.size() != masks.size()) {
std::cerr << "ERROR: YOLO output layer.size: " << sortedLayers.size()
<< " does not match mask.size: " << masks.size() << std::endl;
return false;
}
if (num_classes != detectionParams.numClassesConfigured)
{
std::cerr << "WARNING: Num classes mismatch. Configured: "
<< detectionParams.numClassesConfigured
<< ", detected by network: " << num_classes << std::endl;
}
std::vector<NvDsInferParseObjectInfo> objects;
for (uint idx = 0; idx < masks.size(); ++idx) {
const NvDsInferLayerInfo &layer = *sortedLayers[idx]; // 255 x Grid x Grid
assert(layer.inferDims.numDims == 3);
const uint gridSizeH = layer.inferDims.d[1];
const uint gridSizeW = layer.inferDims.d[2];
const uint stride = DIVUP(networkInfo.width, gridSizeW);
assert(stride == DIVUP(networkInfo.height, gridSizeH));
std::vector<NvDsInferParseObjectInfo> outObjs =
decodeYoloTensor((const float*)(layer.buffer), masks[idx], anchors, gridSizeW, gridSizeH, stride, masks[idx].size(),
num_classes, networkInfo.width, networkInfo.height, kCONF_THRESH);
objects.insert(objects.end(), outObjs.begin(), outObjs.end());
}
objectList.clear();
objectList = nmsAllClasses(beta_nms, objects, num_classes);
return true;
}
static bool NvDsInferParseYoloV2(
std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams,
std::vector<NvDsInferParseObjectInfo>& objectList,
std::vector<float> &anchors,
const uint &num_classes)
{
if (outputLayersInfo.empty()) {
std::cerr << "Could not find output layer in bbox parsing" << std::endl;;
return false;
}
const uint kNUM_BBOXES = anchors.size() / 2;
const NvDsInferLayerInfo &layer = outputLayersInfo[0];
if (num_classes != detectionParams.numClassesConfigured)
{
std::cerr << "WARNING: Num classes mismatch. Configured: "
<< detectionParams.numClassesConfigured
<< ", detected by network: " << num_classes << std::endl;
}
assert(layer.inferDims.numDims == 3);
const uint gridSizeH = layer.inferDims.d[1];
const uint gridSizeW = layer.inferDims.d[2];
const uint stride = DIVUP(networkInfo.width, gridSizeW);
assert(stride == DIVUP(networkInfo.height, gridSizeH));
for (auto& anchor : anchors) {
anchor *= stride;
}
std::vector<NvDsInferParseObjectInfo> objects =
decodeYoloV2Tensor((const float*)(layer.buffer), anchors, gridSizeW, gridSizeH, stride, kNUM_BBOXES,
num_classes, networkInfo.width, networkInfo.height);
objectList = objects;
return true;
}
extern "C" bool NvDsInferParseYolo(
std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams,
std::vector<NvDsInferParseObjectInfo>& objectList)
{
int num_classes = kNUM_CLASSES;
float beta_nms = kBETA_NMS;
std::vector<float> anchors = kANCHORS;
std::vector<std::vector<int>> mask = kMASK;
if (mask.size() > 0) {
return NvDsInferParseYolo (outputLayersInfo, networkInfo, detectionParams, objectList, anchors, mask, num_classes, beta_nms);
}
else {
return NvDsInferParseYoloV2 (outputLayersInfo, networkInfo, detectionParams, objectList, anchors, num_classes);
}
}
CHECK_CUSTOM_PARSE_FUNC_PROTOTYPE(NvDsInferParseYolo);

View File

@@ -0,0 +1,150 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
* Edited by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "utils.h"
#include <experimental/filesystem>
#include <iomanip>
#include <algorithm>
#include <math.h>
static void leftTrim(std::string& s)
{
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !isspace(ch); }));
}
static void rightTrim(std::string& s)
{
s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !isspace(ch); }).base(), s.end());
}
std::string trim(std::string s)
{
leftTrim(s);
rightTrim(s);
return s;
}
float clamp(const float val, const float minVal, const float maxVal)
{
assert(minVal <= maxVal);
return std::min(maxVal, std::max(minVal, val));
}
bool fileExists(const std::string fileName, bool verbose)
{
if (!std::experimental::filesystem::exists(std::experimental::filesystem::path(fileName)))
{
if (verbose) std::cout << "File does not exist: " << fileName << std::endl;
return false;
}
return true;
}
std::vector<float> loadWeights(const std::string weightsFilePath, const std::string& networkType)
{
assert(fileExists(weightsFilePath));
std::cout << "\nLoading pre-trained weights" << std::endl;
std::ifstream file(weightsFilePath, std::ios_base::binary);
assert(file.good());
std::string line;
if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos)
{
// Remove 4 int32 bytes of data from the stream belonging to the header
file.ignore(4 * 4);
}
else
{
// Remove 5 int32 bytes of data from the stream belonging to the header
file.ignore(4 * 5);
}
std::vector<float> weights;
char floatWeight[4];
while (!file.eof())
{
file.read(floatWeight, 4);
assert(file.gcount() == 4);
weights.push_back(*reinterpret_cast<float*>(floatWeight));
if (file.peek() == std::istream::traits_type::eof()) break;
}
std::cout << "Loading weights of " << networkType << " complete"
<< std::endl;
std::cout << "Total weights read: " << weights.size() << std::endl;
return weights;
}
std::string dimsToString(const nvinfer1::Dims d)
{
std::stringstream s;
assert(d.nbDims >= 1);
for (int i = 0; i < d.nbDims - 1; ++i)
{
s << std::setw(4) << d.d[i] << " x";
}
s << std::setw(4) << d.d[d.nbDims - 1];
return s.str();
}
void displayDimType(const nvinfer1::Dims d)
{
std::cout << "(" << d.nbDims << ") ";
for (int i = 0; i < d.nbDims; ++i)
{
switch (d.type[i])
{
case nvinfer1::DimensionType::kSPATIAL: std::cout << "kSPATIAL "; break;
case nvinfer1::DimensionType::kCHANNEL: std::cout << "kCHANNEL "; break;
case nvinfer1::DimensionType::kINDEX: std::cout << "kINDEX "; break;
case nvinfer1::DimensionType::kSEQUENCE: std::cout << "kSEQUENCE "; break;
}
}
std::cout << std::endl;
}
int getNumChannels(nvinfer1::ITensor* t)
{
nvinfer1::Dims d = t->getDimensions();
assert(d.nbDims == 3);
return d.d[0];
}
uint64_t get3DTensorVolume(nvinfer1::Dims inputDims)
{
assert(inputDims.nbDims == 3);
return inputDims.d[0] * inputDims.d[1] * inputDims.d[2];
}
void printLayerInfo(std::string layerIndex, std::string layerName, std::string layerInput,
std::string layerOutput, std::string weightPtr)
{
std::cout << std::setw(6) << std::left << layerIndex << std::setw(24) << std::left << layerName;
std::cout << std::setw(20) << std::left << layerInput << std::setw(20) << std::left
<< layerOutput;
std::cout << std::setw(7) << std::left << weightPtr << std::endl;
}

View File

@@ -0,0 +1,53 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
* Edited by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#ifndef __UTILS_H__
#define __UTILS_H__
#include <map>
#include <vector>
#include <cassert>
#include <iostream>
#include <fstream>
#include "NvInfer.h"
#define UNUSED(expr) (void)(expr)
#define DIVUP(n, d) ((n) + (d)-1) / (d)
std::string trim(std::string s);
float clamp(const float val, const float minVal, const float maxVal);
bool fileExists(const std::string fileName, bool verbose = true);
std::vector<float> loadWeights(const std::string weightsFilePath, const std::string& networkType);
std::string dimsToString(const nvinfer1::Dims d);
void displayDimType(const nvinfer1::Dims d);
int getNumChannels(nvinfer1::ITensor* t);
uint64_t get3DTensorVolume(nvinfer1::Dims inputDims);
void printLayerInfo(std::string layerIndex, std::string layerName, std::string layerInput,
std::string layerOutput, std::string weightPtr);
#endif

View File

@@ -0,0 +1,465 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
* Edited by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "yolo.h"
#include "yoloPlugins.h"
void orderParams(std::vector<std::vector<int>> *maskVector) {
std::vector<std::vector<int>> maskinput = *maskVector;
std::vector<int> maskPartial;
for (uint i = 0; i < maskinput.size(); i++) {
for (uint j = i + 1; j < maskinput.size(); j++) {
if (maskinput[i][0] <= maskinput[j][0]) {
maskPartial = maskinput[i];
maskinput[i] = maskinput[j];
maskinput[j] = maskPartial;
}
}
}
*maskVector = maskinput;
}
Yolo::Yolo(const NetworkInfo& networkInfo)
: m_NetworkType(networkInfo.networkType), // YOLO type
m_ConfigFilePath(networkInfo.configFilePath), // YOLO cfg
m_WtsFilePath(networkInfo.wtsFilePath), // YOLO weights
m_DeviceType(networkInfo.deviceType), // kDLA, kGPU
m_InputBlobName(networkInfo.inputBlobName), // data
m_InputH(0),
m_InputW(0),
m_InputC(0),
m_InputSize(0)
{}
Yolo::~Yolo()
{
destroyNetworkUtils();
}
nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
{
assert (builder);
std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType);
std::vector<nvinfer1::Weights> trtWeights;
nvinfer1::INetworkDefinition *network = builder->createNetwork();
if (parseModel(*network) != NVDSINFER_SUCCESS) {
network->destroy();
return nullptr;
}
// Build the engine
std::cout << "Building the TensorRT Engine" << std::endl;
nvinfer1::ICudaEngine * engine = builder->buildCudaEngine(*network);
if (engine) {
std::cout << "Building complete\n" << std::endl;
} else {
std::cerr << "Building engine failed\n" << std::endl;
}
// destroy
network->destroy();
return engine;
}
NvDsInferStatus Yolo::parseModel(nvinfer1::INetworkDefinition& network) {
destroyNetworkUtils();
m_ConfigBlocks = parseConfigFile(m_ConfigFilePath);
parseConfigBlocks();
orderParams(&m_OutputMasks);
std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType);
// build yolo network
std::cout << "Building YOLO network" << std::endl;
NvDsInferStatus status = buildYoloNetwork(weights, network);
if (status == NVDSINFER_SUCCESS) {
std::cout << "Building YOLO network complete" << std::endl;
} else {
std::cerr << "Building YOLO network failed" << std::endl;
}
return status;
}
NvDsInferStatus Yolo::buildYoloNetwork(
std::vector<float>& weights, nvinfer1::INetworkDefinition& network) {
int weightPtr = 0;
int channels = m_InputC;
nvinfer1::ITensor* data =
network.addInput(m_InputBlobName.c_str(), nvinfer1::DataType::kFLOAT,
nvinfer1::DimsCHW{static_cast<int>(m_InputC),
static_cast<int>(m_InputH), static_cast<int>(m_InputW)});
assert(data != nullptr && data->getDimensions().nbDims > 0);
nvinfer1::ITensor* previous = data;
std::vector<nvinfer1::ITensor*> tensorOutputs;
uint outputTensorCount = 0;
// build the network using the network API
for (uint i = 0; i < m_ConfigBlocks.size(); ++i) {
// check if num. of channels is correct
assert(getNumChannels(previous) == channels);
std::string layerIndex = "(" + std::to_string(tensorOutputs.size()) + ")";
if (m_ConfigBlocks.at(i).at("type") == "net") {
printLayerInfo("", "layer", " input", " outup", "weightPtr");
}
else if (m_ConfigBlocks.at(i).at("type") == "convolutional") {
std::string inputVol = dimsToString(previous->getDimensions());
nvinfer1::ILayer* out = convolutionalLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, channels, previous, &network);
previous = out->getOutput(0);
assert(previous != nullptr);
channels = getNumChannels(previous);
std::string outputVol = dimsToString(previous->getDimensions());
tensorOutputs.push_back(previous);
std::string layerType = "conv_" + m_ConfigBlocks.at(i).at("activation");
printLayerInfo(layerIndex, layerType, inputVol, outputVol, std::to_string(weightPtr));
}
else if (m_ConfigBlocks.at(i).at("type") == "dropout") {
assert(m_ConfigBlocks.at(i).find("probability") != m_ConfigBlocks.at(i).end());
//float probability = std::stof(m_ConfigBlocks.at(i).at("probability"));
//nvinfer1::ILayer* out = dropoutLayer(probability, previous, &network);
//previous = out->getOutput(0);
//Skip dropout layer
assert(previous != nullptr);
tensorOutputs.push_back(previous);
printLayerInfo(layerIndex, "dropout", " -", " -", " -");
}
else if (m_ConfigBlocks.at(i).at("type") == "shortcut") {
assert(m_ConfigBlocks.at(i).find("activation") != m_ConfigBlocks.at(i).end());
assert(m_ConfigBlocks.at(i).find("from") != m_ConfigBlocks.at(i).end());
std::string activation = m_ConfigBlocks.at(i).at("activation");
int from = stoi(m_ConfigBlocks.at(i).at("from"));
if (from > 0) {
from = from - i + 1;
}
assert((i - 2 >= 0) && (i - 2 < tensorOutputs.size()));
assert((i + from - 1 >= 0) && (i + from - 1 < tensorOutputs.size()));
assert(i + from - 1 < i - 2);
std::string inputVol = dimsToString(previous->getDimensions());
std::string shortcutVol = dimsToString(tensorOutputs[i + from - 1]->getDimensions());
nvinfer1::ILayer* out = shortcutLayer(i, activation, inputVol, shortcutVol, previous, tensorOutputs[i + from - 1], &network);
previous = out->getOutput(0);
assert(previous != nullptr);
std::string outputVol = dimsToString(previous->getDimensions());
tensorOutputs.push_back(previous);
std::string layerType = "shortcut_" + m_ConfigBlocks.at(i).at("activation") + ": " + std::to_string(i + from - 1);
printLayerInfo(layerIndex, layerType, " -", outputVol, " -");
if (inputVol != shortcutVol) {
std::cout << inputVol << " +" << shortcutVol << std::endl;
}
}
else if (m_ConfigBlocks.at(i).at("type") == "route") {
assert(m_ConfigBlocks.at(i).find("layers") != m_ConfigBlocks.at(i).end());
nvinfer1::ILayer* out = routeLayer(i, m_ConfigBlocks.at(i), tensorOutputs, &network);
previous = out->getOutput(0);
assert(previous != nullptr);
channels = getNumChannels(previous);
std::string outputVol = dimsToString(previous->getDimensions());
tensorOutputs.push_back(previous);
printLayerInfo(layerIndex, "route", " -", outputVol, std::to_string(weightPtr));
}
else if (m_ConfigBlocks.at(i).at("type") == "upsample") {
std::string inputVol = dimsToString(previous->getDimensions());
nvinfer1::ILayer* out = upsampleLayer(i - 1, m_ConfigBlocks[i], weights, m_TrtWeights, channels, previous, &network);
previous = out->getOutput(0);
assert(previous != nullptr);
std::string outputVol = dimsToString(previous->getDimensions());
tensorOutputs.push_back(previous);
printLayerInfo(layerIndex, "upsample", inputVol, outputVol, " -");
}
else if (m_ConfigBlocks.at(i).at("type") == "maxpool") {
std::string inputVol = dimsToString(previous->getDimensions());
nvinfer1::ILayer* out = maxpoolLayer(i, m_ConfigBlocks.at(i), previous, &network);
previous = out->getOutput(0);
assert(previous != nullptr);
std::string outputVol = dimsToString(previous->getDimensions());
tensorOutputs.push_back(previous);
printLayerInfo(layerIndex, "maxpool", inputVol, outputVol, std::to_string(weightPtr));
}
else if (m_ConfigBlocks.at(i).at("type") == "yolo") {
nvinfer1::Dims prevTensorDims = previous->getDimensions();
assert(prevTensorDims.d[1] == prevTensorDims.d[2]);
TensorInfo& curYoloTensor = m_OutputTensors.at(outputTensorCount);
curYoloTensor.gridSize = prevTensorDims.d[1];
curYoloTensor.stride = m_InputW / curYoloTensor.gridSize;
m_OutputTensors.at(outputTensorCount).volume = curYoloTensor.gridSize
* curYoloTensor.gridSize
* (curYoloTensor.numBBoxes * (5 + curYoloTensor.numClasses));
std::string layerName = "yolo_" + std::to_string(i);
curYoloTensor.blobName = layerName;
int new_coords = 0;
float scale_x_y = 1;
float beta_nms = 0.45;
if (m_ConfigBlocks.at(i).find("new_coords") != m_ConfigBlocks.at(i).end()) {
new_coords = std::stoi(m_ConfigBlocks.at(i).at("new_coords"));
}
if (m_ConfigBlocks.at(i).find("scale_x_y") != m_ConfigBlocks.at(i).end()) {
scale_x_y = std::stof(m_ConfigBlocks.at(i).at("scale_x_y"));
}
if (m_ConfigBlocks.at(i).find("beta_nms") != m_ConfigBlocks.at(i).end()) {
beta_nms = std::stof(m_ConfigBlocks.at(i).at("beta_nms"));
}
nvinfer1::IPluginV2* yoloPlugin
= new YoloLayer(m_OutputTensors.at(outputTensorCount).numBBoxes,
m_OutputTensors.at(outputTensorCount).numClasses,
m_OutputTensors.at(outputTensorCount).gridSize,
'y', new_coords, scale_x_y, beta_nms,
curYoloTensor.anchors,
m_OutputMasks);
assert(yoloPlugin != nullptr);
nvinfer1::IPluginV2Layer* yolo =
network.addPluginV2(&previous, 1, *yoloPlugin);
assert(yolo != nullptr);
yolo->setName(layerName.c_str());
std::string inputVol = dimsToString(previous->getDimensions());
previous = yolo->getOutput(0);
assert(previous != nullptr);
previous->setName(layerName.c_str());
std::string outputVol = dimsToString(previous->getDimensions());
network.markOutput(*previous);
channels = getNumChannels(previous);
tensorOutputs.push_back(yolo->getOutput(0));
printLayerInfo(layerIndex, "yolo", inputVol, outputVol, std::to_string(weightPtr));
++outputTensorCount;
}
//YOLOv2 support
else if (m_ConfigBlocks.at(i).at("type") == "region") {
nvinfer1::Dims prevTensorDims = previous->getDimensions();
assert(prevTensorDims.d[1] == prevTensorDims.d[2]);
TensorInfo& curRegionTensor = m_OutputTensors.at(outputTensorCount);
curRegionTensor.gridSize = prevTensorDims.d[1];
curRegionTensor.stride = m_InputW / curRegionTensor.gridSize;
m_OutputTensors.at(outputTensorCount).volume = curRegionTensor.gridSize
* curRegionTensor.gridSize
* (curRegionTensor.numBBoxes * (5 + curRegionTensor.numClasses));
std::string layerName = "region_" + std::to_string(i);
curRegionTensor.blobName = layerName;
std::vector<std::vector<int>> mask;
nvinfer1::IPluginV2* regionPlugin
= new YoloLayer(curRegionTensor.numBBoxes,
curRegionTensor.numClasses,
curRegionTensor.gridSize,
'r', 0, 1.0, 0,
curRegionTensor.anchors,
mask);
assert(regionPlugin != nullptr);
nvinfer1::IPluginV2Layer* region =
network.addPluginV2(&previous, 1, *regionPlugin);
assert(region != nullptr);
region->setName(layerName.c_str());
std::string inputVol = dimsToString(previous->getDimensions());
previous = region->getOutput(0);
assert(previous != nullptr);
previous->setName(layerName.c_str());
std::string outputVol = dimsToString(previous->getDimensions());
network.markOutput(*previous);
channels = getNumChannels(previous);
tensorOutputs.push_back(region->getOutput(0));
printLayerInfo(layerIndex, "region", inputVol, outputVol, std::to_string(weightPtr));
++outputTensorCount;
}
else if (m_ConfigBlocks.at(i).at("type") == "reorg") {
std::string inputVol = dimsToString(previous->getDimensions());
nvinfer1::IPluginV2* reorgPlugin = createReorgPlugin(2);
assert(reorgPlugin != nullptr);
nvinfer1::IPluginV2Layer* reorg =
network.addPluginV2(&previous, 1, *reorgPlugin);
assert(reorg != nullptr);
std::string layerName = "reorg_" + std::to_string(i);
reorg->setName(layerName.c_str());
previous = reorg->getOutput(0);
assert(previous != nullptr);
std::string outputVol = dimsToString(previous->getDimensions());
channels = getNumChannels(previous);
tensorOutputs.push_back(reorg->getOutput(0));
printLayerInfo(layerIndex, "reorg", inputVol, outputVol, std::to_string(weightPtr));
}
else
{
std::cout << "Unsupported layer type --> \""
<< m_ConfigBlocks.at(i).at("type") << "\"" << std::endl;
assert(0);
}
}
if ((int)weights.size() != weightPtr)
{
std::cout << "Number of unused weights left: " << weights.size() - weightPtr << std::endl;
assert(0);
}
std::cout << "Output YOLO blob names: " << std::endl;
for (auto& tensor : m_OutputTensors) {
std::cout << tensor.blobName << std::endl;
}
int nbLayers = network.getNbLayers();
std::cout << "Total number of YOLO layers: " << nbLayers << std::endl;
return NVDSINFER_SUCCESS;
}
std::vector<std::map<std::string, std::string>>
Yolo::parseConfigFile (const std::string cfgFilePath)
{
assert(fileExists(cfgFilePath));
std::ifstream file(cfgFilePath);
assert(file.good());
std::string line;
std::vector<std::map<std::string, std::string>> blocks;
std::map<std::string, std::string> block;
while (getline(file, line))
{
if (line.size() == 0) continue;
if (line.front() == '#') continue;
line = trim(line);
if (line.front() == '[')
{
if (block.size() > 0)
{
blocks.push_back(block);
block.clear();
}
std::string key = "type";
std::string value = trim(line.substr(1, line.size() - 2));
block.insert(std::pair<std::string, std::string>(key, value));
}
else
{
int cpos = line.find('=');
std::string key = trim(line.substr(0, cpos));
std::string value = trim(line.substr(cpos + 1));
block.insert(std::pair<std::string, std::string>(key, value));
}
}
blocks.push_back(block);
return blocks;
}
void Yolo::parseConfigBlocks()
{
for (auto block : m_ConfigBlocks) {
if (block.at("type") == "net")
{
assert((block.find("height") != block.end())
&& "Missing 'height' param in network cfg");
assert((block.find("width") != block.end()) && "Missing 'width' param in network cfg");
assert((block.find("channels") != block.end())
&& "Missing 'channels' param in network cfg");
m_InputH = std::stoul(block.at("height"));
m_InputW = std::stoul(block.at("width"));
m_InputC = std::stoul(block.at("channels"));
assert(m_InputW == m_InputH);
m_InputSize = m_InputC * m_InputH * m_InputW;
}
else if ((block.at("type") == "region") || (block.at("type") == "yolo"))
{
assert((block.find("num") != block.end())
&& std::string("Missing 'num' param in " + block.at("type") + " layer").c_str());
assert((block.find("classes") != block.end())
&& std::string("Missing 'classes' param in " + block.at("type") + " layer")
.c_str());
assert((block.find("anchors") != block.end())
&& std::string("Missing 'anchors' param in " + block.at("type") + " layer")
.c_str());
TensorInfo outputTensor;
std::string anchorString = block.at("anchors");
while (!anchorString.empty())
{
int npos = anchorString.find_first_of(',');
if (npos != -1)
{
float anchor = std::stof(trim(anchorString.substr(0, npos)));
outputTensor.anchors.push_back(anchor);
anchorString.erase(0, npos + 1);
}
else
{
float anchor = std::stof(trim(anchorString));
outputTensor.anchors.push_back(anchor);
break;
}
}
if (block.find("mask") != block.end()) {
std::string maskString = block.at("mask");
std::vector<int> pMASKS;
while (!maskString.empty())
{
int npos = maskString.find_first_of(',');
if (npos != -1)
{
int mask = std::stoul(trim(maskString.substr(0, npos)));
pMASKS.push_back(mask);
outputTensor.masks.push_back(mask);
maskString.erase(0, npos + 1);
}
else
{
int mask = std::stoul(trim(maskString));
pMASKS.push_back(mask);
outputTensor.masks.push_back(mask);
break;
}
}
m_OutputMasks.push_back(pMASKS);
}
outputTensor.numBBoxes = outputTensor.masks.size() > 0
? outputTensor.masks.size()
: std::stoul(trim(block.at("num")));
outputTensor.numClasses = std::stoul(block.at("classes"));
m_OutputTensors.push_back(outputTensor);
}
}
}
void Yolo::destroyNetworkUtils() {
// deallocate the weights
for (uint i = 0; i < m_TrtWeights.size(); ++i) {
if (m_TrtWeights[i].count > 0)
free(const_cast<void*>(m_TrtWeights[i].values));
}
m_TrtWeights.clear();
}

View File

@@ -0,0 +1,99 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
* Edited by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#ifndef _YOLO_H_
#define _YOLO_H_
#include "layers/convolutional_layer.h"
#include "layers/dropout_layer.h"
#include "layers/shortcut_layer.h"
#include "layers/route_layer.h"
#include "layers/upsample_layer.h"
#include "layers/maxpool_layer.h"
#include "nvdsinfer_custom_impl.h"
struct NetworkInfo
{
std::string networkType;
std::string configFilePath;
std::string wtsFilePath;
std::string deviceType;
std::string inputBlobName;
};
struct TensorInfo
{
std::string blobName;
uint stride{0};
uint gridSize{0};
uint numClasses{0};
uint numBBoxes{0};
uint64_t volume{0};
std::vector<uint> masks;
std::vector<float> anchors;
int bindingIndex{-1};
float* hostBuffer{nullptr};
};
class Yolo : public IModelParser {
public:
Yolo(const NetworkInfo& networkInfo);
~Yolo() override;
bool hasFullDimsSupported() const override { return false; }
const char* getModelName() const override {
return m_ConfigFilePath.empty() ? m_NetworkType.c_str()
: m_ConfigFilePath.c_str();
}
NvDsInferStatus parseModel(nvinfer1::INetworkDefinition& network) override;
nvinfer1::ICudaEngine *createEngine (nvinfer1::IBuilder* builder);
protected:
const std::string m_NetworkType;
const std::string m_ConfigFilePath;
const std::string m_WtsFilePath;
const std::string m_DeviceType;
const std::string m_InputBlobName;
std::vector<TensorInfo> m_OutputTensors;
std::vector<std::vector<int>> m_OutputMasks;
std::vector<std::map<std::string, std::string>> m_ConfigBlocks;
uint m_InputH;
uint m_InputW;
uint m_InputC;
uint64_t m_InputSize;
std::vector<nvinfer1::Weights> m_TrtWeights;
private:
NvDsInferStatus buildYoloNetwork(
std::vector<float>& weights, nvinfer1::INetworkDefinition& network);
std::vector<std::map<std::string, std::string>> parseConfigFile(
const std::string cfgFilePath);
void parseConfigBlocks();
void destroyNetworkUtils();
};
#endif // _YOLO_H_

View File

@@ -0,0 +1,142 @@
/*
* Copyright (c) 2018-2019 NVIDIA Corporation. All rights reserved.
*
* NVIDIA Corporation and its licensors retain all intellectual property
* and proprietary rights in and to this software, related documentation
* and any modifications thereto. Any use, reproduction, disclosure or
* distribution of this software and related documentation without an express
* license agreement from NVIDIA Corporation is strictly prohibited.
*
* Edited by Marcos Luciano
* https://www.github.com/marcoslucianops
*
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
__global__ void gpuYoloLayer(const float* input, float* output, const uint gridSize, const uint numOutputClasses,
const uint numBBoxes, const uint new_coords, const float scale_x_y, char type)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
uint z_id = blockIdx.z * blockDim.z + threadIdx.z;
if ((x_id >= gridSize) || (y_id >= gridSize) || (z_id >= numBBoxes))
{
return;
}
const int numGridCells = gridSize * gridSize;
const int bbindex = y_id * gridSize + x_id;
float alpha = scale_x_y;
float beta = -0.5 * (scale_x_y - 1);
if (type == 'y') {
if (new_coords == 1) {
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] * alpha + beta;
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] * alpha + beta;
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
= pow(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] * 2, 2);
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
= pow(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] * 2, 2);
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)];
for (uint i = 0; i < numOutputClasses; ++i)
{
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
}
}
else {
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta;
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * alpha + beta;
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]);
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]);
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
for (uint i = 0; i < numOutputClasses; ++i)
{
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]);
}
}
}
else if (type == 'r') {
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta;
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * alpha + beta;
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]);
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]);
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
float temp = 1.0;
int i;
float sum = 0;
float largest = -INFINITY;
for(i = 0; i < numOutputClasses; ++i){
int val = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
largest = (val>largest) ? val : largest;
}
for(i = 0; i < numOutputClasses; ++i){
float e = exp(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] / temp - largest / temp);
sum += e;
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] = e;
}
for(i = 0; i < numOutputClasses; ++i){
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] /= sum;
}
}
}
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSize,
const uint& numOutputClasses, const uint& numBBoxes,
uint64_t outputSize, cudaStream_t stream, const uint new_coords, const float scale_x_y, char type);
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSize,
const uint& numOutputClasses, const uint& numBBoxes,
uint64_t outputSize, cudaStream_t stream, const uint new_coords, const float scale_x_y, char type)
{
dim3 threads_per_block(16, 16, 4);
dim3 number_of_blocks((gridSize / threads_per_block.x) + 1,
(gridSize / threads_per_block.y) + 1,
(numBBoxes / threads_per_block.z) + 1);
for (unsigned int batch = 0; batch < batchSize; ++batch)
{
gpuYoloLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
reinterpret_cast<const float*>(input) + (batch * outputSize),
reinterpret_cast<float*>(output) + (batch * outputSize), gridSize, numOutputClasses,
numBBoxes, new_coords, scale_x_y, type);
}
return cudaGetLastError();
}

View File

@@ -0,0 +1,205 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
* Edited by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "yoloPlugins.h"
#include "NvInferPlugin.h"
#include <cassert>
#include <iostream>
#include <memory>
int kNUM_CLASSES;
float kBETA_NMS;
std::vector<float> kANCHORS;
std::vector<std::vector<int>> kMASK;
namespace {
template <typename T>
void write(char*& buffer, const T& val)
{
*reinterpret_cast<T*>(buffer) = val;
buffer += sizeof(T);
}
template <typename T>
void read(const char*& buffer, T& val)
{
val = *reinterpret_cast<const T*>(buffer);
buffer += sizeof(T);
}
}
cudaError_t cudaYoloLayer (
const void* input, void* output, const uint& batchSize,
const uint& gridSize, const uint& numOutputClasses,
const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream, const uint new_coords, const float scale_x_y, char type);
YoloLayer::YoloLayer (const void* data, size_t length)
{
const char *d = static_cast<const char*>(data);
read(d, m_NumBoxes);
read(d, m_NumClasses);
read(d, m_GridSize);
read(d, m_OutputSize);
read(d, m_Type);
read(d, m_new_coords);
read(d, m_scale_x_y);
read(d, m_beta_nms);
uint anchorsSize;
read(d, anchorsSize);
for (uint i = 0; i < anchorsSize; i++) {
float result;
read(d, result);
m_Anchors.push_back(result);
}
uint maskSize;
read(d, maskSize);
for (uint i = 0; i < maskSize; i++) {
uint nMask;
read(d, nMask);
std::vector<int> pMask;
for (uint f = 0; f < nMask; f++) {
int result;
read(d, result);
pMask.push_back(result);
}
m_Mask.push_back(pMask);
}
kNUM_CLASSES = m_NumClasses;
kBETA_NMS = m_beta_nms;
kANCHORS = m_Anchors;
kMASK = m_Mask;
};
YoloLayer::YoloLayer (
const uint& numBoxes, const uint& numClasses, const uint& gridSize, char type, int new_coords, float scale_x_y, float beta_nms, std::vector<float> anchors, std::vector<std::vector<int>> mask) :
m_NumBoxes(numBoxes),
m_NumClasses(numClasses),
m_GridSize(gridSize),
m_Type(type),
m_new_coords(new_coords),
m_scale_x_y(scale_x_y),
m_beta_nms(beta_nms),
m_Anchors(anchors),
m_Mask(mask)
{
assert(m_NumBoxes > 0);
assert(m_NumClasses > 0);
assert(m_GridSize > 0);
m_OutputSize = m_GridSize * m_GridSize * (m_NumBoxes * (4 + 1 + m_NumClasses));
};
nvinfer1::Dims
YoloLayer::getOutputDimensions(
int index, const nvinfer1::Dims* inputs, int nbInputDims)
{
assert(index == 0);
assert(nbInputDims == 1);
return inputs[0];
}
bool YoloLayer::supportsFormat (
nvinfer1::DataType type, nvinfer1::PluginFormat format) const {
return (type == nvinfer1::DataType::kFLOAT &&
format == nvinfer1::PluginFormat::kNCHW);
}
void
YoloLayer::configureWithFormat (
const nvinfer1::Dims* inputDims, int nbInputs,
const nvinfer1::Dims* outputDims, int nbOutputs,
nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize)
{
assert(nbInputs == 1);
assert (format == nvinfer1::PluginFormat::kNCHW);
assert(inputDims != nullptr);
}
int YoloLayer::enqueue(
int batchSize, const void* const* inputs, void** outputs, void* workspace,
cudaStream_t stream)
{
CHECK(cudaYoloLayer(
inputs[0], outputs[0], batchSize, m_GridSize, m_NumClasses, m_NumBoxes,
m_OutputSize, stream, m_new_coords, m_scale_x_y, m_Type));
return 0;
}
size_t YoloLayer::getSerializationSize() const
{
int anchorsSum = 1;
for (uint i = 0; i < m_Anchors.size(); i++) {
anchorsSum += 1;
}
int maskSum = 1;
for (uint i = 0; i < m_Mask.size(); i++) {
maskSum += 1;
for (uint f = 0; f < m_Mask[i].size(); f++) {
maskSum += 1;
}
}
return sizeof(m_NumBoxes) + sizeof(m_NumClasses) + sizeof(m_GridSize) + sizeof(m_OutputSize) + sizeof(m_Type)
+ sizeof(m_new_coords) + sizeof(m_scale_x_y) + sizeof(m_beta_nms) + anchorsSum * sizeof(float) + maskSum * sizeof(int);
}
void YoloLayer::serialize(void* buffer) const
{
char *d = static_cast<char*>(buffer);
write(d, m_NumBoxes);
write(d, m_NumClasses);
write(d, m_GridSize);
write(d, m_OutputSize);
write(d, m_Type);
write(d, m_new_coords);
write(d, m_scale_x_y);
write(d, m_beta_nms);
uint anchorsSize = m_Anchors.size();
write(d, anchorsSize);
for (uint i = 0; i < anchorsSize; i++) {
write(d, m_Anchors[i]);
}
uint maskSize = m_Mask.size();
write(d, maskSize);
for (uint i = 0; i < maskSize; i++) {
uint pMaskSize = m_Mask[i].size();
write(d, pMaskSize);
for (uint f = 0; f < pMaskSize; f++) {
write(d, m_Mask[i][f]);
}
}
kNUM_CLASSES = m_NumClasses;
kBETA_NMS = m_beta_nms;
kANCHORS = m_Anchors;
kMASK = m_Mask;
}
nvinfer1::IPluginV2* YoloLayer::clone() const
{
return new YoloLayer (m_NumBoxes, m_NumClasses, m_GridSize, m_Type, m_new_coords, m_scale_x_y, m_beta_nms, m_Anchors, m_Mask);
}
REGISTER_TENSORRT_PLUGIN(YoloLayerPluginCreator);

View File

@@ -0,0 +1,155 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
* Edited by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#ifndef __YOLO_PLUGINS__
#define __YOLO_PLUGINS__
#include <cassert>
#include <cstring>
#include <cuda_runtime_api.h>
#include <iostream>
#include <memory>
#include <vector>
#include "NvInferPlugin.h"
#define CHECK(status) \
{ \
if (status != 0) \
{ \
std::cout << "CUDA failure: " << cudaGetErrorString(status) << " in file " << __FILE__ \
<< " at line " << __LINE__ << std::endl; \
abort(); \
} \
}
namespace
{
const char* YOLOLAYER_PLUGIN_VERSION {"1"};
const char* YOLOLAYER_PLUGIN_NAME {"YoloLayer_TRT"};
} // namespace
class YoloLayer : public nvinfer1::IPluginV2
{
public:
YoloLayer (const void* data, size_t length);
YoloLayer (const uint& numBoxes, const uint& numClasses, const uint& gridSize,
char type, int new_coords, float scale_x_y, float beta_nms,
std::vector<float> anchors, std::vector<std::vector<int>> mask);
const char* getPluginType () const override { return YOLOLAYER_PLUGIN_NAME; }
const char* getPluginVersion () const override { return YOLOLAYER_PLUGIN_VERSION; }
int getNbOutputs () const override { return 1; }
nvinfer1::Dims getOutputDimensions (
int index, const nvinfer1::Dims* inputs,
int nbInputDims) override;
bool supportsFormat (
nvinfer1::DataType type, nvinfer1::PluginFormat format) const override;
void configureWithFormat (
const nvinfer1::Dims* inputDims, int nbInputs,
const nvinfer1::Dims* outputDims, int nbOutputs,
nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) override;
int initialize () override { return 0; }
void terminate () override {}
size_t getWorkspaceSize (int maxBatchSize) const override { return 0; }
int enqueue (
int batchSize, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) override;
size_t getSerializationSize() const override;
void serialize (void* buffer) const override;
void destroy () override { delete this; }
nvinfer1::IPluginV2* clone() const override;
void setPluginNamespace (const char* pluginNamespace)override {
m_Namespace = pluginNamespace;
}
virtual const char* getPluginNamespace () const override {
return m_Namespace.c_str();
}
private:
uint m_NumBoxes {0};
uint m_NumClasses {0};
uint m_GridSize {0};
uint64_t m_OutputSize {0};
std::string m_Namespace {""};
char m_Type;
uint m_new_coords {0};
float m_scale_x_y {0};
float m_beta_nms {0};
std::vector<float> m_Anchors;
std::vector<std::vector<int>> m_Mask;
};
class YoloLayerPluginCreator : public nvinfer1::IPluginCreator
{
public:
YoloLayerPluginCreator () {}
~YoloLayerPluginCreator () {}
const char* getPluginName () const override { return YOLOLAYER_PLUGIN_NAME; }
const char* getPluginVersion () const override { return YOLOLAYER_PLUGIN_VERSION; }
const nvinfer1::PluginFieldCollection* getFieldNames() override {
std::cerr<< "YoloLayerPluginCreator::getFieldNames is not implemented" << std::endl;
return nullptr;
}
nvinfer1::IPluginV2* createPlugin (
const char* name, const nvinfer1::PluginFieldCollection* fc) override
{
std::cerr<< "YoloLayerPluginCreator::getFieldNames is not implemented";
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin (
const char* name, const void* serialData, size_t serialLength) override
{
std::cout << "Deserialize yoloLayer plugin: " << name << std::endl;
return new YoloLayer(serialData, serialLength);
}
void setPluginNamespace(const char* libNamespace) override {
m_Namespace = libNamespace;
}
const char* getPluginNamespace() const override {
return m_Namespace.c_str();
}
private:
std::string m_Namespace {""};
};
extern int kNUM_CLASSES;
extern float kBETA_NMS;
extern std::vector<float> kANCHORS;
extern std::vector<std::vector<int>> kMASK;
#endif // __YOLO_PLUGINS__

251
readme.md Normal file
View File

@@ -0,0 +1,251 @@
# DeepStream-Yolo
NVIDIA DeepStream SDK 5.0.1 configuration for YOLO models
##
### Improvements on this repository
* Darknet CFG params parser (not need to edit nvdsparsebbox_Yolo.cpp or another file)
* Support to new_coords, beta_nms and scale_x_y params
* Support to new models not supported in official DeepStream SDK YOLO.
* Support to layers not supported in official DeepStream SDK YOLO.
* Support to activations not supported in official DeepStream SDK YOLO.
* Support to Convolutional groups
##
Tutorial
* Configuring to your custom model
* Using VOC models
Benchmark
* [mAP/FPS comparison between models](#mapfps-comparison-between-models)
[Native TensorRT conversion](#native-tensorrt-conversion) (tested models below)
* [YOLOv4x-Mish](https://github.com/AlexeyAB/darknet)
* [YOLOv4-CSP](https://github.com/WongKinYiu/ScaledYOLOv4/tree/yolov4-csp)
* [YOLOv4](https://github.com/AlexeyAB/darknet)
* [YOLOv4-Tiny](https://github.com/AlexeyAB/darknet)
* [YOLOv3-SSP](https://github.com/pjreddie/darknet)
* [YOLOv3](https://github.com/pjreddie/darknet)
* [YOLOv3-Tiny-PRN](https://github.com/WongKinYiu/PartialResidualNetworks)
* [YOLOv3-Tiny](https://github.com/pjreddie/darknet)
* [YOLOv3-Lite](https://github.com/dog-qiuqiu/MobileNet-Yolo)
* [YOLOv3-Nano](https://github.com/dog-qiuqiu/MobileNet-Yolo)
* [YOLO-Fastest](https://github.com/dog-qiuqiu/Yolo-Fastest)
* [YOLO-Fastest-XL](https://github.com/dog-qiuqiu/Yolo-Fastest)
* [YOLOv2](https://github.com/pjreddie/darknet)
* [YOLOv2-Tiny](https://github.com/pjreddie/darknet)
External TensorRT conversion
* [YOLOv5](https://github.com/marcoslucianops/DeepStream-Yolo/blob/master/YOLOv5.md)
Request
* [Request native TensorRT conversion for your YOLO-based model](#request-native-tensorrt-conversion-for-your-yolo-based-model)
##
### Requirements
* [NVIDIA DeepStream SDK 5.0.1](https://developer.nvidia.com/deepstream-sdk)
* [DeepStream-Yolo Native](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/native) (for Darknet YOLO based models)
* [DeepStream-Yolo External](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/external) (for PyTorch YOLOv5 based model)
##
### mAP/FPS comparison between models
<details><summary>NVIDIA GTX 1050 (4GB Mobile)</summary>
```
CUDA 10.2
Driver 440.33
TensorRT 7.2.1
cuDNN 8.0.5
OpenCV 3.2.0 (libopencv-dev)
OpenCV 4.4.0 (opencv-python)
PyTorch 1.7.0
Torchvision 0.8.1
```
| TensorRT | Precision | Resolution | IoU=0.5:0.95 | IoU=0.5 | IoU=0.75 | FPS<br />(with display) | FPS<br />(without display) |
|:---------------:|:---------:|:----------:|:------------:|:-------:|:--------:|:-----------------------:|:--------------------------:|
| YOLOv5x | FP32 | 608 | 0.406 | 0.562 | 0.441 | 7.91 | 7.99 |
| YOLOv5l | FP32 | 608 | 0.385 | 0.540 | 0.419 | 12.82 | 12.97 |
| YOLOv5m | FP32 | 608 | 0.354 | 0.507 | 0.388 | 25.09 | 25.97 |
| YOLOv5s | FP32 | 608 | 0.281 | 0.430 | 0.307 | 52.02 | 56.21 |
| YOLOv4x-MISH | FP32 | 640 | 0.454 | 0.644 | 0.491 | 7.45 | 7.56 |
| YOLOv4x-MISH | FP32 | 608 | 0.450 | 0.644 | 0.482 | 7.93 | 8.05 |
| YOLOv4-CSP | FP32 | 608 | 0.434 | 0.628 | 0.465 | 13.74 | 14.11 |
| YOLOv4-CSP | FP32 | 512 | 0.427 | 0.618 | 0.459 | 21.69 | 22.75 |
| YOLOv4 | FP32 | 608 | 0.490 | 0.734 | 0.538 | 11.72 | 12.09 |
| YOLOv4 | FP32 | 512 | 0.484 | 0.725 | 0.533 | 19.00 | 19.70 |
| YOLOv4 | FP32 | 416 | 0.456 | 0.693 | 0.491 | 22.63 | 23.81 |
| YOLOv4 | FP32 | 320 | 0.400 | 0.623 | 0.424 | 32.46 | 35.07 |
| YOLOv3-SPP | FP32 | 608 | 0.411 | 0.680 | 0.436 | 11.85 | 12.12 |
| YOLOv3 | FP32 | 608 | 0.374 | 0.654 | 0.387 | 12.00 | 12.33 |
| YOLOv3 | FP32 | 416 | 0.369 | 0.651 | 0.379 | 23.19 | 24.55 |
| YOLOv4-Tiny | FP32 | 416 | 0.195 | 0.382 | 0.175 | 144.55 | 176.31 |
| YOLOv3-Tiny-PRN | FP32 | 416 | 0.168 | 0.369 | 0.130 | 181.71 | 244.47 |
| YOLOv3-Tiny | FP32 | 416 | 0.165 | 0.357 | 0.128 | 154.19 | 190.42 |
| YOLOv3-Lite | FP32 | 416 | 0.165 | 0.350 | 0.131 | 122.40 | 146.19 |
| YOLOv3-Lite | FP32 | 320 | 0.155 | 0.324 | 0.128 | 163.76 | 204.21 |
| YOLOv3-Nano | FP32 | 416 | 0.127 | 0.277 | 0.098 | 191.77 | 264.59 |
| YOLOv3-Nano | FP32 | 320 | 0.122 | 0.258 | 0.099 | 207.04 | 269.89 |
| YOLO-Fastest | FP32 | 416 | 0.092 | 0.213 | 0.062 | 174.26 | 221.05 |
| YOLO-Fastest | FP32 | 320 | 0.090 | 0.201 | 0.068 | 199.48 | 258.56 |
| YOLO-FastestXL | FP32 | 416 | 0.144 | 0.306 | 0.115 | 121.89 | 145.13 |
| YOLO-FastestXL | FP32 | 320 | 0.136 | 0.279 | 0.117 | 162.65 | 199.75 |
| YOLOv2 | FP32 | 608 | 0.286 | 0.534 | 0.274 | 23.92 | 25.47 |
| YOLOv2-Tiny | FP32 | 416 | 0.103 | 0.251 | 0.064 | 165.01 | 203.02 |
| Darknet | Precision | Resolution | IoU=0.5:0.95 | IoU=0.5 | IoU=0.75 | FPS<br />(with display) | FPS<br />(without display) |
|:---------------:|:---------:|:----------:|:------------:|:-------:|:--------:|:-----------------------:|:--------------------------:|
| YOLOv4x-MISH | FP32 | 640 | 0.495 | 0.682 | 0.538 | 5.3 | 5.5 |
| YOLOv4x-MISH | FP32 | 608 | 0.493 | 0.680 | 0.535 | 5.4 | 5.6 |
| YOLOv4-CSP | FP32 | 608 | 0.473 | 0.661 | 0.515 | 9.2 | 9.5 |
| YOLOv4-CSP | FP32 | 512 | 0.458 | 0.645 | 0.496 | 13.6 | 14.0 |
| YOLOv4 | FP32 | 608 | 0.513 | 0.748 | 0.574 | 7.3 | 7.5 |
| YOLOv4 | FP32 | 512 | 0.506 | 0.738 | 0.564 | 11.8 | 12.3 |
| YOLOv4 | FP32 | 416 | 0.479 | 0.709 | 0.527 | 15.4 | 15.8 |
| YOLOv4 | FP32 | 320 | 0.421 | 0.638 | 0.454 | 21.0 | 21.7 |
| YOLOv3-SPP | FP32 | 608 | 0.432 | 0.701 | 0.465 | 6.9 | 7.1 |
| YOLOv3 | FP32 | 608 | 0.391 | 0.672 | 0.412 | 7.0 | 7.3 |
| YOLOv3 | FP32 | 416 | 0.384 | 0.668 | 0.402 | 16.3 | 16.9 |
| YOLOv4-Tiny | FP32 | 416 | 0.203 | 0.388 | 0.189 | 68.0 | 112.5 |
| YOLOv3-Tiny-PRN | FP32 | 416 | 0.172 | 0.378 | 0.133 | 71.6 | 143.9 |
| YOLOv3-Tiny | FP32 | 416 | 0.171 | 0.367 | 0.137 | 71.5 | 117.9 |
| YOLOv3-Lite | FP32 | 416 | 0.169 | 0.349 | 0.144 | 53.8 | 63.4 |
| YOLOv3-Lite | FP32 | 320 | 0.159 | 0.326 | 0.139 | 55.2 | 97.5 |
| YOLOv3-Nano | FP32 | 416 | 0.129 | 0.275 | 0.102 | 58.0 | 113.1 |
| YOLOv3-Nano | FP32 | 320 | 0.124 | 0.259 | 0.106 | 61.6 | 156.8 |
| YOLO-Fastest | FP32 | 416 | 0.095 | 0.213 | 0.068 | 61.7 | 104.1 |
| YOLO-Fastest | FP32 | 320 | 0.093 | 0.202 | 0.074 | 65.8 | 143.3 |
| YOLO-FastestXL | FP32 | 416 | 0.148 | 0.308 | 0.125 | 62.0 | 75.9 |
| YOLO-FastestXL | FP32 | 320 | 0.141 | 0.284 | 0.125 | 63.9 | 112.3 |
| YOLOv2 | FP32 | 608 | 0.297 | 0.548 | 0.291 | 12.1 | 12.1 |
| YOLOv2-Tiny | FP32 | 416 | 0.105 | 0.255 | 0.068 | 34.5 | 40.7 |
| PyTorch | Precision | Resolution | IoU=0.5:0.95 | IoU=0.5 | IoU=0.75 | FPS<br />(with output) | FPS<br />(without output) |
|:-------:|:---------:|:----------:|:------------:|:-------:|:--------:|:----------------------:|:-------------------------:|
| YOLOv5x | FP32 | 608 | 0.487 | 0.676 | 0.527 | 8.25 | 9.49 |
| YOLOv5l | FP32 | 608 | 0.471 | 0.662 | 0.512 | 12.67 | 15.77 |
| YOLOv5m | FP32 | 608 | 0.439 | 0.631 | 0.474 | 18.13 | 24.80 |
| YOLOv5s | FP32 | 608 | 0.369 | 0.567 | 0.395 | 28.03 | 49.52 |
<br />
</details>
<details><summary>NVIDIA Jetson Nano (4GB)</summary>
Coming soon
</details>
<br />
#### DeepStream settings
* General
```
width = 1920
height = 1080
maintain-aspect-ratio = 0
batch-size = 1
```
* Evaluate mAP
```
valid = val2017 (COCO)
nms-iou-threshold = 0.6
pre-cluster-threshold = 0.001 (CONF_THRESH)
```
* Evaluate FPS and Demo
```
nms-iou-threshold = 0.45 (NMS; changed to beta_nms when available)
pre-cluster-threshold = 0.25 (CONF_THRESH)
```
##
### Native TensorRT conversion
Donwload [my native folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/native), rename to yolo and move to your deepstream/sources folder.
Donwload cfg and weights files from your model and move to deepstream/sources/yolo folder.
* YOLOv4x-Mish [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4x-mish.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4x-mish.weights)]
* YOLOv4-CSP [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-csp.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4-csp.weights)]
* YOLOv4 [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v3_optimal/yolov4.weights)]
* YOLOv4-Tiny [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-tiny.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4-tiny.weights)]
* YOLOv3-SPP [[cfg](https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3-spp.cfg)] [[weights](https://pjreddie.com/media/files/yolov3-spp.weights)]
* YOLOv3 [[cfg](https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg)] [[weights](https://pjreddie.com/media/files/yolov3.weights)]
* YOLOv3-Tiny-PRN [[cfg](https://raw.githubusercontent.com/WongKinYiu/PartialResidualNetworks/master/cfg/yolov3-tiny-prn.cfg)] [[weights](https://github.com/WongKinYiu/PartialResidualNetworks/raw/master/model/yolov3-tiny-prn.weights)]
* YOLOv3-Tiny [[cfg](https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3-tiny.cfg)] [[weights](https://pjreddie.com/media/files/yolov3-tiny.weights)]
* YOLOv3-Lite [[cfg](https://raw.githubusercontent.com/dog-qiuqiu/MobileNet-Yolo/master/MobileNetV2-YOLOv3-Lite/COCO/MobileNetV2-YOLOv3-Lite-coco.cfg)] [[weights](https://github.com/dog-qiuqiu/MobileNet-Yolo/raw/master/MobileNetV2-YOLOv3-Lite/COCO/MobileNetV2-YOLOv3-Lite-coco.weights)]
* YOLOv3-Nano [[cfg](https://raw.githubusercontent.com/dog-qiuqiu/MobileNet-Yolo/master/MobileNetV2-YOLOv3-Nano/COCO/MobileNetV2-YOLOv3-Nano-coco.cfg)] [[weights](https://github.com/dog-qiuqiu/MobileNet-Yolo/raw/master/MobileNetV2-YOLOv3-Nano/COCO/MobileNetV2-YOLOv3-Nano-coco.weights)]
* YOLO-Fastest [[cfg](https://raw.githubusercontent.com/dog-qiuqiu/Yolo-Fastest/master/Yolo-Fastest/COCO/yolo-fastest.cfg)] [[weights](https://github.com/dog-qiuqiu/Yolo-Fastest/raw/master/Yolo-Fastest/COCO/yolo-fastest.weights)]
* YOLO-Fastest-XL [[cfg](https://raw.githubusercontent.com/dog-qiuqiu/Yolo-Fastest/master/Yolo-Fastest/COCO/yolo-fastest-xl.cfg)] [[weights](https://github.com/dog-qiuqiu/Yolo-Fastest/raw/master/Yolo-Fastest/COCO/yolo-fastest-xl.weights)]
* YOLOv2 [[cfg](https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov2.cfg)] [[weights](https://pjreddie.com/media/files/yolov2.weights)]
* YOLOv2-Tiny [[cfg](https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov2-tiny.cfg)] [[weights](https://pjreddie.com/media/files/yolov2-tiny.weights)]
Compile
```
cd /opt/nvidia/deepstream/deepstream-5.0/sources/yolo
CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
```
Edit config_infer_primary.txt for your model (example for YOLOv4)
```
[property]
...
# 0=RGB, 1=BGR
model-color-format=0
# CFG
custom-network-config=yolov4.cfg
# Weights
model-file=yolov4.weights
# Generated TensorRT model (will be created if it doesn't exist)
model-engine-file=model_b1_gpu0_fp32.engine
# Model labels file
labelfile-path=labels.txt
# Batch size
batch-size=1
# 0=FP32, 1=INT8, 2=FP16 mode
network-mode=0
# Number of classes in label file
num-detected-classes=80
...
[class-attrs-all]
# CONF_THRESH
pre-cluster-threshold=0.25
```
Run
```
deepstream-app -c deepstream_app_config.txt
```
If you want to use YOLOv2 or YOLOv2-Tiny models, change, before run, deepstream_app_config.txt
```
[primary-gie]
enable=1
gpu-id=0
gie-unique-id=1
nvbuf-memory-type=0
config-file=config_infer_primary_yoloV2.txt
```
Note: config_infer_primary.txt uses cluster-mode=4 and NMS = 0.45 (via code) when beta_nms isn't available (when beta_nms is available, NMS = beta_nms), while config_infer_primary_yoloV2.txt uses cluster-mode=2 and nms-iou-threshold=0.45 to set NMS.
##
### Request native TensorRT conversion for your YOLO-based model
To request moded files for native TensorRT conversion to use in DeepStream SDK, send me the model cfg and weights files via Issues tab.
<br />
Note: If your model are listed in native tab, you can use [my native folder](https://github.com/marcoslucianops/DeepStream-Yolo/tree/master/native) to run your model in DeepStream.
##
For commercial DeepStream SDK projects, contact me at email address available in GitHub.