From e2257a81c00dc9fb8fb469902746f6010f5a9d3e Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 12 Dec 2021 00:47:32 -0300 Subject: [PATCH] Added YOLOR native support YOLOR-CSP YOLOR-CSP* YOLOR-CSP-X YOLOR-CSP-X* --- config_infer_primary_yolor.txt | 24 +++++ nvdsinfer_custom_impl_Yolo/Makefile | 2 + .../layers/channels_layer.cpp | 32 +++++++ .../layers/channels_layer.h | 20 ++++ .../layers/implicit_layer.cpp | 31 +++++++ .../layers/implicit_layer.h | 22 +++++ nvdsinfer_custom_impl_Yolo/yolo.cpp | 45 +++++++++ nvdsinfer_custom_impl_Yolo/yolo.h | 2 + nvdsinfer_custom_impl_Yolo/yoloForward.cu | 22 +++++ readme.md | 91 ++++++++++++++++++- utils/gen_wts_yoloV5.py | 8 +- utils/gen_wts_yolor.py | 43 +++++++++ 12 files changed, 336 insertions(+), 6 deletions(-) create mode 100644 config_infer_primary_yolor.txt create mode 100644 nvdsinfer_custom_impl_Yolo/layers/channels_layer.cpp create mode 100644 nvdsinfer_custom_impl_Yolo/layers/channels_layer.h create mode 100644 nvdsinfer_custom_impl_Yolo/layers/implicit_layer.cpp create mode 100644 nvdsinfer_custom_impl_Yolo/layers/implicit_layer.h create mode 100644 utils/gen_wts_yolor.py diff --git a/config_infer_primary_yolor.txt b/config_infer_primary_yolor.txt new file mode 100644 index 0000000..fd194ef --- /dev/null +++ b/config_infer_primary_yolor.txt @@ -0,0 +1,24 @@ +[property] +gpu-id=0 +net-scale-factor=0.0039215697906911373 +model-color-format=0 +custom-network-config=yolor_csp.cfg +model-file=yolor_csp.wts +model-engine-file=model_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=4 +maintain-aspect-ratio=0 +parse-bbox-func-name=NvDsInferParseYolo +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so +engine-create-func-name=NvDsInferYoloCudaEngineGet + +[class-attrs-all] +pre-cluster-threshold=0.25 diff --git a/nvdsinfer_custom_impl_Yolo/Makefile b/nvdsinfer_custom_impl_Yolo/Makefile index f0b95c0..b063e83 100644 --- a/nvdsinfer_custom_impl_Yolo/Makefile +++ b/nvdsinfer_custom_impl_Yolo/Makefile @@ -53,6 +53,8 @@ SRCFILES:= nvdsinfer_yolo_engine.cpp \ nvdsparsebbox_Yolo.cpp \ yoloPlugins.cpp \ layers/convolutional_layer.cpp \ + layers/implicit_layer.cpp \ + layers/channels_layer.cpp \ layers/dropout_layer.cpp \ layers/shortcut_layer.cpp \ layers/route_layer.cpp \ diff --git a/nvdsinfer_custom_impl_Yolo/layers/channels_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/channels_layer.cpp new file mode 100644 index 0000000..af61bac --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/layers/channels_layer.cpp @@ -0,0 +1,32 @@ +/* + * Created by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#include "channels_layer.h" + +nvinfer1::ILayer* channelsLayer( + std::string type, + nvinfer1::ITensor* input, + nvinfer1::ITensor* implicitTensor, + nvinfer1::INetworkDefinition* network) +{ + nvinfer1::ILayer* output; + + if (type == "shift") { + nvinfer1::IElementWiseLayer* ew = network->addElementWise( + *input, *implicitTensor, + nvinfer1::ElementWiseOperation::kSUM); + assert(ew != nullptr); + output = ew; + } + else if (type == "control") { + nvinfer1::IElementWiseLayer* ew = network->addElementWise( + *input, *implicitTensor, + nvinfer1::ElementWiseOperation::kPROD); + assert(ew != nullptr); + output = ew; + } + + return output; +} \ No newline at end of file diff --git a/nvdsinfer_custom_impl_Yolo/layers/channels_layer.h b/nvdsinfer_custom_impl_Yolo/layers/channels_layer.h new file mode 100644 index 0000000..b22f6b6 --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/layers/channels_layer.h @@ -0,0 +1,20 @@ +/* + * Created by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#ifndef __CHANNELS_LAYER_H__ +#define __CHANNELS_LAYER_H__ + +#include +#include + +#include "NvInfer.h" + +nvinfer1::ILayer* channelsLayer( + std::string type, + nvinfer1::ITensor* input, + nvinfer1::ITensor* implicitTensor, + nvinfer1::INetworkDefinition* network); + +#endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.cpp new file mode 100644 index 0000000..a3a3d0e --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.cpp @@ -0,0 +1,31 @@ +/* + * Created by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#include +#include "implicit_layer.h" + +nvinfer1::ILayer* implicitLayer( + int channels, + std::vector& weights, + std::vector& trtWeights, + int& weightPtr, + nvinfer1::INetworkDefinition* network) +{ + nvinfer1::Weights convWt{nvinfer1::DataType::kFLOAT, nullptr, channels}; + + float* val = new float[channels]; + for (int i = 0; i < channels; ++i) + { + val[i] = weights[weightPtr]; + weightPtr++; + } + convWt.values = val; + trtWeights.push_back(convWt); + + nvinfer1::IConstantLayer* implicit = network->addConstant(nvinfer1::Dims3{static_cast(channels), 1, 1}, convWt); + assert(implicit != nullptr); + + return implicit; +} \ No newline at end of file diff --git a/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.h b/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.h new file mode 100644 index 0000000..e34d738 --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.h @@ -0,0 +1,22 @@ +/* + * Created by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#ifndef __IMPLICIT_LAYER_H__ +#define __IMPLICIT_LAYER_H__ + +#include +#include +#include + +#include "NvInfer.h" + +nvinfer1::ILayer* implicitLayer( + int channels, + std::vector& weights, + std::vector& trtWeights, + int& weightPtr, + nvinfer1::INetworkDefinition* network); + +#endif diff --git a/nvdsinfer_custom_impl_Yolo/yolo.cpp b/nvdsinfer_custom_impl_Yolo/yolo.cpp index a0fb9d1..9d33ff3 100644 --- a/nvdsinfer_custom_impl_Yolo/yolo.cpp +++ b/nvdsinfer_custom_impl_Yolo/yolo.cpp @@ -187,6 +187,51 @@ NvDsInferStatus Yolo::buildYoloNetwork( printLayerInfo(layerIndex, layerType, inputVol, outputVol, std::to_string(weightPtr)); } + else if (m_ConfigBlocks.at(i).at("type") == "implicit_add" || m_ConfigBlocks.at(i).at("type") == "implicit_mul") { + std::string type; + if (m_ConfigBlocks.at(i).at("type") == "implicit_add") { + type = "add"; + } + else if (m_ConfigBlocks.at(i).at("type") == "implicit_mul") { + type = "mul"; + } + assert(m_ConfigBlocks.at(i).find("filters") != m_ConfigBlocks.at(i).end()); + int filters = std::stoi(m_ConfigBlocks.at(i).at("filters")); + nvinfer1::ILayer* out = implicitLayer(filters, weights, m_TrtWeights, weightPtr, &network); + previous = out->getOutput(0); + assert(previous != nullptr); + channels = getNumChannels(previous); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerType = "implicit_" + type; + printLayerInfo(layerIndex, layerType, " -", outputVol, std::to_string(weightPtr)); + } + + else if (m_ConfigBlocks.at(i).at("type") == "shift_channels" || m_ConfigBlocks.at(i).at("type") == "control_channels") { + std::string type; + if (m_ConfigBlocks.at(i).at("type") == "shift_channels") { + type = "shift"; + } + else if (m_ConfigBlocks.at(i).at("type") == "control_channels") { + type = "control"; + } + assert(m_ConfigBlocks.at(i).find("from") != m_ConfigBlocks.at(i).end()); + int from = stoi(m_ConfigBlocks.at(i).at("from")); + if (from > 0) { + from = from - i + 1; + } + assert((i - 2 >= 0) && (i - 2 < tensorOutputs.size())); + assert((i + from - 1 >= 0) && (i + from - 1 < tensorOutputs.size())); + assert(i + from - 1 < i - 2); + nvinfer1::ILayer* out = channelsLayer(type, previous, tensorOutputs[i + from - 1], &network); + previous = out->getOutput(0); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerType = type + "_channels" + ": " + std::to_string(i + from - 1); + printLayerInfo(layerIndex, layerType, " -", outputVol, " -"); + } + else if (m_ConfigBlocks.at(i).at("type") == "dropout") { assert(m_ConfigBlocks.at(i).find("probability") != m_ConfigBlocks.at(i).end()); //float probability = std::stof(m_ConfigBlocks.at(i).at("probability")); diff --git a/nvdsinfer_custom_impl_Yolo/yolo.h b/nvdsinfer_custom_impl_Yolo/yolo.h index 053e9c7..b660459 100644 --- a/nvdsinfer_custom_impl_Yolo/yolo.h +++ b/nvdsinfer_custom_impl_Yolo/yolo.h @@ -27,6 +27,8 @@ #define _YOLO_H_ #include "layers/convolutional_layer.h" +#include "layers/implicit_layer.h" +#include "layers/channels_layer.h" #include "layers/dropout_layer.h" #include "layers/shortcut_layer.h" #include "layers/route_layer.h" diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward.cu b/nvdsinfer_custom_impl_Yolo/yoloForward.cu index dcc4b95..cbaa29f 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloForward.cu +++ b/nvdsinfer_custom_impl_Yolo/yoloForward.cu @@ -60,6 +60,28 @@ __global__ void gpuYoloLayer(const float* input, float* output, const uint gridS = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]; } } + else if (new_coords == 0 && scale_x_y != 1) { // YOLOR incorrect param + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * 2.0 - 0.5; + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * 2.0 - 0.5; + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] + = pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * 2, 2); + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] + = pow(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * 2, 2); + + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)] + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]); + + for (uint i = 0; i < numOutputClasses; ++i) + { + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] + = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]); + } + } else { output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta; diff --git a/readme.md b/readme.md index 1bd89f7..4268c74 100644 --- a/readme.md +++ b/readme.md @@ -6,7 +6,6 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models * New documentation for multiple models * DeepStream tutorials -* Native YOLOR support * Native PP-YOLO support * Models benchmark * GPU NMS @@ -22,7 +21,9 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models * Support for convolutional groups * Support for INT8 calibration * Support for non square models +* **Support for implicit and channel layers (YOLOR)** * **YOLOv5 6.0 native support** +* **Initial YOLOR native support** ## @@ -33,6 +34,7 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models * [dGPU installation](#dgpu-installation) * [Basic usage](#basic-usage) * [YOLOv5 usage](#yolov5-usage) +* [YOLOR usage](#yolor-usage) * [INT8 calibration](#int8-calibration) * [Using your custom model](docs/customModels.md) @@ -55,6 +57,10 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models ## ### Tested models +* [YOLOR-CSP](https://github.com/WongKinYiu/yolor) [[cfg]](https://raw.githubusercontent.com/WongKinYiu/yolor/main/cfg/yolor_csp.cfg) [[pt]](https://drive.google.com/file/d/1ZEqGy4kmZyD-Cj3tEFJcLSZenZBDGiyg/view?usp=sharing) +* [YOLOR-CSP*](https://github.com/WongKinYiu/yolor) [[cfg]](https://raw.githubusercontent.com/WongKinYiu/yolor/main/cfg/yolor_csp.cfg) [[pt]](https://drive.google.com/file/d/1OJKgIasELZYxkIjFoiqyn555bcmixUP2/view?usp=sharing) +* [YOLOR-CSP-X](https://github.com/WongKinYiu/yolor) [[cfg]](https://raw.githubusercontent.com/WongKinYiu/yolor/main/cfg/yolor_csp_x.cfg) [[pt]](https://drive.google.com/file/d/1L29rfIPNH1n910qQClGftknWpTBgAv6c/view?usp=sharing) +* [YOLOR-CSO-X*](https://github.com/WongKinYiu/yolor) [[cfg]](https://raw.githubusercontent.com/WongKinYiu/yolor/main/cfg/yolor_csp_x.cfg) [[pt]](https://drive.google.com/file/d/1NbMG3ivuBQ4S8kEhFJ0FIqOQXevGje_w/view?usp=sharing) * [YOLOv5 6.0](https://github.com/ultralytics/yolov5) [[pt]](https://github.com/ultralytics/yolov5/releases/tag/v6.0) * [YOLOv4x-Mish](https://github.com/AlexeyAB/darknet) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4x-mish.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4x-mish.weights)] * [YOLOv4-CSP](https://github.com/WongKinYiu/ScaledYOLOv4/tree/yolov4-csp) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-csp.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4-csp.weights)] @@ -285,7 +291,7 @@ config-file=config_infer_primary_yoloV2.txt ### YOLOv5 usage -#### 1. Copy gen_wts_yoloV5.py from DeepStream-Yolo/utils to ultralytics/yolov5 folder +#### 1. Copy gen_wts_yoloV5.py from DeepStream-Yolo/utils to [ultralytics/yolov5](https://github.com/ultralytics/yolov5) folder #### 2. Open the ultralytics/yolov5 folder @@ -404,6 +410,87 @@ deepstream-app -c deepstream_app_config.txt ## +### YOLOR usage + +**NOTE**: For now, available only for YOLOR-CSP, YOLOR-CSP*, YOLOR-CSP-X and YOLOR-CSP-X*. + +#### 1. Copy gen_wts_yolor.py from DeepStream-Yolo/utils to [yolor](https://github.com/WongKinYiu/yolor) folder + +#### 2. Open the yolor folder + +#### 3. Download pt file from [yolor](https://github.com/WongKinYiu/yolor) website + +#### 4. Generate wts file (example for YOLOR-CSP) + +``` +python3 gen_wts_yolor.py -w yolor_csp.pt -c cfg/yolor_csp.cfg +``` + +#### 5. Copy cfg and generated wts files to DeepStream-Yolo folder + +#### 6. Open DeepStream-Yolo folder + +#### 7. Compile lib + +* x86 platform + +``` +CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo +``` + +* Jetson platform + +``` +CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo +``` + +#### 8. Edit config_infer_primary_yolor.txt for your model (example for YOLOR-CSP) + +``` +[property] +... +# 0=RGB, 1=BGR, 2=GRAYSCALE +model-color-format=0 +# CFG +custom-network-config=yolor_csp.cfg +# WTS +model-file=yolor_csp.wts +# Generated TensorRT model (will be created if it doesn't exist) +model-engine-file=model_b1_gpu0_fp32.engine +# Model labels file +labelfile-path=labels.txt +# Batch size +batch-size=1 +# 0=FP32, 1=INT8, 2=FP16 mode +network-mode=0 +# Number of classes in label file +num-detected-classes=80 +... +[class-attrs-all] +# CONF_THRESH +pre-cluster-threshold=0.25 +``` + +#### 8. Change the deepstream_app_config.txt file + +``` +... +[primary-gie] +enable=1 +gpu-id=0 +gie-unique-id=1 +nvbuf-memory-type=0 +config-file=config_infer_primary_yolor.txt +``` + +#### 9. Run + +``` +deepstream-app -c deepstream_app_config.txt +``` + +## + ### INT8 calibration #### 1. Install OpenCV diff --git a/utils/gen_wts_yoloV5.py b/utils/gen_wts_yoloV5.py index 97adc72..f648ee4 100644 --- a/utils/gen_wts_yoloV5.py +++ b/utils/gen_wts_yoloV5.py @@ -8,7 +8,7 @@ from utils.torch_utils import select_device def parse_args(): - parser = argparse.ArgumentParser(description="PyTorch conversion") + parser = argparse.ArgumentParser(description="PyTorch YOLOv5 conversion") parser.add_argument("-w", "--weights", required=True, help="Input weights (.pt) file path (required)") parser.add_argument("-c", "--yaml", help="Input cfg (.yaml) file path") parser.add_argument("-mw", "--width", help="Model width (default = 640 / 1280 [P6])") @@ -76,7 +76,7 @@ with open(wts_file, "w") as f: cv1 += "{} {} ".format(k, len(vr)) for vv in vr: cv1 += " " - cv1 += struct.pack(">f" ,float(vv)).hex() + cv1 += struct.pack(">f", float(vv)).hex() cv1 += "\n" conv_count += 1 elif cv1 != "" and ".m." in k: @@ -86,7 +86,7 @@ with open(wts_file, "w") as f: cv3 += "{} {} ".format(k, len(vr)) for vv in vr: cv3 += " " - cv3 += struct.pack(">f" ,float(vv)).hex() + cv3 += struct.pack(">f", float(vv)).hex() cv3 += "\n" cv3_idx = idx conv_count += 1 @@ -98,7 +98,7 @@ with open(wts_file, "w") as f: wts_write += "{} {} ".format(k, len(vr)) for vv in vr: wts_write += " " - wts_write += struct.pack(">f" ,float(vv)).hex() + wts_write += struct.pack(">f", float(vv)).hex() wts_write += "\n" conv_count += 1 f.write("{}\n".format(conv_count)) diff --git a/utils/gen_wts_yolor.py b/utils/gen_wts_yolor.py new file mode 100644 index 0000000..6358b72 --- /dev/null +++ b/utils/gen_wts_yolor.py @@ -0,0 +1,43 @@ +import argparse +import os +import struct +import torch +from utils.torch_utils import select_device +from models.models import Darknet + + +def parse_args(): + parser = argparse.ArgumentParser(description="PyTorch YOLOR conversion (main branch)") + parser.add_argument("-w", "--weights", required=True, help="Input weights (.pt) file path (required)") + parser.add_argument("-c", "--cfg", required=True, help="Input cfg (.cfg) file path (required)") + args = parser.parse_args() + if not os.path.isfile(args.weights): + raise SystemExit("Invalid weights file") + if not os.path.isfile(args.cfg): + raise SystemExit("Invalid cfg file") + return args.weights, args.cfg + + +pt_file, cfg_file = parse_args() + +wts_file = pt_file.split(".pt")[0] + ".wts" + +device = select_device("cpu") +model = Darknet(cfg_file).to(device) +model.load_state_dict(torch.load(pt_file, map_location=device)["model"]) +model.to(device).eval() + +with open(wts_file, "w") as f: + wts_write = "" + conv_count = 0 + for k, v in model.state_dict().items(): + if not "num_batches_tracked" in k: + vr = v.reshape(-1).cpu().numpy() + wts_write += "{} {} ".format(k, len(vr)) + for vv in vr: + wts_write += " " + wts_write += struct.pack(">f", float(vv)).hex() + wts_write += "\n" + conv_count += 1 + f.write("{}\n".format(conv_count)) + f.write(wts_write)