Added YOLOR native support
YOLOR-CSP YOLOR-CSP* YOLOR-CSP-X YOLOR-CSP-X*
This commit is contained in:
24
config_infer_primary_yolor.txt
Normal file
24
config_infer_primary_yolor.txt
Normal file
@@ -0,0 +1,24 @@
|
||||
[property]
|
||||
gpu-id=0
|
||||
net-scale-factor=0.0039215697906911373
|
||||
model-color-format=0
|
||||
custom-network-config=yolor_csp.cfg
|
||||
model-file=yolor_csp.wts
|
||||
model-engine-file=model_b1_gpu0_fp32.engine
|
||||
#int8-calib-file=calib.table
|
||||
labelfile-path=labels.txt
|
||||
batch-size=1
|
||||
network-mode=0
|
||||
num-detected-classes=80
|
||||
interval=0
|
||||
gie-unique-id=1
|
||||
process-mode=1
|
||||
network-type=0
|
||||
cluster-mode=4
|
||||
maintain-aspect-ratio=0
|
||||
parse-bbox-func-name=NvDsInferParseYolo
|
||||
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
|
||||
engine-create-func-name=NvDsInferYoloCudaEngineGet
|
||||
|
||||
[class-attrs-all]
|
||||
pre-cluster-threshold=0.25
|
||||
@@ -53,6 +53,8 @@ SRCFILES:= nvdsinfer_yolo_engine.cpp \
|
||||
nvdsparsebbox_Yolo.cpp \
|
||||
yoloPlugins.cpp \
|
||||
layers/convolutional_layer.cpp \
|
||||
layers/implicit_layer.cpp \
|
||||
layers/channels_layer.cpp \
|
||||
layers/dropout_layer.cpp \
|
||||
layers/shortcut_layer.cpp \
|
||||
layers/route_layer.cpp \
|
||||
|
||||
32
nvdsinfer_custom_impl_Yolo/layers/channels_layer.cpp
Normal file
32
nvdsinfer_custom_impl_Yolo/layers/channels_layer.cpp
Normal file
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Created by Marcos Luciano
|
||||
* https://www.github.com/marcoslucianops
|
||||
*/
|
||||
|
||||
#include "channels_layer.h"
|
||||
|
||||
nvinfer1::ILayer* channelsLayer(
|
||||
std::string type,
|
||||
nvinfer1::ITensor* input,
|
||||
nvinfer1::ITensor* implicitTensor,
|
||||
nvinfer1::INetworkDefinition* network)
|
||||
{
|
||||
nvinfer1::ILayer* output;
|
||||
|
||||
if (type == "shift") {
|
||||
nvinfer1::IElementWiseLayer* ew = network->addElementWise(
|
||||
*input, *implicitTensor,
|
||||
nvinfer1::ElementWiseOperation::kSUM);
|
||||
assert(ew != nullptr);
|
||||
output = ew;
|
||||
}
|
||||
else if (type == "control") {
|
||||
nvinfer1::IElementWiseLayer* ew = network->addElementWise(
|
||||
*input, *implicitTensor,
|
||||
nvinfer1::ElementWiseOperation::kPROD);
|
||||
assert(ew != nullptr);
|
||||
output = ew;
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
20
nvdsinfer_custom_impl_Yolo/layers/channels_layer.h
Normal file
20
nvdsinfer_custom_impl_Yolo/layers/channels_layer.h
Normal file
@@ -0,0 +1,20 @@
|
||||
/*
|
||||
* Created by Marcos Luciano
|
||||
* https://www.github.com/marcoslucianops
|
||||
*/
|
||||
|
||||
#ifndef __CHANNELS_LAYER_H__
|
||||
#define __CHANNELS_LAYER_H__
|
||||
|
||||
#include <map>
|
||||
#include <cassert>
|
||||
|
||||
#include "NvInfer.h"
|
||||
|
||||
nvinfer1::ILayer* channelsLayer(
|
||||
std::string type,
|
||||
nvinfer1::ITensor* input,
|
||||
nvinfer1::ITensor* implicitTensor,
|
||||
nvinfer1::INetworkDefinition* network);
|
||||
|
||||
#endif
|
||||
31
nvdsinfer_custom_impl_Yolo/layers/implicit_layer.cpp
Normal file
31
nvdsinfer_custom_impl_Yolo/layers/implicit_layer.cpp
Normal file
@@ -0,0 +1,31 @@
|
||||
/*
|
||||
* Created by Marcos Luciano
|
||||
* https://www.github.com/marcoslucianops
|
||||
*/
|
||||
|
||||
#include <math.h>
|
||||
#include "implicit_layer.h"
|
||||
|
||||
nvinfer1::ILayer* implicitLayer(
|
||||
int channels,
|
||||
std::vector<float>& weights,
|
||||
std::vector<nvinfer1::Weights>& trtWeights,
|
||||
int& weightPtr,
|
||||
nvinfer1::INetworkDefinition* network)
|
||||
{
|
||||
nvinfer1::Weights convWt{nvinfer1::DataType::kFLOAT, nullptr, channels};
|
||||
|
||||
float* val = new float[channels];
|
||||
for (int i = 0; i < channels; ++i)
|
||||
{
|
||||
val[i] = weights[weightPtr];
|
||||
weightPtr++;
|
||||
}
|
||||
convWt.values = val;
|
||||
trtWeights.push_back(convWt);
|
||||
|
||||
nvinfer1::IConstantLayer* implicit = network->addConstant(nvinfer1::Dims3{static_cast<int>(channels), 1, 1}, convWt);
|
||||
assert(implicit != nullptr);
|
||||
|
||||
return implicit;
|
||||
}
|
||||
22
nvdsinfer_custom_impl_Yolo/layers/implicit_layer.h
Normal file
22
nvdsinfer_custom_impl_Yolo/layers/implicit_layer.h
Normal file
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Created by Marcos Luciano
|
||||
* https://www.github.com/marcoslucianops
|
||||
*/
|
||||
|
||||
#ifndef __IMPLICIT_LAYER_H__
|
||||
#define __IMPLICIT_LAYER_H__
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
|
||||
#include "NvInfer.h"
|
||||
|
||||
nvinfer1::ILayer* implicitLayer(
|
||||
int channels,
|
||||
std::vector<float>& weights,
|
||||
std::vector<nvinfer1::Weights>& trtWeights,
|
||||
int& weightPtr,
|
||||
nvinfer1::INetworkDefinition* network);
|
||||
|
||||
#endif
|
||||
@@ -187,6 +187,51 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
||||
printLayerInfo(layerIndex, layerType, inputVol, outputVol, std::to_string(weightPtr));
|
||||
}
|
||||
|
||||
else if (m_ConfigBlocks.at(i).at("type") == "implicit_add" || m_ConfigBlocks.at(i).at("type") == "implicit_mul") {
|
||||
std::string type;
|
||||
if (m_ConfigBlocks.at(i).at("type") == "implicit_add") {
|
||||
type = "add";
|
||||
}
|
||||
else if (m_ConfigBlocks.at(i).at("type") == "implicit_mul") {
|
||||
type = "mul";
|
||||
}
|
||||
assert(m_ConfigBlocks.at(i).find("filters") != m_ConfigBlocks.at(i).end());
|
||||
int filters = std::stoi(m_ConfigBlocks.at(i).at("filters"));
|
||||
nvinfer1::ILayer* out = implicitLayer(filters, weights, m_TrtWeights, weightPtr, &network);
|
||||
previous = out->getOutput(0);
|
||||
assert(previous != nullptr);
|
||||
channels = getNumChannels(previous);
|
||||
std::string outputVol = dimsToString(previous->getDimensions());
|
||||
tensorOutputs.push_back(previous);
|
||||
std::string layerType = "implicit_" + type;
|
||||
printLayerInfo(layerIndex, layerType, " -", outputVol, std::to_string(weightPtr));
|
||||
}
|
||||
|
||||
else if (m_ConfigBlocks.at(i).at("type") == "shift_channels" || m_ConfigBlocks.at(i).at("type") == "control_channels") {
|
||||
std::string type;
|
||||
if (m_ConfigBlocks.at(i).at("type") == "shift_channels") {
|
||||
type = "shift";
|
||||
}
|
||||
else if (m_ConfigBlocks.at(i).at("type") == "control_channels") {
|
||||
type = "control";
|
||||
}
|
||||
assert(m_ConfigBlocks.at(i).find("from") != m_ConfigBlocks.at(i).end());
|
||||
int from = stoi(m_ConfigBlocks.at(i).at("from"));
|
||||
if (from > 0) {
|
||||
from = from - i + 1;
|
||||
}
|
||||
assert((i - 2 >= 0) && (i - 2 < tensorOutputs.size()));
|
||||
assert((i + from - 1 >= 0) && (i + from - 1 < tensorOutputs.size()));
|
||||
assert(i + from - 1 < i - 2);
|
||||
nvinfer1::ILayer* out = channelsLayer(type, previous, tensorOutputs[i + from - 1], &network);
|
||||
previous = out->getOutput(0);
|
||||
assert(previous != nullptr);
|
||||
std::string outputVol = dimsToString(previous->getDimensions());
|
||||
tensorOutputs.push_back(previous);
|
||||
std::string layerType = type + "_channels" + ": " + std::to_string(i + from - 1);
|
||||
printLayerInfo(layerIndex, layerType, " -", outputVol, " -");
|
||||
}
|
||||
|
||||
else if (m_ConfigBlocks.at(i).at("type") == "dropout") {
|
||||
assert(m_ConfigBlocks.at(i).find("probability") != m_ConfigBlocks.at(i).end());
|
||||
//float probability = std::stof(m_ConfigBlocks.at(i).at("probability"));
|
||||
|
||||
@@ -27,6 +27,8 @@
|
||||
#define _YOLO_H_
|
||||
|
||||
#include "layers/convolutional_layer.h"
|
||||
#include "layers/implicit_layer.h"
|
||||
#include "layers/channels_layer.h"
|
||||
#include "layers/dropout_layer.h"
|
||||
#include "layers/shortcut_layer.h"
|
||||
#include "layers/route_layer.h"
|
||||
|
||||
@@ -60,6 +60,28 @@ __global__ void gpuYoloLayer(const float* input, float* output, const uint gridS
|
||||
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
|
||||
}
|
||||
}
|
||||
else if (new_coords == 0 && scale_x_y != 1) { // YOLOR incorrect param
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * 2.0 - 0.5;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * 2.0 - 0.5;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]
|
||||
= pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * 2, 2);
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]
|
||||
= pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * 2, 2);
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
|
||||
|
||||
for (uint i = 0; i < numOutputClasses; ++i)
|
||||
{
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta;
|
||||
|
||||
91
readme.md
91
readme.md
@@ -6,7 +6,6 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
|
||||
|
||||
* New documentation for multiple models
|
||||
* DeepStream tutorials
|
||||
* Native YOLOR support
|
||||
* Native PP-YOLO support
|
||||
* Models benchmark
|
||||
* GPU NMS
|
||||
@@ -22,7 +21,9 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
|
||||
* Support for convolutional groups
|
||||
* Support for INT8 calibration
|
||||
* Support for non square models
|
||||
* **Support for implicit and channel layers (YOLOR)**
|
||||
* **YOLOv5 6.0 native support**
|
||||
* **Initial YOLOR native support**
|
||||
|
||||
##
|
||||
|
||||
@@ -33,6 +34,7 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
|
||||
* [dGPU installation](#dgpu-installation)
|
||||
* [Basic usage](#basic-usage)
|
||||
* [YOLOv5 usage](#yolov5-usage)
|
||||
* [YOLOR usage](#yolor-usage)
|
||||
* [INT8 calibration](#int8-calibration)
|
||||
* [Using your custom model](docs/customModels.md)
|
||||
|
||||
@@ -55,6 +57,10 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
|
||||
##
|
||||
|
||||
### Tested models
|
||||
* [YOLOR-CSP](https://github.com/WongKinYiu/yolor) [[cfg]](https://raw.githubusercontent.com/WongKinYiu/yolor/main/cfg/yolor_csp.cfg) [[pt]](https://drive.google.com/file/d/1ZEqGy4kmZyD-Cj3tEFJcLSZenZBDGiyg/view?usp=sharing)
|
||||
* [YOLOR-CSP*](https://github.com/WongKinYiu/yolor) [[cfg]](https://raw.githubusercontent.com/WongKinYiu/yolor/main/cfg/yolor_csp.cfg) [[pt]](https://drive.google.com/file/d/1OJKgIasELZYxkIjFoiqyn555bcmixUP2/view?usp=sharing)
|
||||
* [YOLOR-CSP-X](https://github.com/WongKinYiu/yolor) [[cfg]](https://raw.githubusercontent.com/WongKinYiu/yolor/main/cfg/yolor_csp_x.cfg) [[pt]](https://drive.google.com/file/d/1L29rfIPNH1n910qQClGftknWpTBgAv6c/view?usp=sharing)
|
||||
* [YOLOR-CSO-X*](https://github.com/WongKinYiu/yolor) [[cfg]](https://raw.githubusercontent.com/WongKinYiu/yolor/main/cfg/yolor_csp_x.cfg) [[pt]](https://drive.google.com/file/d/1NbMG3ivuBQ4S8kEhFJ0FIqOQXevGje_w/view?usp=sharing)
|
||||
* [YOLOv5 6.0](https://github.com/ultralytics/yolov5) [[pt]](https://github.com/ultralytics/yolov5/releases/tag/v6.0)
|
||||
* [YOLOv4x-Mish](https://github.com/AlexeyAB/darknet) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4x-mish.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4x-mish.weights)]
|
||||
* [YOLOv4-CSP](https://github.com/WongKinYiu/ScaledYOLOv4/tree/yolov4-csp) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-csp.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4-csp.weights)]
|
||||
@@ -285,7 +291,7 @@ config-file=config_infer_primary_yoloV2.txt
|
||||
|
||||
### YOLOv5 usage
|
||||
|
||||
#### 1. Copy gen_wts_yoloV5.py from DeepStream-Yolo/utils to ultralytics/yolov5 folder
|
||||
#### 1. Copy gen_wts_yoloV5.py from DeepStream-Yolo/utils to [ultralytics/yolov5](https://github.com/ultralytics/yolov5) folder
|
||||
|
||||
#### 2. Open the ultralytics/yolov5 folder
|
||||
|
||||
@@ -404,6 +410,87 @@ deepstream-app -c deepstream_app_config.txt
|
||||
|
||||
##
|
||||
|
||||
### YOLOR usage
|
||||
|
||||
**NOTE**: For now, available only for YOLOR-CSP, YOLOR-CSP*, YOLOR-CSP-X and YOLOR-CSP-X*.
|
||||
|
||||
#### 1. Copy gen_wts_yolor.py from DeepStream-Yolo/utils to [yolor](https://github.com/WongKinYiu/yolor) folder
|
||||
|
||||
#### 2. Open the yolor folder
|
||||
|
||||
#### 3. Download pt file from [yolor](https://github.com/WongKinYiu/yolor) website
|
||||
|
||||
#### 4. Generate wts file (example for YOLOR-CSP)
|
||||
|
||||
```
|
||||
python3 gen_wts_yolor.py -w yolor_csp.pt -c cfg/yolor_csp.cfg
|
||||
```
|
||||
|
||||
#### 5. Copy cfg and generated wts files to DeepStream-Yolo folder
|
||||
|
||||
#### 6. Open DeepStream-Yolo folder
|
||||
|
||||
#### 7. Compile lib
|
||||
|
||||
* x86 platform
|
||||
|
||||
```
|
||||
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
|
||||
```
|
||||
|
||||
* Jetson platform
|
||||
|
||||
```
|
||||
CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
|
||||
```
|
||||
|
||||
#### 8. Edit config_infer_primary_yolor.txt for your model (example for YOLOR-CSP)
|
||||
|
||||
```
|
||||
[property]
|
||||
...
|
||||
# 0=RGB, 1=BGR, 2=GRAYSCALE
|
||||
model-color-format=0
|
||||
# CFG
|
||||
custom-network-config=yolor_csp.cfg
|
||||
# WTS
|
||||
model-file=yolor_csp.wts
|
||||
# Generated TensorRT model (will be created if it doesn't exist)
|
||||
model-engine-file=model_b1_gpu0_fp32.engine
|
||||
# Model labels file
|
||||
labelfile-path=labels.txt
|
||||
# Batch size
|
||||
batch-size=1
|
||||
# 0=FP32, 1=INT8, 2=FP16 mode
|
||||
network-mode=0
|
||||
# Number of classes in label file
|
||||
num-detected-classes=80
|
||||
...
|
||||
[class-attrs-all]
|
||||
# CONF_THRESH
|
||||
pre-cluster-threshold=0.25
|
||||
```
|
||||
|
||||
#### 8. Change the deepstream_app_config.txt file
|
||||
|
||||
```
|
||||
...
|
||||
[primary-gie]
|
||||
enable=1
|
||||
gpu-id=0
|
||||
gie-unique-id=1
|
||||
nvbuf-memory-type=0
|
||||
config-file=config_infer_primary_yolor.txt
|
||||
```
|
||||
|
||||
#### 9. Run
|
||||
|
||||
```
|
||||
deepstream-app -c deepstream_app_config.txt
|
||||
```
|
||||
|
||||
##
|
||||
|
||||
### INT8 calibration
|
||||
|
||||
#### 1. Install OpenCV
|
||||
|
||||
@@ -8,7 +8,7 @@ from utils.torch_utils import select_device
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="PyTorch conversion")
|
||||
parser = argparse.ArgumentParser(description="PyTorch YOLOv5 conversion")
|
||||
parser.add_argument("-w", "--weights", required=True, help="Input weights (.pt) file path (required)")
|
||||
parser.add_argument("-c", "--yaml", help="Input cfg (.yaml) file path")
|
||||
parser.add_argument("-mw", "--width", help="Model width (default = 640 / 1280 [P6])")
|
||||
|
||||
43
utils/gen_wts_yolor.py
Normal file
43
utils/gen_wts_yolor.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import argparse
|
||||
import os
|
||||
import struct
|
||||
import torch
|
||||
from utils.torch_utils import select_device
|
||||
from models.models import Darknet
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="PyTorch YOLOR conversion (main branch)")
|
||||
parser.add_argument("-w", "--weights", required=True, help="Input weights (.pt) file path (required)")
|
||||
parser.add_argument("-c", "--cfg", required=True, help="Input cfg (.cfg) file path (required)")
|
||||
args = parser.parse_args()
|
||||
if not os.path.isfile(args.weights):
|
||||
raise SystemExit("Invalid weights file")
|
||||
if not os.path.isfile(args.cfg):
|
||||
raise SystemExit("Invalid cfg file")
|
||||
return args.weights, args.cfg
|
||||
|
||||
|
||||
pt_file, cfg_file = parse_args()
|
||||
|
||||
wts_file = pt_file.split(".pt")[0] + ".wts"
|
||||
|
||||
device = select_device("cpu")
|
||||
model = Darknet(cfg_file).to(device)
|
||||
model.load_state_dict(torch.load(pt_file, map_location=device)["model"])
|
||||
model.to(device).eval()
|
||||
|
||||
with open(wts_file, "w") as f:
|
||||
wts_write = ""
|
||||
conv_count = 0
|
||||
for k, v in model.state_dict().items():
|
||||
if not "num_batches_tracked" in k:
|
||||
vr = v.reshape(-1).cpu().numpy()
|
||||
wts_write += "{} {} ".format(k, len(vr))
|
||||
for vv in vr:
|
||||
wts_write += " "
|
||||
wts_write += struct.pack(">f", float(vv)).hex()
|
||||
wts_write += "\n"
|
||||
conv_count += 1
|
||||
f.write("{}\n".format(conv_count))
|
||||
f.write(wts_write)
|
||||
Reference in New Issue
Block a user