From f9c7a4dfca3b5a1e2b4bdd63675e25d17d50b607 Mon Sep 17 00:00:00 2001 From: Marcos Luciano Date: Fri, 27 Jan 2023 15:56:00 -0300 Subject: [PATCH] Add YOLOv8 support --- README.md | 16 +- config_infer_primary.txt | 1 + config_infer_primary_yoloV5.txt | 2 +- config_infer_primary_yoloV7.txt | 1 + config_infer_primary_yoloV8.txt | 27 + config_infer_primary_yolor.txt | 1 + docs/YOLOv8.md | 139 +++ nvdsinfer_custom_impl_Yolo/calibrator.cpp | 247 ++-- nvdsinfer_custom_impl_Yolo/calibrator.h | 85 +- .../layers/activation_layer.cpp | 215 ++-- .../layers/activation_layer.h | 10 +- .../layers/batchnorm_layer.cpp | 166 ++- .../layers/batchnorm_layer.h | 11 +- .../layers/c2f_layer.cpp | 82 ++ nvdsinfer_custom_impl_Yolo/layers/c2f_layer.h | 18 + .../layers/channels_layer.cpp | 49 +- .../layers/channels_layer.h | 9 +- .../layers/cls_layer.cpp | 32 +- nvdsinfer_custom_impl_Yolo/layers/cls_layer.h | 6 +- .../layers/convolutional_layer.cpp | 379 +++--- .../layers/convolutional_layer.h | 14 +- .../layers/detect_v8_layer.cpp | 196 +++ .../layers/detect_v8_layer.h | 18 + .../layers/implicit_layer.cpp | 49 +- .../layers/implicit_layer.h | 10 +- .../layers/pooling_layer.cpp | 79 +- .../layers/pooling_layer.h | 7 +- .../layers/reduce_layer.cpp | 80 +- .../layers/reduce_layer.h | 6 +- .../layers/reg_layer.cpp | 174 ++- nvdsinfer_custom_impl_Yolo/layers/reg_layer.h | 10 +- .../layers/reorg_layer.cpp | 83 +- .../layers/reorg_layer.h | 11 +- .../layers/route_layer.cpp | 126 +- .../layers/route_layer.h | 9 +- .../layers/shortcut_layer.cpp | 63 +- .../layers/shortcut_layer.h | 13 +- .../layers/shuffle_layer.cpp | 212 ++-- .../layers/shuffle_layer.h | 10 +- .../layers/softmax_layer.cpp | 30 +- .../layers/softmax_layer.h | 6 +- .../layers/upsample_layer.cpp | 34 +- .../layers/upsample_layer.h | 6 +- .../nvdsinfer_yolo_engine.cpp | 123 +- .../nvdsparsebbox_Yolo.cpp | 155 ++- nvdsinfer_custom_impl_Yolo/utils.cpp | 200 +-- nvdsinfer_custom_impl_Yolo/utils.h | 7 +- nvdsinfer_custom_impl_Yolo/yolo.cpp | 1070 ++++++++--------- nvdsinfer_custom_impl_Yolo/yolo.h | 51 +- nvdsinfer_custom_impl_Yolo/yoloForward.cu | 132 +- nvdsinfer_custom_impl_Yolo/yoloForward_e.cu | 92 +- nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu | 132 +- nvdsinfer_custom_impl_Yolo/yoloForward_r.cu | 134 +-- nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu | 165 ++- nvdsinfer_custom_impl_Yolo/yoloForward_v8.cu | 62 + nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp | 495 ++++---- nvdsinfer_custom_impl_Yolo/yoloPlugins.h | 128 +- utils/gen_wts_ppyoloe.py | 20 +- utils/gen_wts_yoloV8.py | 315 +++++ 59 files changed, 3260 insertions(+), 2763 deletions(-) create mode 100644 config_infer_primary_yoloV8.txt create mode 100644 docs/YOLOv8.md create mode 100644 nvdsinfer_custom_impl_Yolo/layers/c2f_layer.cpp create mode 100644 nvdsinfer_custom_impl_Yolo/layers/c2f_layer.h create mode 100644 nvdsinfer_custom_impl_Yolo/layers/detect_v8_layer.cpp create mode 100644 nvdsinfer_custom_impl_Yolo/layers/detect_v8_layer.h create mode 100644 nvdsinfer_custom_impl_Yolo/yoloForward_v8.cu create mode 100644 utils/gen_wts_yoloV8.py diff --git a/README.md b/README.md index 95e4b9f..96a700c 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models * YOLOX support * YOLOv6 support * Dynamic batch-size +* PP-YOLOE+ support ### Improvements on this repository @@ -22,11 +23,12 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models * New documentation for multiple models * YOLOv5 support * YOLOR support -* **GPU YOLO Decoder** [#138](https://github.com/marcoslucianops/DeepStream-Yolo/issues/138) -* **PP-YOLOE support** -* **YOLOv7 support** -* **Optimized NMS** [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142) -* **Models benchmarks** +* GPU YOLO Decoder [#138](https://github.com/marcoslucianops/DeepStream-Yolo/issues/138) +* PP-YOLOE support +* YOLOv7 support +* Optimized NMS [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142) +* Models benchmarks +* **YOLOv8 support** ## @@ -44,6 +46,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models * [YOLOR usage](docs/YOLOR.md) * [PP-YOLOE usage](docs/PPYOLOE.md) * [YOLOv7 usage](docs/YOLOv7.md) +* [YOLOv8 usage](docs/YOLOv8.md) * [Using your custom model](docs/customModels.md) * [Multiple YOLO GIEs](docs/multipleGIEs.md) @@ -108,6 +111,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models * [YOLOR](https://github.com/WongKinYiu/yolor) * [PP-YOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe) * [YOLOv7](https://github.com/WongKinYiu/yolov7) +* [YOLOv8](https://github.com/ultralytics/ultralytics) * [MobileNet-YOLO](https://github.com/dog-qiuqiu/MobileNet-Yolo) * [YOLO-Fastest](https://github.com/dog-qiuqiu/Yolo-Fastest) @@ -131,7 +135,7 @@ sample = 1920x1080 video - Eval ``` -nms-iou-threshold = 0.6 (Darknet) / 0.65 (PyTorch) / 0.7 (Paddle) +nms-iou-threshold = 0.6 (Darknet and YOLOv8) / 0.65 (YOLOR, YOLOv5 and YOLOv7) / 0.7 (Paddle) pre-cluster-threshold = 0.001 topk = 300 ``` diff --git a/config_infer_primary.txt b/config_infer_primary.txt index 4989489..fa5788d 100644 --- a/config_infer_primary.txt +++ b/config_infer_primary.txt @@ -16,6 +16,7 @@ process-mode=1 network-type=0 cluster-mode=2 maintain-aspect-ratio=0 +symmetric-padding=1 parse-bbox-func-name=NvDsInferParseYolo custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so engine-create-func-name=NvDsInferYoloCudaEngineGet diff --git a/config_infer_primary_yoloV5.txt b/config_infer_primary_yoloV5.txt index b3f6eb7..601ffb4 100644 --- a/config_infer_primary_yoloV5.txt +++ b/config_infer_primary_yoloV5.txt @@ -16,10 +16,10 @@ process-mode=1 network-type=0 cluster-mode=2 maintain-aspect-ratio=1 +symmetric-padding=1 parse-bbox-func-name=NvDsInferParseYolo custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so engine-create-func-name=NvDsInferYoloCudaEngineGet -symmetric-padding=1 [class-attrs-all] nms-iou-threshold=0.45 diff --git a/config_infer_primary_yoloV7.txt b/config_infer_primary_yoloV7.txt index d063418..0e35f08 100644 --- a/config_infer_primary_yoloV7.txt +++ b/config_infer_primary_yoloV7.txt @@ -16,6 +16,7 @@ process-mode=1 network-type=0 cluster-mode=2 maintain-aspect-ratio=1 +symmetric-padding=1 parse-bbox-func-name=NvDsInferParseYolo custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so engine-create-func-name=NvDsInferYoloCudaEngineGet diff --git a/config_infer_primary_yoloV8.txt b/config_infer_primary_yoloV8.txt new file mode 100644 index 0000000..3214bd3 --- /dev/null +++ b/config_infer_primary_yoloV8.txt @@ -0,0 +1,27 @@ +[property] +gpu-id=0 +net-scale-factor=0.0039215697906911373 +model-color-format=0 +custom-network-config=yolov8s.cfg +model-file=yolov8s.wts +model-engine-file=model_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=1 +symmetric-padding=1 +parse-bbox-func-name=NvDsInferParseYolo +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so +engine-create-func-name=NvDsInferYoloCudaEngineGet + +[class-attrs-all] +nms-iou-threshold=0.45 +pre-cluster-threshold=0.25 +topk=300 diff --git a/config_infer_primary_yolor.txt b/config_infer_primary_yolor.txt index 3c5490a..4e178de 100644 --- a/config_infer_primary_yolor.txt +++ b/config_infer_primary_yolor.txt @@ -16,6 +16,7 @@ process-mode=1 network-type=0 cluster-mode=2 maintain-aspect-ratio=1 +symmetric-padding=1 parse-bbox-func-name=NvDsInferParseYolo custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so engine-create-func-name=NvDsInferYoloCudaEngineGet diff --git a/docs/YOLOv8.md b/docs/YOLOv8.md new file mode 100644 index 0000000..bddffc5 --- /dev/null +++ b/docs/YOLOv8.md @@ -0,0 +1,139 @@ +# YOLOv8 usage + +**NOTE**: The yaml file is not required. + +* [Convert model](#convert-model) +* [Compile the lib](#compile-the-lib) +* [Edit the config_infer_primary_yoloV8 file](#edit-the-config_infer_primary_yolov8-file) +* [Edit the deepstream_app_config file](#edit-the-deepstream_app_config-file) +* [Testing the model](#testing-the-model) + +## + +### Convert model + +#### 1. Download the YOLOv8 repo and install the requirements + +``` +git clone https://github.com/ultralytics/ultralytics.git +cd ultralytics +pip3 install -r requirements.txt +``` + +**NOTE**: It is recommended to use Python virtualenv. + +#### 2. Copy conversor + +Copy the `gen_wts_yoloV8.py` file from `DeepStream-Yolo/utils` directory to the `ultralytics` folder. + +#### 3. Download the model + +Download the `pt` file from [YOLOv8](https://github.com/ultralytics/assets/releases/) releases (example for YOLOv8s) + +``` +wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s.pt +``` + +**NOTE**: You can use your custom model, but it is important to keep the YOLO model reference (`yolov8_`) in you `cfg` and `weights`/`wts` filenames to generate the engine correctly. + +#### 4. Convert model + +Generate the `cfg` and `wts` files (example for YOLOv8s) + +``` +python3 gen_wts_yoloV8.py -w yolov8s.pt +``` + +**NOTE**: To change the inference size (defaut: 640) + +``` +-s SIZE +--size SIZE +-s HEIGHT WIDTH +--size HEIGHT WIDTH +``` + +Example for 1280 + +``` +-s 1280 +``` + +or + +``` +-s 1280 1280 +``` + +#### 5. Copy generated files + +Copy the generated `cfg` and `wts` files to the `DeepStream-Yolo` folder. + +## + +### Compile the lib + +Open the `DeepStream-Yolo` folder and compile the lib + +* DeepStream 6.1.1 on x86 platform + + ``` + CUDA_VER=11.7 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.1 on x86 platform + + ``` + CUDA_VER=11.6 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.0.1 / 6.0 on x86 platform + + ``` + CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.1.1 / 6.1 on Jetson platform + + ``` + CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.0.1 / 6.0 on Jetson platform + + ``` + CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo + ``` + +## + +### Edit the config_infer_primary_yoloV8 file + +Edit the `config_infer_primary_yoloV8.txt` file according to your model (example for YOLOv8s) + +``` +[property] +... +custom-network-config=yolov8s.cfg +model-file=yolov8s.wts +... +``` + +## + +### Edit the deepstream_app_config file + +``` +... +[primary-gie] +... +config-file=config_infer_primary_yoloV8.txt +``` + +## + +### Testing the model + +``` +deepstream-app -c deepstream_app_config.txt +``` diff --git a/nvdsinfer_custom_impl_Yolo/calibrator.cpp b/nvdsinfer_custom_impl_Yolo/calibrator.cpp index 2cc1e53..c445de7 100644 --- a/nvdsinfer_custom_impl_Yolo/calibrator.cpp +++ b/nvdsinfer_custom_impl_Yolo/calibrator.cpp @@ -4,139 +4,130 @@ */ #include "calibrator.h" + #include #include -namespace nvinfer1 +Int8EntropyCalibrator2::Int8EntropyCalibrator2(const int& batchsize, const int& channels, const int& height, + const int& width, const int& letterbox, const std::string& imgPath, + const std::string& calibTablePath) : batchSize(batchsize), inputC(channels), inputH(height), inputW(width), + letterBox(letterbox), calibTablePath(calibTablePath), imageIndex(0) { - Int8EntropyCalibrator2::Int8EntropyCalibrator2(const int &batchsize, const int &channels, const int &height, const int &width, const int &letterbox, const std::string &imgPath, - const std::string &calibTablePath):batchSize(batchsize), inputC(channels), inputH(height), inputW(width), letterBox(letterbox), calibTablePath(calibTablePath), imageIndex(0) - { - inputCount = batchsize * channels * height * width; - std::fstream f(imgPath); - if (f.is_open()) - { - std::string temp; - while (std::getline(f, temp)) imgPaths.push_back(temp); - } - batchData = new float[inputCount]; - CUDA_CHECK(cudaMalloc(&deviceInput, inputCount * sizeof(float))); - } - - Int8EntropyCalibrator2::~Int8EntropyCalibrator2() - { - CUDA_CHECK(cudaFree(deviceInput)); - if (batchData) - delete[] batchData; - } - - int Int8EntropyCalibrator2::getBatchSize() const noexcept - { - return batchSize; - } - - bool Int8EntropyCalibrator2::getBatch(void **bindings, const char **names, int nbBindings) noexcept - { - if (imageIndex + batchSize > uint(imgPaths.size())) - return false; - - float* ptr = batchData; - for (size_t j = imageIndex; j < imageIndex + batchSize; ++j) - { - cv::Mat img = cv::imread(imgPaths[j], cv::IMREAD_COLOR); - std::vectorinputData = prepareImage(img, inputC, inputH, inputW, letterBox); - - int len = (int)(inputData.size()); - memcpy(ptr, inputData.data(), len * sizeof(float)); - - ptr += inputData.size(); - std::cout << "Load image: " << imgPaths[j] << std::endl; - std::cout << "Progress: " << (j + 1)*100. / imgPaths.size() << "%" << std::endl; - } - imageIndex += batchSize; - CUDA_CHECK(cudaMemcpy(deviceInput, batchData, inputCount * sizeof(float), cudaMemcpyHostToDevice)); - bindings[0] = deviceInput; - return true; - } - - const void* Int8EntropyCalibrator2::readCalibrationCache(std::size_t &length) noexcept - { - calibrationCache.clear(); - std::ifstream input(calibTablePath, std::ios::binary); - input >> std::noskipws; - if (readCache && input.good()) - { - std::copy(std::istream_iterator(input), std::istream_iterator(), - std::back_inserter(calibrationCache)); - } - length = calibrationCache.size(); - return length ? calibrationCache.data() : nullptr; - } - - void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, std::size_t length) noexcept - { - std::ofstream output(calibTablePath, std::ios::binary); - output.write(reinterpret_cast(cache), length); - } + inputCount = batchsize * channels * height * width; + std::fstream f(imgPath); + if (f.is_open()) { + std::string temp; + while (std::getline(f, temp)) + imgPaths.push_back(temp); + } + batchData = new float[inputCount]; + CUDA_CHECK(cudaMalloc(&deviceInput, inputCount * sizeof(float))); } -std::vector prepareImage(cv::Mat& img, int input_c, int input_h, int input_w, int letter_box) +Int8EntropyCalibrator2::~Int8EntropyCalibrator2() { - cv::Mat out; - int image_w = img.cols; - int image_h = img.rows; - if (image_w != input_w || image_h != input_h) - { - if (letter_box == 1) - { - float ratio_w = (float)image_w / (float)input_w; - float ratio_h = (float)image_h / (float)input_h; - if (ratio_w > ratio_h) - { - int new_width = input_w * ratio_h; - int x = (image_w - new_width) / 2; - cv::Rect roi(abs(x), 0, new_width, image_h); - out = img(roi); - } - else if (ratio_w < ratio_h) - { - int new_height = input_h * ratio_w; - int y = (image_h - new_height) / 2; - cv::Rect roi(0, abs(y), image_w, new_height); - out = img(roi); - } - else { - out = img; - } - cv::resize(out, out, cv::Size(input_w, input_h), 0, 0, cv::INTER_CUBIC); - } - else - { - cv::resize(img, out, cv::Size(input_w, input_h), 0, 0, cv::INTER_CUBIC); - } - cv::cvtColor(out, out, cv::COLOR_BGR2RGB); - } - else - { - cv::cvtColor(img, out, cv::COLOR_BGR2RGB); - } - if (input_c == 3) - { - out.convertTo(out, CV_32FC3, 1.0 / 255.0); - } - else - { - out.convertTo(out, CV_32FC1, 1.0 / 255.0); - } - std::vector input_channels(input_c); - cv::split(out, input_channels); - std::vector result(input_h * input_w * input_c); - auto data = result.data(); - int channelLength = input_h * input_w; - for (int i = 0; i < input_c; ++i) - { - memcpy(data, input_channels[i].data, channelLength * sizeof(float)); - data += channelLength; - } - return result; + CUDA_CHECK(cudaFree(deviceInput)); + if (batchData) + delete[] batchData; +} + +int +Int8EntropyCalibrator2::getBatchSize() const noexcept +{ + return batchSize; +} + +bool +Int8EntropyCalibrator2::getBatch(void** bindings, const char** names, int nbBindings) noexcept +{ + if (imageIndex + batchSize > uint(imgPaths.size())) + return false; + + float* ptr = batchData; + for (size_t i = imageIndex; i < imageIndex + batchSize; ++i) { + cv::Mat img = cv::imread(imgPaths[i], cv::IMREAD_COLOR); + std::vector inputData = prepareImage(img, inputC, inputH, inputW, letterBox); + + int len = (int) (inputData.size()); + memcpy(ptr, inputData.data(), len * sizeof(float)); + + ptr += inputData.size(); + std::cout << "Load image: " << imgPaths[i] << std::endl; + std::cout << "Progress: " << (i + 1)*100. / imgPaths.size() << "%" << std::endl; + } + imageIndex += batchSize; + CUDA_CHECK(cudaMemcpy(deviceInput, batchData, inputCount * sizeof(float), cudaMemcpyHostToDevice)); + bindings[0] = deviceInput; + return true; +} + +const void* +Int8EntropyCalibrator2::readCalibrationCache(std::size_t &length) noexcept +{ + calibrationCache.clear(); + std::ifstream input(calibTablePath, std::ios::binary); + input >> std::noskipws; + if (readCache && input.good()) + std::copy(std::istream_iterator(input), std::istream_iterator(), std::back_inserter(calibrationCache)); + length = calibrationCache.size(); + return length ? calibrationCache.data() : nullptr; +} + +void +Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, std::size_t length) noexcept +{ + std::ofstream output(calibTablePath, std::ios::binary); + output.write(reinterpret_cast(cache), length); +} + +std::vector +prepareImage(cv::Mat& img, int input_c, int input_h, int input_w, int letter_box) +{ + cv::Mat out; + int image_w = img.cols; + int image_h = img.rows; + if (image_w != input_w || image_h != input_h) { + if (letter_box == 1) { + float ratio_w = (float) image_w / (float) input_w; + float ratio_h = (float) image_h / (float) input_h; + if (ratio_w > ratio_h) { + int new_width = input_w * ratio_h; + int x = (image_w - new_width) / 2; + cv::Rect roi(abs(x), 0, new_width, image_h); + out = img(roi); + } + else if (ratio_w < ratio_h) { + int new_height = input_h * ratio_w; + int y = (image_h - new_height) / 2; + cv::Rect roi(0, abs(y), image_w, new_height); + out = img(roi); + } + else + out = img; + cv::resize(out, out, cv::Size(input_w, input_h), 0, 0, cv::INTER_CUBIC); + } + else { + cv::resize(img, out, cv::Size(input_w, input_h), 0, 0, cv::INTER_CUBIC); + } + cv::cvtColor(out, out, cv::COLOR_BGR2RGB); + } + else + cv::cvtColor(img, out, cv::COLOR_BGR2RGB); + + if (input_c == 3) + out.convertTo(out, CV_32FC3, 1.0 / 255.0); + else + out.convertTo(out, CV_32FC1, 1.0 / 255.0); + + std::vector input_channels(input_c); + cv::split(out, input_channels); + std::vector result(input_h * input_w * input_c); + auto data = result.data(); + int channelLength = input_h * input_w; + for (int i = 0; i < input_c; ++i) { + memcpy(data, input_channels[i].data, channelLength * sizeof(float)); + data += channelLength; + } + + return result; } diff --git a/nvdsinfer_custom_impl_Yolo/calibrator.h b/nvdsinfer_custom_impl_Yolo/calibrator.h index 3d06865..2a18f83 100644 --- a/nvdsinfer_custom_impl_Yolo/calibrator.h +++ b/nvdsinfer_custom_impl_Yolo/calibrator.h @@ -6,57 +6,50 @@ #ifndef CALIBRATOR_H #define CALIBRATOR_H -#include "opencv2/opencv.hpp" -#include "cuda_runtime.h" -#include "NvInfer.h" #include -#include -#ifndef CUDA_CHECK -#define CUDA_CHECK(callstr) \ - { \ - cudaError_t error_code = callstr; \ - if (error_code != cudaSuccess) { \ - std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ - assert(0); \ - } \ - } -#endif +#include "NvInfer.h" +#include "opencv2/opencv.hpp" -namespace nvinfer1 { - class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 { - public: - Int8EntropyCalibrator2(const int &batchsize, - const int &channels, - const int &height, - const int &width, - const int &letterbox, - const std::string &imgPath, - const std::string &calibTablePath); - - virtual ~Int8EntropyCalibrator2(); - int getBatchSize() const noexcept override; - bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override; - const void* readCalibrationCache(std::size_t& length) noexcept override; - void writeCalibrationCache(const void* cache, size_t length) noexcept override; - - private: - int batchSize; - int inputC; - int inputH; - int inputW; - int letterBox; - std::string calibTablePath; - size_t imageIndex; - size_t inputCount; - std::vector imgPaths; - float *batchData{ nullptr }; - void *deviceInput{ nullptr }; - bool readCache; - std::vector calibrationCache; - }; +#define CUDA_CHECK(status) { \ + if (status != 0) { \ + std::cout << "CUDA failure: " << cudaGetErrorString(status) << " in file " << __FILE__ << " at line " << __LINE__ << \ + std::endl; \ + abort(); \ + } \ } +class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 { + public: + Int8EntropyCalibrator2(const int& batchsize, const int& channels, const int& height, const int& width, + const int& letterbox, const std::string& imgPath, const std::string& calibTablePath); + + virtual ~Int8EntropyCalibrator2(); + + int getBatchSize() const noexcept override; + + bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override; + + const void* readCalibrationCache(std::size_t& length) noexcept override; + + void writeCalibrationCache(const void* cache, size_t length) noexcept override; + + private: + int batchSize; + int inputC; + int inputH; + int inputW; + int letterBox; + std::string calibTablePath; + size_t imageIndex; + size_t inputCount; + std::vector imgPaths; + float* batchData {nullptr}; + void* deviceInput {nullptr}; + bool readCache; + std::vector calibrationCache; +}; + std::vector prepareImage(cv::Mat& img, int input_c, int input_h, int input_w, int letter_box); #endif //CALIBRATOR_H diff --git a/nvdsinfer_custom_impl_Yolo/layers/activation_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/activation_layer.cpp index 139b42f..ce3d9a1 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/activation_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/activation_layer.cpp @@ -5,118 +5,107 @@ #include "activation_layer.h" -nvinfer1::ITensor* activationLayer( - int layerIdx, - std::string activation, - nvinfer1::ITensor* input, - nvinfer1::INetworkDefinition* network) -{ - nvinfer1::ITensor* output; +#include +#include - if (activation == "linear") - { - output = input; - } - else if (activation == "relu") - { - nvinfer1::IActivationLayer* relu = network->addActivation(*input, nvinfer1::ActivationType::kRELU); - assert(relu != nullptr); - std::string reluLayerName = "relu_" + std::to_string(layerIdx); - relu->setName(reluLayerName.c_str()); - output = relu->getOutput(0); - } - else if (activation == "sigmoid" || activation == "logistic") - { - nvinfer1::IActivationLayer* sigmoid = network->addActivation(*input, nvinfer1::ActivationType::kSIGMOID); - assert(sigmoid != nullptr); - std::string sigmoidLayerName = "sigmoid_" + std::to_string(layerIdx); - sigmoid->setName(sigmoidLayerName.c_str()); - output = sigmoid->getOutput(0); - } - else if (activation == "tanh") - { - nvinfer1::IActivationLayer* tanh = network->addActivation(*input, nvinfer1::ActivationType::kTANH); - assert(tanh != nullptr); - std::string tanhLayerName = "tanh_" + std::to_string(layerIdx); - tanh->setName(tanhLayerName.c_str()); - output = tanh->getOutput(0); - } - else if (activation == "leaky") - { - nvinfer1::IActivationLayer* leaky = network->addActivation(*input, nvinfer1::ActivationType::kLEAKY_RELU); - assert(leaky != nullptr); - std::string leakyLayerName = "leaky_" + std::to_string(layerIdx); - leaky->setName(leakyLayerName.c_str()); - leaky->setAlpha(0.1); - output = leaky->getOutput(0); - } - else if (activation == "softplus") - { - nvinfer1::IActivationLayer* softplus = network->addActivation(*input, nvinfer1::ActivationType::kSOFTPLUS); - assert(softplus != nullptr); - std::string softplusLayerName = "softplus_" + std::to_string(layerIdx); - softplus->setName(softplusLayerName.c_str()); - output = softplus->getOutput(0); - } - else if (activation == "mish") - { - nvinfer1::IActivationLayer* softplus = network->addActivation(*input, nvinfer1::ActivationType::kSOFTPLUS); - assert(softplus != nullptr); - std::string softplusLayerName = "softplus_" + std::to_string(layerIdx); - softplus->setName(softplusLayerName.c_str()); - nvinfer1::IActivationLayer* tanh = network->addActivation(*softplus->getOutput(0), nvinfer1::ActivationType::kTANH); - assert(tanh != nullptr); - std::string tanhLayerName = "tanh_" + std::to_string(layerIdx); - tanh->setName(tanhLayerName.c_str()); - nvinfer1::IElementWiseLayer* mish - = network->addElementWise(*input, *tanh->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); - assert(mish != nullptr); - std::string mishLayerName = "mish_" + std::to_string(layerIdx); - mish->setName(mishLayerName.c_str()); - output = mish->getOutput(0); - } - else if (activation == "silu" || activation == "swish") - { - nvinfer1::IActivationLayer* sigmoid = network->addActivation(*input, nvinfer1::ActivationType::kSIGMOID); - assert(sigmoid != nullptr); - std::string sigmoidLayerName = "sigmoid_" + std::to_string(layerIdx); - sigmoid->setName(sigmoidLayerName.c_str()); - nvinfer1::IElementWiseLayer* silu - = network->addElementWise(*input, *sigmoid->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); - assert(silu != nullptr); - std::string siluLayerName = "silu_" + std::to_string(layerIdx); - silu->setName(siluLayerName.c_str()); - output = silu->getOutput(0); - } - else if (activation == "hardsigmoid") - { - nvinfer1::IActivationLayer* hardsigmoid = network->addActivation(*input, nvinfer1::ActivationType::kHARD_SIGMOID); - assert(hardsigmoid != nullptr); - std::string hardsigmoidLayerName = "hardsigmoid_" + std::to_string(layerIdx); - hardsigmoid->setName(hardsigmoidLayerName.c_str()); - hardsigmoid->setAlpha(1.0 / 6.0); - hardsigmoid->setBeta(0.5); - output = hardsigmoid->getOutput(0); - } - else if (activation == "hardswish") - { - nvinfer1::IActivationLayer* hardsigmoid = network->addActivation(*input, nvinfer1::ActivationType::kHARD_SIGMOID); - assert(hardsigmoid != nullptr); - std::string hardsigmoidLayerName = "hardsigmoid_" + std::to_string(layerIdx); - hardsigmoid->setName(hardsigmoidLayerName.c_str()); - hardsigmoid->setAlpha(1.0 / 6.0); - hardsigmoid->setBeta(0.5); - nvinfer1::IElementWiseLayer* hardswish - = network->addElementWise(*input, *hardsigmoid->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); - assert(hardswish != nullptr); - std::string hardswishLayerName = "hardswish_" + std::to_string(layerIdx); - hardswish->setName(hardswishLayerName.c_str()); - output = hardswish->getOutput(0); - } - else - { - std::cerr << "Activation not supported: " << activation << std::endl; - std::abort(); - } - return output; +nvinfer1::ITensor* +activationLayer(int layerIdx, std::string activation, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network, + std::string layerName) +{ + nvinfer1::ITensor* output; + + if (activation == "linear") + output = input; + else if (activation == "relu") { + nvinfer1::IActivationLayer* relu = network->addActivation(*input, nvinfer1::ActivationType::kRELU); + assert(relu != nullptr); + std::string reluLayerName = "relu_" + layerName + std::to_string(layerIdx); + relu->setName(reluLayerName.c_str()); + output = relu->getOutput(0); + } + else if (activation == "sigmoid" || activation == "logistic") { + nvinfer1::IActivationLayer* sigmoid = network->addActivation(*input, nvinfer1::ActivationType::kSIGMOID); + assert(sigmoid != nullptr); + std::string sigmoidLayerName = "sigmoid_" + layerName + std::to_string(layerIdx); + sigmoid->setName(sigmoidLayerName.c_str()); + output = sigmoid->getOutput(0); + } + else if (activation == "tanh") { + nvinfer1::IActivationLayer* tanh = network->addActivation(*input, nvinfer1::ActivationType::kTANH); + assert(tanh != nullptr); + std::string tanhLayerName = "tanh_" + layerName + std::to_string(layerIdx); + tanh->setName(tanhLayerName.c_str()); + output = tanh->getOutput(0); + } + else if (activation == "leaky") { + nvinfer1::IActivationLayer* leaky = network->addActivation(*input, nvinfer1::ActivationType::kLEAKY_RELU); + assert(leaky != nullptr); + std::string leakyLayerName = "leaky_" + layerName + std::to_string(layerIdx); + leaky->setName(leakyLayerName.c_str()); + leaky->setAlpha(0.1); + output = leaky->getOutput(0); + } + else if (activation == "softplus") { + nvinfer1::IActivationLayer* softplus = network->addActivation(*input, nvinfer1::ActivationType::kSOFTPLUS); + assert(softplus != nullptr); + std::string softplusLayerName = "softplus_" + layerName + std::to_string(layerIdx); + softplus->setName(softplusLayerName.c_str()); + output = softplus->getOutput(0); + } + else if (activation == "mish") { + nvinfer1::IActivationLayer* softplus = network->addActivation(*input, nvinfer1::ActivationType::kSOFTPLUS); + assert(softplus != nullptr); + std::string softplusLayerName = "softplus_" + layerName + std::to_string(layerIdx); + softplus->setName(softplusLayerName.c_str()); + nvinfer1::IActivationLayer* tanh = network->addActivation(*softplus->getOutput(0), nvinfer1::ActivationType::kTANH); + assert(tanh != nullptr); + std::string tanhLayerName = "tanh_" + layerName + std::to_string(layerIdx); + tanh->setName(tanhLayerName.c_str()); + nvinfer1::IElementWiseLayer* mish = network->addElementWise(*input, *tanh->getOutput(0), + nvinfer1::ElementWiseOperation::kPROD); + assert(mish != nullptr); + std::string mishLayerName = "mish_" + layerName + std::to_string(layerIdx); + mish->setName(mishLayerName.c_str()); + output = mish->getOutput(0); + } + else if (activation == "silu" || activation == "swish") { + nvinfer1::IActivationLayer* sigmoid = network->addActivation(*input, nvinfer1::ActivationType::kSIGMOID); + assert(sigmoid != nullptr); + std::string sigmoidLayerName = "sigmoid_" + layerName + std::to_string(layerIdx); + sigmoid->setName(sigmoidLayerName.c_str()); + nvinfer1::IElementWiseLayer* silu = network->addElementWise(*input, *sigmoid->getOutput(0), + nvinfer1::ElementWiseOperation::kPROD); + assert(silu != nullptr); + std::string siluLayerName = "silu_" + layerName + std::to_string(layerIdx); + silu->setName(siluLayerName.c_str()); + output = silu->getOutput(0); + } + else if (activation == "hardsigmoid") { + nvinfer1::IActivationLayer* hardsigmoid = network->addActivation(*input, nvinfer1::ActivationType::kHARD_SIGMOID); + assert(hardsigmoid != nullptr); + std::string hardsigmoidLayerName = "hardsigmoid_" + layerName + std::to_string(layerIdx); + hardsigmoid->setName(hardsigmoidLayerName.c_str()); + hardsigmoid->setAlpha(1.0 / 6.0); + hardsigmoid->setBeta(0.5); + output = hardsigmoid->getOutput(0); + } + else if (activation == "hardswish") { + nvinfer1::IActivationLayer* hardsigmoid = network->addActivation(*input, nvinfer1::ActivationType::kHARD_SIGMOID); + assert(hardsigmoid != nullptr); + std::string hardsigmoidLayerName = "hardsigmoid_" + layerName + std::to_string(layerIdx); + hardsigmoid->setName(hardsigmoidLayerName.c_str()); + hardsigmoid->setAlpha(1.0 / 6.0); + hardsigmoid->setBeta(0.5); + nvinfer1::IElementWiseLayer* hardswish = network->addElementWise(*input, *hardsigmoid->getOutput(0), + nvinfer1::ElementWiseOperation::kPROD); + assert(hardswish != nullptr); + std::string hardswishLayerName = "hardswish_" + layerName + std::to_string(layerIdx); + hardswish->setName(hardswishLayerName.c_str()); + output = hardswish->getOutput(0); + } + else { + std::cerr << "Activation not supported: " << activation << std::endl; + assert(0); + } + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/activation_layer.h b/nvdsinfer_custom_impl_Yolo/layers/activation_layer.h index c5151ac..8c0b648 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/activation_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/activation_layer.h @@ -6,15 +6,11 @@ #ifndef __ACTIVATION_LAYER_H__ #define __ACTIVATION_LAYER_H__ -#include -#include +#include #include "NvInfer.h" -nvinfer1::ITensor* activationLayer( - int layerIdx, - std::string activation, - nvinfer1::ITensor* input, - nvinfer1::INetworkDefinition* network); +nvinfer1::ITensor* activationLayer(int layerIdx, std::string activation, nvinfer1::ITensor* input, + nvinfer1::INetworkDefinition* network, std::string layerName = ""); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/batchnorm_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/batchnorm_layer.cpp index e6828e7..084b22b 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/batchnorm_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/batchnorm_layer.cpp @@ -3,108 +3,94 @@ * https://www.github.com/marcoslucianops */ -#include #include "batchnorm_layer.h" -nvinfer1::ITensor* batchnormLayer( - int layerIdx, - std::map& block, - std::vector& weights, - std::vector& trtWeights, - int& weightPtr, - std::string weightsType, - float eps, - nvinfer1::ITensor* input, +#include +#include + +nvinfer1::ITensor* +batchnormLayer(int layerIdx, std::map& block, std::vector& weights, + std::vector& trtWeights, int& weightPtr, std::string weightsType, float eps, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "batchnorm"); - assert(block.find("filters") != block.end()); + assert(block.at("type") == "batchnorm"); + assert(block.find("filters") != block.end()); - int filters = std::stoi(block.at("filters")); - std::string activation = block.at("activation"); + int filters = std::stoi(block.at("filters")); + std::string activation = block.at("activation"); - std::vector bnBiases; - std::vector bnWeights; - std::vector bnRunningMean; - std::vector bnRunningVar; + std::vector bnBiases; + std::vector bnWeights; + std::vector bnRunningMean; + std::vector bnRunningVar; - if (weightsType == "weights") - { - for (int i = 0; i < filters; ++i) - { - bnBiases.push_back(weights[weightPtr]); - weightPtr++; - } - for (int i = 0; i < filters; ++i) - { - bnWeights.push_back(weights[weightPtr]); - weightPtr++; - } - for (int i = 0; i < filters; ++i) - { - bnRunningMean.push_back(weights[weightPtr]); - weightPtr++; - } - for (int i = 0; i < filters; ++i) - { - bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5)); - weightPtr++; - } + if (weightsType == "weights") { + for (int i = 0; i < filters; ++i) { + bnBiases.push_back(weights[weightPtr]); + ++weightPtr; } - else - { - for (int i = 0; i < filters; ++i) - { - bnWeights.push_back(weights[weightPtr]); - weightPtr++; - } - for (int i = 0; i < filters; ++i) - { - bnBiases.push_back(weights[weightPtr]); - weightPtr++; - } - for (int i = 0; i < filters; ++i) - { - bnRunningMean.push_back(weights[weightPtr]); - weightPtr++; - } - for (int i = 0; i < filters; ++i) - { - bnRunningVar.push_back(sqrt(weights[weightPtr] + eps)); - weightPtr++; - } + for (int i = 0; i < filters; ++i) { + bnWeights.push_back(weights[weightPtr]); + ++weightPtr; } + for (int i = 0; i < filters; ++i) { + bnRunningMean.push_back(weights[weightPtr]); + ++weightPtr; + } + for (int i = 0; i < filters; ++i) { + bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5)); + ++weightPtr; + } + } + else { + for (int i = 0; i < filters; ++i) { + bnWeights.push_back(weights[weightPtr]); + ++weightPtr; + } + for (int i = 0; i < filters; ++i) { + bnBiases.push_back(weights[weightPtr]); + ++weightPtr; + } + for (int i = 0; i < filters; ++i) { + bnRunningMean.push_back(weights[weightPtr]); + ++weightPtr; + } + for (int i = 0; i < filters; ++i) { + bnRunningVar.push_back(sqrt(weights[weightPtr] + eps)); + ++weightPtr; + } + } - int size = filters; - nvinfer1::Weights shift{nvinfer1::DataType::kFLOAT, nullptr, size}; - nvinfer1::Weights scale{nvinfer1::DataType::kFLOAT, nullptr, size}; - nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, size}; - float* shiftWt = new float[size]; - for (int i = 0; i < size; ++i) - shiftWt[i] = bnBiases.at(i) - ((bnRunningMean.at(i) * bnWeights.at(i)) / bnRunningVar.at(i)); - shift.values = shiftWt; - float* scaleWt = new float[size]; - for (int i = 0; i < size; ++i) - scaleWt[i] = bnWeights.at(i) / bnRunningVar[i]; - scale.values = scaleWt; - float* powerWt = new float[size]; - for (int i = 0; i < size; ++i) - powerWt[i] = 1.0; - power.values = powerWt; - trtWeights.push_back(shift); - trtWeights.push_back(scale); - trtWeights.push_back(power); + int size = filters; + nvinfer1::Weights shift {nvinfer1::DataType::kFLOAT, nullptr, size}; + nvinfer1::Weights scale {nvinfer1::DataType::kFLOAT, nullptr, size}; + nvinfer1::Weights power {nvinfer1::DataType::kFLOAT, nullptr, size}; + float* shiftWt = new float[size]; + for (int i = 0; i < size; ++i) + shiftWt[i] = bnBiases.at(i) - ((bnRunningMean.at(i) * bnWeights.at(i)) / bnRunningVar.at(i)); + shift.values = shiftWt; + float* scaleWt = new float[size]; + for (int i = 0; i < size; ++i) + scaleWt[i] = bnWeights.at(i) / bnRunningVar[i]; + scale.values = scaleWt; + float* powerWt = new float[size]; + for (int i = 0; i < size; ++i) + powerWt[i] = 1.0; + power.values = powerWt; + trtWeights.push_back(shift); + trtWeights.push_back(scale); + trtWeights.push_back(power); - nvinfer1::IScaleLayer* batchnorm = network->addScale(*input, nvinfer1::ScaleMode::kCHANNEL, shift, scale, power); - assert(batchnorm != nullptr); - std::string batchnormLayerName = "batchnorm_" + std::to_string(layerIdx); - batchnorm->setName(batchnormLayerName.c_str()); - output = batchnorm->getOutput(0); + nvinfer1::IScaleLayer* batchnorm = network->addScale(*input, nvinfer1::ScaleMode::kCHANNEL, shift, scale, power); + assert(batchnorm != nullptr); + std::string batchnormLayerName = "batchnorm_" + std::to_string(layerIdx); + batchnorm->setName(batchnormLayerName.c_str()); + output = batchnorm->getOutput(0); - output = activationLayer(layerIdx, activation, output, network); - assert(output != nullptr); + output = activationLayer(layerIdx, activation, output, network); + assert(output != nullptr); - return output; + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/batchnorm_layer.h b/nvdsinfer_custom_impl_Yolo/layers/batchnorm_layer.h index 078b7f1..c3bfffc 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/batchnorm_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/batchnorm_layer.h @@ -13,15 +13,8 @@ #include "activation_layer.h" -nvinfer1::ITensor* batchnormLayer( - int layerIdx, - std::map& block, - std::vector& weights, - std::vector& trtWeights, - int& weightPtr, - std::string weightsType, - float eps, - nvinfer1::ITensor* input, +nvinfer1::ITensor* batchnormLayer(int layerIdx, std::map& block, std::vector& weights, + std::vector& trtWeights, int& weightPtr, std::string weightsType, float eps, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/c2f_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/c2f_layer.cpp new file mode 100644 index 0000000..e51c663 --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/layers/c2f_layer.cpp @@ -0,0 +1,82 @@ +/* + * Created by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#include "c2f_layer.h" + +#include + +#include "convolutional_layer.h" + +nvinfer1::ITensor* +c2fLayer(int layerIdx, std::map& block, std::vector& weights, + std::vector& trtWeights, int& weightPtr, std::string weightsType, float eps, nvinfer1::ITensor* input, + nvinfer1::INetworkDefinition* network) +{ + nvinfer1::ITensor* output; + + assert(block.at("type") == "c2f"); + assert(block.find("n") != block.end()); + assert(block.find("shortcut") != block.end()); + assert(block.find("filters") != block.end()); + + int n = std::stoi(block.at("n")); + bool shortcut = (block.at("shortcut") == "1"); + int filters = std::stoi(block.at("filters")); + + nvinfer1::Dims inputDims = input->getDimensions(); + + nvinfer1::ISliceLayer* sliceLt = network->addSlice(*input,nvinfer1::Dims{3, {0, 0, 0}}, + nvinfer1::Dims{3, {inputDims.d[0] / 2, inputDims.d[1], inputDims.d[2]}}, nvinfer1::Dims{3, {1, 1, 1}}); + assert(sliceLt != nullptr); + std::string sliceLtLayerName = "slice_lt_" + std::to_string(layerIdx); + sliceLt->setName(sliceLtLayerName.c_str()); + nvinfer1::ITensor* lt = sliceLt->getOutput(0); + + nvinfer1::ISliceLayer* sliceRb = network->addSlice(*input,nvinfer1::Dims{3, {inputDims.d[0] / 2, 0, 0}}, + nvinfer1::Dims{3, {inputDims.d[0] / 2, inputDims.d[1], inputDims.d[2]}}, nvinfer1::Dims{3, {1, 1, 1}}); + assert(sliceRb != nullptr); + std::string sliceRbLayerName = "slice_rb_" + std::to_string(layerIdx); + sliceRb->setName(sliceRbLayerName.c_str()); + nvinfer1::ITensor* rb = sliceRb->getOutput(0); + + std::vector concatInputs; + concatInputs.push_back(lt); + concatInputs.push_back(rb); + output = rb; + + for (int i = 0; i < n; ++i) { + std::string cv1MlayerName = "c2f_1_" + std::to_string(i + 1) + "_"; + nvinfer1::ITensor* cv1M = convolutionalLayer(layerIdx, block, weights, trtWeights, weightPtr, weightsType, filters, eps, + output, network, cv1MlayerName); + assert(cv1M != nullptr); + + std::string cv2MlayerName = "c2f_2_" + std::to_string(i + 1) + "_"; + nvinfer1::ITensor* cv2M = convolutionalLayer(layerIdx, block, weights, trtWeights, weightPtr, weightsType, filters, eps, + cv1M, network, cv2MlayerName); + assert(cv2M != nullptr); + + if (shortcut) { + nvinfer1::IElementWiseLayer* ew = network->addElementWise(*rb, *cv2M, nvinfer1::ElementWiseOperation::kSUM); + assert(ew != nullptr); + std::string ewLayerName = "shortcut_c2f_" + std::to_string(i + 1) + "_" + std::to_string(layerIdx); + ew->setName(ewLayerName.c_str()); + output = ew->getOutput(0); + concatInputs.push_back(output); + } + else { + output = cv2M; + concatInputs.push_back(output); + } + } + + nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size()); + assert(concat != nullptr); + std::string concatLayerName = "route_" + std::to_string(layerIdx); + concat->setName(concatLayerName.c_str()); + concat->setAxis(0); + output = concat->getOutput(0); + + return output; +} diff --git a/nvdsinfer_custom_impl_Yolo/layers/c2f_layer.h b/nvdsinfer_custom_impl_Yolo/layers/c2f_layer.h new file mode 100644 index 0000000..28f373f --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/layers/c2f_layer.h @@ -0,0 +1,18 @@ +/* + * Created by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#ifndef __C2F_LAYER_H__ +#define __C2F_LAYER_H__ + +#include +#include + +#include "NvInfer.h" + +nvinfer1::ITensor* c2fLayer(int layerIdx, std::map& block, std::vector& weights, + std::vector& trtWeights, int& weightPtr, std::string weightsType, float eps, nvinfer1::ITensor* input, + nvinfer1::INetworkDefinition* network); + +#endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/channels_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/channels_layer.cpp index 69e183a..14b661c 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/channels_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/channels_layer.cpp @@ -5,33 +5,32 @@ #include "channels_layer.h" -nvinfer1::ITensor* channelsLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, - nvinfer1::ITensor* implicitTensor, - nvinfer1::INetworkDefinition* network) +#include + +nvinfer1::ITensor* +channelsLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, + nvinfer1::ITensor* implicitTensor, nvinfer1::INetworkDefinition* network) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "shift_channels" || block.at("type") == "control_channels"); + assert(block.at("type") == "shift_channels" || block.at("type") == "control_channels"); - if (block.at("type") == "shift_channels") { - nvinfer1::IElementWiseLayer* shift - = network->addElementWise(*input, *implicitTensor, nvinfer1::ElementWiseOperation::kSUM); - assert(shift != nullptr); - std::string shiftLayerName = "shift_channels_" + std::to_string(layerIdx); - shift->setName(shiftLayerName.c_str()); - output = shift->getOutput(0); - } - else if (block.at("type") == "control_channels") { - nvinfer1::IElementWiseLayer* control - = network->addElementWise(*input, *implicitTensor, nvinfer1::ElementWiseOperation::kPROD); - assert(control != nullptr); - std::string controlLayerName = "control_channels_" + std::to_string(layerIdx); - control->setName(controlLayerName.c_str()); - output = control->getOutput(0); - } + if (block.at("type") == "shift_channels") { + nvinfer1::IElementWiseLayer* shift = network->addElementWise(*input, *implicitTensor, + nvinfer1::ElementWiseOperation::kSUM); + assert(shift != nullptr); + std::string shiftLayerName = "shift_channels_" + std::to_string(layerIdx); + shift->setName(shiftLayerName.c_str()); + output = shift->getOutput(0); + } + else if (block.at("type") == "control_channels") { + nvinfer1::IElementWiseLayer* control = network->addElementWise(*input, *implicitTensor, + nvinfer1::ElementWiseOperation::kPROD); + assert(control != nullptr); + std::string controlLayerName = "control_channels_" + std::to_string(layerIdx); + control->setName(controlLayerName.c_str()); + output = control->getOutput(0); + } - return output; + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/channels_layer.h b/nvdsinfer_custom_impl_Yolo/layers/channels_layer.h index 4db704c..6a2d389 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/channels_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/channels_layer.h @@ -7,15 +7,10 @@ #define __CHANNELS_LAYER_H__ #include -#include #include "NvInfer.h" -nvinfer1::ITensor* channelsLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, - nvinfer1::ITensor* implicitTensor, - nvinfer1::INetworkDefinition* network); +nvinfer1::ITensor* channelsLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, + nvinfer1::ITensor* implicitTensor, nvinfer1::INetworkDefinition* network); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/cls_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/cls_layer.cpp index c8eed52..4a6a93b 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/cls_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/cls_layer.cpp @@ -5,25 +5,25 @@ #include "cls_layer.h" -nvinfer1::ITensor* clsLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, +#include + +nvinfer1::ITensor* +clsLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "cls"); + assert(block.at("type") == "cls"); - nvinfer1::IShuffleLayer* shuffle = network->addShuffle(*input); - assert(shuffle != nullptr); - std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx); - shuffle->setName(shuffleLayerName.c_str()); - nvinfer1::Permutation permutation; - permutation.order[0] = 1; - permutation.order[1] = 0; - shuffle->setFirstTranspose(permutation); - output = shuffle->getOutput(0); + nvinfer1::IShuffleLayer* shuffle = network->addShuffle(*input); + assert(shuffle != nullptr); + std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx); + shuffle->setName(shuffleLayerName.c_str()); + nvinfer1::Permutation permutation; + permutation.order[0] = 1; + permutation.order[1] = 0; + shuffle->setFirstTranspose(permutation); + output = shuffle->getOutput(0); - return output; + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/cls_layer.h b/nvdsinfer_custom_impl_Yolo/layers/cls_layer.h index cca342b..3179590 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/cls_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/cls_layer.h @@ -7,14 +7,10 @@ #define __CLS_LAYER_H__ #include -#include #include "NvInfer.h" -nvinfer1::ITensor* clsLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, +nvinfer1::ITensor* clsLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.cpp index be85379..8f3ef62 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.cpp @@ -3,224 +3,197 @@ * https://www.github.com/marcoslucianops */ -#include #include "convolutional_layer.h" -nvinfer1::ITensor* convolutionalLayer( - int layerIdx, - std::map& block, - std::vector& weights, - std::vector& trtWeights, - int& weightPtr, - std::string weightsType, - int& inputChannels, - float eps, - nvinfer1::ITensor* input, - nvinfer1::INetworkDefinition* network) +#include +#include + +nvinfer1::ITensor* +convolutionalLayer(int layerIdx, std::map& block, std::vector& weights, + std::vector& trtWeights, int& weightPtr, std::string weightsType, int& inputChannels, float eps, + nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network, std::string layerName) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "convolutional"); - assert(block.find("filters") != block.end()); - assert(block.find("pad") != block.end()); - assert(block.find("size") != block.end()); - assert(block.find("stride") != block.end()); + assert(block.at("type") == "convolutional" || block.at("type") == "c2f"); + assert(block.find("filters") != block.end()); + assert(block.find("pad") != block.end()); + assert(block.find("size") != block.end()); + assert(block.find("stride") != block.end()); - int filters = std::stoi(block.at("filters")); - int padding = std::stoi(block.at("pad")); - int kernelSize = std::stoi(block.at("size")); - int stride = std::stoi(block.at("stride")); - std::string activation = block.at("activation"); - int bias = filters; + int filters = std::stoi(block.at("filters")); + int padding = std::stoi(block.at("pad")); + int kernelSize = std::stoi(block.at("size")); + int stride = std::stoi(block.at("stride")); + std::string activation = block.at("activation"); + int bias = filters; - bool batchNormalize = false; - if (block.find("batch_normalize") != block.end()) - { - bias = 0; - batchNormalize = (block.at("batch_normalize") == "1"); - } + bool batchNormalize = false; + if (block.find("batch_normalize") != block.end()) { + bias = 0; + batchNormalize = (block.at("batch_normalize") == "1"); + } - int groups = 1; - if (block.find("groups") != block.end()) - groups = std::stoi(block.at("groups")); + int groups = 1; + if (block.find("groups") != block.end()) + groups = std::stoi(block.at("groups")); - if (block.find("bias") != block.end()) - bias = std::stoi(block.at("bias")); + if (block.find("bias") != block.end()) + bias = std::stoi(block.at("bias")); - int pad; - if (padding) - pad = (kernelSize - 1) / 2; - else - pad = 0; + int pad; + if (padding) + pad = (kernelSize - 1) / 2; + else + pad = 0; - int size = filters * inputChannels * kernelSize * kernelSize / groups; - std::vector bnBiases; - std::vector bnWeights; - std::vector bnRunningMean; - std::vector bnRunningVar; - nvinfer1::Weights convWt{nvinfer1::DataType::kFLOAT, nullptr, size}; - nvinfer1::Weights convBias{nvinfer1::DataType::kFLOAT, nullptr, bias}; + int size = filters * inputChannels * kernelSize * kernelSize / groups; + std::vector bnBiases; + std::vector bnWeights; + std::vector bnRunningMean; + std::vector bnRunningVar; + nvinfer1::Weights convWt {nvinfer1::DataType::kFLOAT, nullptr, size}; + nvinfer1::Weights convBias {nvinfer1::DataType::kFLOAT, nullptr, bias}; - if (weightsType == "weights") - { - if (batchNormalize == false) - { - float* val; - if (bias != 0) { - val = new float[filters]; - for (int i = 0; i < filters; ++i) - { - val[i] = weights[weightPtr]; - weightPtr++; - } - convBias.values = val; - trtWeights.push_back(convBias); - } - val = new float[size]; - for (int i = 0; i < size; ++i) - { - val[i] = weights[weightPtr]; - weightPtr++; - } - convWt.values = val; - trtWeights.push_back(convWt); - } - else - { - for (int i = 0; i < filters; ++i) - { - bnBiases.push_back(weights[weightPtr]); - weightPtr++; - } - for (int i = 0; i < filters; ++i) - { - bnWeights.push_back(weights[weightPtr]); - weightPtr++; - } - for (int i = 0; i < filters; ++i) - { - bnRunningMean.push_back(weights[weightPtr]); - weightPtr++; - } - for (int i = 0; i < filters; ++i) - { - bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5)); - weightPtr++; - } - float* val = new float[size]; - for (int i = 0; i < size; ++i) - { - val[i] = weights[weightPtr]; - weightPtr++; - } - convWt.values = val; - trtWeights.push_back(convWt); - if (bias != 0) - trtWeights.push_back(convBias); + if (weightsType == "weights") { + if (batchNormalize == false) { + float* val; + if (bias != 0) { + val = new float[filters]; + for (int i = 0; i < filters; ++i) { + val[i] = weights[weightPtr]; + ++weightPtr; } + convBias.values = val; + trtWeights.push_back(convBias); + } + val = new float[size]; + for (int i = 0; i < size; ++i) { + val[i] = weights[weightPtr]; + ++weightPtr; + } + convWt.values = val; + trtWeights.push_back(convWt); } - else - { - if (batchNormalize == false) - { - float* val = new float[size]; - for (int i = 0; i < size; ++i) - { - val[i] = weights[weightPtr]; - weightPtr++; - } - convWt.values = val; - trtWeights.push_back(convWt); - if (bias != 0) { - val = new float[filters]; - for (int i = 0; i < filters; ++i) - { - val[i] = weights[weightPtr]; - weightPtr++; - } - convBias.values = val; - trtWeights.push_back(convBias); - } - } - else - { - float* val = new float[size]; - for (int i = 0; i < size; ++i) - { - val[i] = weights[weightPtr]; - weightPtr++; - } - convWt.values = val; - for (int i = 0; i < filters; ++i) - { - bnWeights.push_back(weights[weightPtr]); - weightPtr++; - } - for (int i = 0; i < filters; ++i) - { - bnBiases.push_back(weights[weightPtr]); - weightPtr++; - } - for (int i = 0; i < filters; ++i) - { - bnRunningMean.push_back(weights[weightPtr]); - weightPtr++; - } - for (int i = 0; i < filters; ++i) - { - bnRunningVar.push_back(sqrt(weights[weightPtr] + eps)); - weightPtr++; - } - trtWeights.push_back(convWt); - if (bias != 0) - trtWeights.push_back(convBias); - } + else { + for (int i = 0; i < filters; ++i) { + bnBiases.push_back(weights[weightPtr]); + ++weightPtr; + } + for (int i = 0; i < filters; ++i) { + bnWeights.push_back(weights[weightPtr]); + ++weightPtr; + } + for (int i = 0; i < filters; ++i) { + bnRunningMean.push_back(weights[weightPtr]); + ++weightPtr; + } + for (int i = 0; i < filters; ++i) { + bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5)); + ++weightPtr; + } + float* val = new float[size]; + for (int i = 0; i < size; ++i) { + val[i] = weights[weightPtr]; + ++weightPtr; + } + convWt.values = val; + trtWeights.push_back(convWt); + if (bias != 0) + trtWeights.push_back(convBias); } - - nvinfer1::IConvolutionLayer* conv - = network->addConvolutionNd(*input, filters, nvinfer1::Dims{2, {kernelSize, kernelSize}}, convWt, convBias); - assert(conv != nullptr); - std::string convLayerName = "conv_" + std::to_string(layerIdx); - conv->setName(convLayerName.c_str()); - conv->setStrideNd(nvinfer1::Dims{2, {stride, stride}}); - conv->setPaddingNd(nvinfer1::Dims{2, {pad, pad}}); - - if (block.find("groups") != block.end()) - conv->setNbGroups(groups); - - output = conv->getOutput(0); - - if (batchNormalize == true) - { - size = filters; - nvinfer1::Weights shift{nvinfer1::DataType::kFLOAT, nullptr, size}; - nvinfer1::Weights scale{nvinfer1::DataType::kFLOAT, nullptr, size}; - nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, size}; - float* shiftWt = new float[size]; - for (int i = 0; i < size; ++i) - shiftWt[i] = bnBiases.at(i) - ((bnRunningMean.at(i) * bnWeights.at(i)) / bnRunningVar.at(i)); - shift.values = shiftWt; - float* scaleWt = new float[size]; - for (int i = 0; i < size; ++i) - scaleWt[i] = bnWeights.at(i) / bnRunningVar[i]; - scale.values = scaleWt; - float* powerWt = new float[size]; - for (int i = 0; i < size; ++i) - powerWt[i] = 1.0; - power.values = powerWt; - trtWeights.push_back(shift); - trtWeights.push_back(scale); - trtWeights.push_back(power); - - nvinfer1::IScaleLayer* batchnorm = network->addScale(*output, nvinfer1::ScaleMode::kCHANNEL, shift, scale, power); - assert(batchnorm != nullptr); - std::string batchnormLayerName = "batchnorm_" + std::to_string(layerIdx); - batchnorm->setName(batchnormLayerName.c_str()); - output = batchnorm->getOutput(0); + } + else { + if (batchNormalize == false) { + float* val = new float[size]; + for (int i = 0; i < size; ++i) { + val[i] = weights[weightPtr]; + ++weightPtr; + } + convWt.values = val; + trtWeights.push_back(convWt); + if (bias != 0) { + val = new float[filters]; + for (int i = 0; i < filters; ++i) { + val[i] = weights[weightPtr]; + ++weightPtr; + } + convBias.values = val; + trtWeights.push_back(convBias); + } } + else { + float* val = new float[size]; + for (int i = 0; i < size; ++i) { + val[i] = weights[weightPtr]; + ++weightPtr; + } + convWt.values = val; + for (int i = 0; i < filters; ++i) { + bnWeights.push_back(weights[weightPtr]); + ++weightPtr; + } + for (int i = 0; i < filters; ++i) { + bnBiases.push_back(weights[weightPtr]); + ++weightPtr; + } + for (int i = 0; i < filters; ++i) { + bnRunningMean.push_back(weights[weightPtr]); + ++weightPtr; + } + for (int i = 0; i < filters; ++i) { + bnRunningVar.push_back(sqrt(weights[weightPtr] + eps)); + ++weightPtr; + } + trtWeights.push_back(convWt); + if (bias != 0) + trtWeights.push_back(convBias); + } + } - output = activationLayer(layerIdx, activation, output, network); - assert(output != nullptr); + nvinfer1::IConvolutionLayer* conv = network->addConvolutionNd(*input, filters, nvinfer1::Dims{2, {kernelSize, kernelSize}}, + convWt, convBias); + assert(conv != nullptr); + std::string convLayerName = "conv_" + layerName + std::to_string(layerIdx); + conv->setName(convLayerName.c_str()); + conv->setStrideNd(nvinfer1::Dims{2, {stride, stride}}); + conv->setPaddingNd(nvinfer1::Dims{2, {pad, pad}}); - return output; + if (block.find("groups") != block.end()) + conv->setNbGroups(groups); + + output = conv->getOutput(0); + + if (batchNormalize == true) { + size = filters; + nvinfer1::Weights shift {nvinfer1::DataType::kFLOAT, nullptr, size}; + nvinfer1::Weights scale {nvinfer1::DataType::kFLOAT, nullptr, size}; + nvinfer1::Weights power {nvinfer1::DataType::kFLOAT, nullptr, size}; + float* shiftWt = new float[size]; + for (int i = 0; i < size; ++i) + shiftWt[i] = bnBiases.at(i) - ((bnRunningMean.at(i) * bnWeights.at(i)) / bnRunningVar.at(i)); + shift.values = shiftWt; + float* scaleWt = new float[size]; + for (int i = 0; i < size; ++i) + scaleWt[i] = bnWeights.at(i) / bnRunningVar[i]; + scale.values = scaleWt; + float* powerWt = new float[size]; + for (int i = 0; i < size; ++i) + powerWt[i] = 1.0; + power.values = powerWt; + trtWeights.push_back(shift); + trtWeights.push_back(scale); + trtWeights.push_back(power); + + nvinfer1::IScaleLayer* batchnorm = network->addScale(*output, nvinfer1::ScaleMode::kCHANNEL, shift, scale, power); + assert(batchnorm != nullptr); + std::string batchnormLayerName = "batchnorm_" + layerName + std::to_string(layerIdx); + batchnorm->setName(batchnormLayerName.c_str()); + output = batchnorm->getOutput(0); + } + + output = activationLayer(layerIdx, activation, output, network, layerName); + assert(output != nullptr); + + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.h b/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.h index 8df166a..4652bcb 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.h @@ -13,16 +13,8 @@ #include "activation_layer.h" -nvinfer1::ITensor* convolutionalLayer( - int layerIdx, - std::map& block, - std::vector& weights, - std::vector& trtWeights, - int& weightPtr, - std::string weightsType, - int& inputChannels, - float eps, - nvinfer1::ITensor* input, - nvinfer1::INetworkDefinition* network); +nvinfer1::ITensor* convolutionalLayer(int layerIdx, std::map& block, std::vector& weights, + std::vector& trtWeights, int& weightPtr, std::string weightsType, int& inputChannels, float eps, + nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network, std::string layerName = ""); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/detect_v8_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/detect_v8_layer.cpp new file mode 100644 index 0000000..b8ad7e4 --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/layers/detect_v8_layer.cpp @@ -0,0 +1,196 @@ +/* + * Created by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#include "detect_v8_layer.h" + +#include + +nvinfer1::ITensor* +detectV8Layer(int layerIdx, std::map& block, std::vector& weights, + std::vector& trtWeights, int& weightPtr, nvinfer1::ITensor* input, + nvinfer1::INetworkDefinition* network) +{ + nvinfer1::ITensor* output; + + assert(block.at("type") == "detect_v8"); + assert(block.find("num") != block.end()); + assert(block.find("classes") != block.end()); + + int num = std::stoi(block.at("num")); + int classes = std::stoi(block.at("classes")); + int reg_max = num / 4; + + nvinfer1::Dims inputDims = input->getDimensions(); + + nvinfer1::ISliceLayer* sliceBox = network->addSlice(*input, nvinfer1::Dims{2, {0, 0}}, + nvinfer1::Dims{2, {num, inputDims.d[1]}}, nvinfer1::Dims{2, {1, 1}}); + assert(sliceBox != nullptr); + std::string sliceBoxLayerName = "slice_box_" + std::to_string(layerIdx); + sliceBox->setName(sliceBoxLayerName.c_str()); + nvinfer1::ITensor* box = sliceBox->getOutput(0); + + nvinfer1::ISliceLayer* sliceCls = network->addSlice(*input, nvinfer1::Dims{2, {num, 0}}, + nvinfer1::Dims{2, {classes, inputDims.d[1]}}, nvinfer1::Dims{2, {1, 1}}); + assert(sliceCls != nullptr); + std::string sliceClsLayerName = "slice_cls_" + std::to_string(layerIdx); + sliceCls->setName(sliceClsLayerName.c_str()); + nvinfer1::ITensor* cls = sliceCls->getOutput(0); + + nvinfer1::IShuffleLayer* shuffle1Box = network->addShuffle(*box); + assert(shuffle1Box != nullptr); + std::string shuffle1BoxLayerName = "shuffle1_box_" + std::to_string(layerIdx); + shuffle1Box->setName(shuffle1BoxLayerName.c_str()); + nvinfer1::Dims reshape1Dims = {3, {4, reg_max, inputDims.d[1]}}; + shuffle1Box->setReshapeDimensions(reshape1Dims); + nvinfer1::Permutation permutation1; + permutation1.order[0] = 1; + permutation1.order[1] = 0; + permutation1.order[2] = 2; + shuffle1Box->setSecondTranspose(permutation1); + box = shuffle1Box->getOutput(0); + + nvinfer1::ISoftMaxLayer* softmax = network->addSoftMax(*box); + assert(softmax != nullptr); + std::string softmaxLayerName = "softmax_box_" + std::to_string(layerIdx); + softmax->setName(softmaxLayerName.c_str()); + softmax->setAxes(1 << 0); + box = softmax->getOutput(0); + + nvinfer1::Weights dflWt {nvinfer1::DataType::kFLOAT, nullptr, reg_max}; + + float* val = new float[reg_max]; + for (int i = 0; i < reg_max; ++i) { + val[i] = i; + } + dflWt.values = val; + + nvinfer1::IConvolutionLayer* conv = network->addConvolutionNd(*box, 1, nvinfer1::Dims{2, {1, 1}}, dflWt, + nvinfer1::Weights{}); + assert(conv != nullptr); + std::string convLayerName = "conv_box_" + std::to_string(layerIdx); + conv->setName(convLayerName.c_str()); + conv->setStrideNd(nvinfer1::Dims{2, {1, 1}}); + conv->setPaddingNd(nvinfer1::Dims{2, {0, 0}}); + box = conv->getOutput(0); + + nvinfer1::IShuffleLayer* shuffle2Box = network->addShuffle(*box); + assert(shuffle2Box != nullptr); + std::string shuffle2BoxLayerName = "shuffle2_box_" + std::to_string(layerIdx); + shuffle2Box->setName(shuffle2BoxLayerName.c_str()); + nvinfer1::Dims reshape2Dims = {2, {4, inputDims.d[1]}}; + shuffle2Box->setReshapeDimensions(reshape2Dims); + box = shuffle2Box->getOutput(0); + + nvinfer1::Dims shuffle2BoxDims = box->getDimensions(); + + nvinfer1::ISliceLayer* sliceLtBox = network->addSlice(*box, nvinfer1::Dims{2, {0, 0}}, + nvinfer1::Dims{2, {2, shuffle2BoxDims.d[1]}}, nvinfer1::Dims{2, {1, 1}}); + assert(sliceLtBox != nullptr); + std::string sliceLtBoxLayerName = "slice_lt_box_" + std::to_string(layerIdx); + sliceLtBox->setName(sliceLtBoxLayerName.c_str()); + nvinfer1::ITensor* lt = sliceLtBox->getOutput(0); + + nvinfer1::ISliceLayer* sliceRbBox = network->addSlice(*box, nvinfer1::Dims{2, {2, 0}}, + nvinfer1::Dims{2, {2, shuffle2BoxDims.d[1]}}, nvinfer1::Dims{2, {1, 1}}); + assert(sliceRbBox != nullptr); + std::string sliceRbBoxLayerName = "slice_rb_box_" + std::to_string(layerIdx); + sliceRbBox->setName(sliceRbBoxLayerName.c_str()); + nvinfer1::ITensor* rb = sliceRbBox->getOutput(0); + + int channels = 2 * shuffle2BoxDims.d[1]; + nvinfer1::Weights anchorPointsWt {nvinfer1::DataType::kFLOAT, nullptr, channels}; + val = new float[channels]; + for (int i = 0; i < channels; ++i) { + val[i] = weights[weightPtr]; + ++weightPtr; + } + anchorPointsWt.values = val; + trtWeights.push_back(anchorPointsWt); + + nvinfer1::IConstantLayer* anchorPoints = network->addConstant(nvinfer1::Dims{2, {2, shuffle2BoxDims.d[1]}}, + anchorPointsWt); + assert(anchorPoints != nullptr); + std::string anchorPointsLayerName = "anchor_points_" + std::to_string(layerIdx); + anchorPoints->setName(anchorPointsLayerName.c_str()); + nvinfer1::ITensor* anchorPointsTensor = anchorPoints->getOutput(0); + + nvinfer1::IElementWiseLayer* x1y1 = network->addElementWise(*anchorPointsTensor, *lt, + nvinfer1::ElementWiseOperation::kSUB); + assert(x1y1 != nullptr); + std::string x1y1LayerName = "x1y1_" + std::to_string(layerIdx); + x1y1->setName(x1y1LayerName.c_str()); + nvinfer1::ITensor* x1y1Tensor = x1y1->getOutput(0); + + nvinfer1::IElementWiseLayer* x2y2 = network->addElementWise(*rb, *anchorPointsTensor, + nvinfer1::ElementWiseOperation::kSUM); + assert(x2y2 != nullptr); + std::string x2y2LayerName = "x2y2_" + std::to_string(layerIdx); + x2y2->setName(x2y2LayerName.c_str()); + nvinfer1::ITensor* x2y2Tensor = x2y2->getOutput(0); + + std::vector concatBoxInputs; + concatBoxInputs.push_back(x1y1Tensor); + concatBoxInputs.push_back(x2y2Tensor); + + nvinfer1::IConcatenationLayer* concatBox = network->addConcatenation(concatBoxInputs.data(), concatBoxInputs.size()); + assert(concatBox != nullptr); + std::string concatBoxLayerName = "concat_box_" + std::to_string(layerIdx); + concatBox->setName(concatBoxLayerName.c_str()); + concatBox->setAxis(0); + box = concatBox->getOutput(0); + + channels = shuffle2BoxDims.d[1]; + nvinfer1::Weights stridePointsWt {nvinfer1::DataType::kFLOAT, nullptr, channels}; + val = new float[channels]; + for (int i = 0; i < channels; ++i) { + val[i] = weights[weightPtr]; + ++weightPtr; + } + stridePointsWt.values = val; + trtWeights.push_back(stridePointsWt); + + nvinfer1::IConstantLayer* stridePoints = network->addConstant(nvinfer1::Dims{2, {1, shuffle2BoxDims.d[1]}}, + stridePointsWt); + assert(stridePoints != nullptr); + std::string stridePointsLayerName = "stride_points_" + std::to_string(layerIdx); + stridePoints->setName(stridePointsLayerName.c_str()); + nvinfer1::ITensor* stridePointsTensor = stridePoints->getOutput(0); + + nvinfer1::IElementWiseLayer* pred = network->addElementWise(*box, *stridePointsTensor, + nvinfer1::ElementWiseOperation::kPROD); + assert(pred != nullptr); + std::string predLayerName = "pred_" + std::to_string(layerIdx); + pred->setName(predLayerName.c_str()); + box = pred->getOutput(0); + + nvinfer1::IActivationLayer* sigmoid = network->addActivation(*cls, nvinfer1::ActivationType::kSIGMOID); + assert(sigmoid != nullptr); + std::string sigmoidLayerName = "sigmoid_cls_" + std::to_string(layerIdx); + sigmoid->setName(sigmoidLayerName.c_str()); + cls = sigmoid->getOutput(0); + + std::vector concatInputs; + concatInputs.push_back(box); + concatInputs.push_back(cls); + + nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size()); + assert(concat != nullptr); + std::string concatLayerName = "concat_" + std::to_string(layerIdx); + concat->setName(concatLayerName.c_str()); + concat->setAxis(0); + output = concat->getOutput(0); + + nvinfer1::IShuffleLayer* shuffle = network->addShuffle(*output); + assert(shuffle != nullptr); + std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx); + shuffle->setName(shuffleLayerName.c_str()); + nvinfer1::Permutation permutation2; + permutation2.order[0] = 1; + permutation2.order[1] = 0; + shuffle->setFirstTranspose(permutation2); + output = shuffle->getOutput(0); + + return output; +} diff --git a/nvdsinfer_custom_impl_Yolo/layers/detect_v8_layer.h b/nvdsinfer_custom_impl_Yolo/layers/detect_v8_layer.h new file mode 100644 index 0000000..9cd9443 --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/layers/detect_v8_layer.h @@ -0,0 +1,18 @@ +/* + * Created by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#ifndef __DETECT_V8_LAYER_H__ +#define __DETECT_V8_LAYER_H__ + +#include +#include + +#include "NvInfer.h" + +nvinfer1::ITensor* detectV8Layer(int layerIdx, std::map& block, std::vector& weights, + std::vector& trtWeights, int& weightPtr, nvinfer1::ITensor* input, + nvinfer1::INetworkDefinition* network); + +#endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.cpp index 25ce603..5553ac7 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.cpp @@ -5,37 +5,34 @@ #include "implicit_layer.h" -nvinfer1::ITensor* implicitLayer( - int layerIdx, - std::map& block, - std::vector& weights, - std::vector& trtWeights, - int& weightPtr, - nvinfer1::INetworkDefinition* network) +#include + +nvinfer1::ITensor* +implicitLayer(int layerIdx, std::map& block, std::vector& weights, + std::vector& trtWeights, int& weightPtr, nvinfer1::INetworkDefinition* network) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "implicit_add" || block.at("type") == "implicit_mul"); - assert(block.find("filters") != block.end()); + assert(block.at("type") == "implicit_add" || block.at("type") == "implicit_mul"); + assert(block.find("filters") != block.end()); - int filters = std::stoi(block.at("filters")); + int filters = std::stoi(block.at("filters")); - nvinfer1::Weights convWt{nvinfer1::DataType::kFLOAT, nullptr, filters}; + nvinfer1::Weights convWt {nvinfer1::DataType::kFLOAT, nullptr, filters}; - float* val = new float[filters]; - for (int i = 0; i < filters; ++i) - { - val[i] = weights[weightPtr]; - weightPtr++; - } - convWt.values = val; - trtWeights.push_back(convWt); + float* val = new float[filters]; + for (int i = 0; i < filters; ++i) { + val[i] = weights[weightPtr]; + ++weightPtr; + } + convWt.values = val; + trtWeights.push_back(convWt); - nvinfer1::IConstantLayer* implicit = network->addConstant(nvinfer1::Dims{3, {filters, 1, 1}}, convWt); - assert(implicit != nullptr); - std::string implicitLayerName = block.at("type") + "_" + std::to_string(layerIdx); - implicit->setName(implicitLayerName.c_str()); - output = implicit->getOutput(0); + nvinfer1::IConstantLayer* implicit = network->addConstant(nvinfer1::Dims{3, {filters, 1, 1}}, convWt); + assert(implicit != nullptr); + std::string implicitLayerName = block.at("type") + "_" + std::to_string(layerIdx); + implicit->setName(implicitLayerName.c_str()); + output = implicit->getOutput(0); - return output; + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.h b/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.h index a4611c9..10e50e2 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/implicit_layer.h @@ -8,16 +8,10 @@ #include #include -#include #include "NvInfer.h" -nvinfer1::ITensor* implicitLayer( - int layerIdx, - std::map& block, - std::vector& weights, - std::vector& trtWeights, - int& weightPtr, - nvinfer1::INetworkDefinition* network); +nvinfer1::ITensor* implicitLayer(int layerIdx, std::map& block, std::vector& weights, + std::vector& trtWeights, int& weightPtr, nvinfer1::INetworkDefinition* network); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/pooling_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/pooling_layer.cpp index 9f4d59c..7ebe0f3 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/pooling_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/pooling_layer.cpp @@ -5,53 +5,50 @@ #include "pooling_layer.h" -nvinfer1::ITensor* poolingLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, +#include +#include + +nvinfer1::ITensor* +poolingLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "maxpool" || block.at("type") == "avgpool"); + assert(block.at("type") == "maxpool" || block.at("type") == "avgpool"); - if (block.at("type") == "maxpool") - { - assert(block.find("size") != block.end()); - assert(block.find("stride") != block.end()); + if (block.at("type") == "maxpool") { + assert(block.find("size") != block.end()); + assert(block.find("stride") != block.end()); - int size = std::stoi(block.at("size")); - int stride = std::stoi(block.at("stride")); + int size = std::stoi(block.at("size")); + int stride = std::stoi(block.at("stride")); - nvinfer1::IPoolingLayer* maxpool - = network->addPoolingNd(*input, nvinfer1::PoolingType::kMAX, nvinfer1::Dims{2, {size, size}}); - assert(maxpool != nullptr); - std::string maxpoolLayerName = "maxpool_" + std::to_string(layerIdx); - maxpool->setName(maxpoolLayerName.c_str()); - maxpool->setStrideNd(nvinfer1::Dims{2, {stride, stride}}); - maxpool->setPaddingNd(nvinfer1::Dims{2, {(size - 1) / 2, (size - 1) / 2}}); - if (size == 2 && stride == 1) - { - maxpool->setPrePadding(nvinfer1::Dims{2, {0, 0}}); - maxpool->setPostPadding(nvinfer1::Dims{2, {1, 1}}); - } - output = maxpool->getOutput(0); - } - else if (block.at("type") == "avgpool") - { - nvinfer1::Dims inputDims = input->getDimensions(); - nvinfer1::IPoolingLayer* avgpool = network->addPoolingNd( - *input, nvinfer1::PoolingType::kAVERAGE, nvinfer1::Dims{2, {inputDims.d[1], inputDims.d[2]}}); - assert(avgpool != nullptr); - std::string avgpoolLayerName = "avgpool_" + std::to_string(layerIdx); - avgpool->setName(avgpoolLayerName.c_str()); - output = avgpool->getOutput(0); - } - else - { - std::cerr << "Pooling not supported: " << block.at("type") << std::endl; - std::abort(); + nvinfer1::IPoolingLayer* maxpool = network->addPoolingNd(*input, nvinfer1::PoolingType::kMAX, + nvinfer1::Dims{2, {size, size}}); + assert(maxpool != nullptr); + std::string maxpoolLayerName = "maxpool_" + std::to_string(layerIdx); + maxpool->setName(maxpoolLayerName.c_str()); + maxpool->setStrideNd(nvinfer1::Dims{2, {stride, stride}}); + maxpool->setPaddingNd(nvinfer1::Dims{2, {(size - 1) / 2, (size - 1) / 2}}); + if (size == 2 && stride == 1) { + maxpool->setPrePadding(nvinfer1::Dims{2, {0, 0}}); + maxpool->setPostPadding(nvinfer1::Dims{2, {1, 1}}); } + output = maxpool->getOutput(0); + } + else if (block.at("type") == "avgpool") { + nvinfer1::Dims inputDims = input->getDimensions(); + nvinfer1::IPoolingLayer* avgpool = network->addPoolingNd(*input, nvinfer1::PoolingType::kAVERAGE, + nvinfer1::Dims{2, {inputDims.d[1], inputDims.d[2]}}); + assert(avgpool != nullptr); + std::string avgpoolLayerName = "avgpool_" + std::to_string(layerIdx); + avgpool->setName(avgpoolLayerName.c_str()); + output = avgpool->getOutput(0); + } + else { + std::cerr << "Pooling not supported: " << block.at("type") << std::endl; + assert(0); + } - return output; + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/pooling_layer.h b/nvdsinfer_custom_impl_Yolo/layers/pooling_layer.h index 0e97f24..82f1971 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/pooling_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/pooling_layer.h @@ -7,15 +7,10 @@ #define __POOLING_LAYER_H__ #include -#include -#include #include "NvInfer.h" -nvinfer1::ITensor* poolingLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, +nvinfer1::ITensor* poolingLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/reduce_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/reduce_layer.cpp index 716848b..9d91178 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/reduce_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/reduce_layer.cpp @@ -5,54 +5,50 @@ #include "reduce_layer.h" -nvinfer1::ITensor* reduceLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, +nvinfer1::ITensor* +reduceLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "reduce"); - assert(block.find("mode") != block.end()); - assert(block.find("axes") != block.end()); + assert(block.at("type") == "reduce"); + assert(block.find("mode") != block.end()); + assert(block.find("axes") != block.end()); - std::string mode = block.at("mode"); + std::string mode = block.at("mode"); - nvinfer1::ReduceOperation operation; - if (mode == "mean") - operation = nvinfer1::ReduceOperation::kAVG; + nvinfer1::ReduceOperation operation; + if (mode == "mean") + operation = nvinfer1::ReduceOperation::kAVG; - std::string strAxes = block.at("axes"); - std::vector axes; - size_t lastPos = 0, pos = 0; - while ((pos = strAxes.find(',', lastPos)) != std::string::npos) - { - int vL = std::stoi(trim(strAxes.substr(lastPos, pos - lastPos))); - axes.push_back(vL); - lastPos = pos + 1; - } - if (lastPos < strAxes.length()) - { - std::string lastV = trim(strAxes.substr(lastPos)); - if (!lastV.empty()) - axes.push_back(std::stoi(lastV)); - } - assert(!axes.empty()); - - uint32_t axisMask = 0; - for (int axis : axes) - axisMask |= 1 << axis; - - bool keepDims = false; - if (block.find("keep") != block.end()) - keepDims = std::stoi(block.at("keep")) == 1 ? true : false; + std::string strAxes = block.at("axes"); + std::vector axes; + size_t lastPos = 0, pos = 0; + while ((pos = strAxes.find(',', lastPos)) != std::string::npos) { + int vL = std::stoi(trim(strAxes.substr(lastPos, pos - lastPos))); + axes.push_back(vL); + lastPos = pos + 1; + } + if (lastPos < strAxes.length()) { + std::string lastV = trim(strAxes.substr(lastPos)); + if (!lastV.empty()) + axes.push_back(std::stoi(lastV)); + } + assert(!axes.empty()); - nvinfer1::IReduceLayer* reduce = network->addReduce(*input, operation, axisMask, keepDims); - assert(reduce != nullptr); - std::string reduceLayerName = "reduce_" + std::to_string(layerIdx); - reduce->setName(reduceLayerName.c_str()); - output = reduce->getOutput(0); + uint32_t axisMask = 0; + for (int axis : axes) + axisMask |= 1 << axis; - return output; + bool keepDims = false; + if (block.find("keep") != block.end()) + keepDims = std::stoi(block.at("keep")) == 1 ? true : false; + + nvinfer1::IReduceLayer* reduce = network->addReduce(*input, operation, axisMask, keepDims); + assert(reduce != nullptr); + std::string reduceLayerName = "reduce_" + std::to_string(layerIdx); + reduce->setName(reduceLayerName.c_str()); + output = reduce->getOutput(0); + + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/reduce_layer.h b/nvdsinfer_custom_impl_Yolo/layers/reduce_layer.h index c8330a1..e68bca2 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/reduce_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/reduce_layer.h @@ -6,13 +6,9 @@ #ifndef __REDUCE_LAYER_H__ #define __REDUCE_LAYER_H__ -#include "NvInfer.h" #include "../utils.h" -nvinfer1::ITensor* reduceLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, +nvinfer1::ITensor* reduceLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/reg_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/reg_layer.cpp index ea9be07..7d339e2 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/reg_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/reg_layer.cpp @@ -5,109 +5,105 @@ #include "reg_layer.h" -nvinfer1::ITensor* regLayer( - int layerIdx, - std::map& block, - std::vector& weights, - std::vector& trtWeights, - int& weightPtr, - nvinfer1::ITensor* input, +#include + +nvinfer1::ITensor* +regLayer(int layerIdx, std::map& block, std::vector& weights, + std::vector& trtWeights, int& weightPtr, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "reg"); + assert(block.at("type") == "reg"); - nvinfer1::IShuffleLayer* shuffle = network->addShuffle(*input); - assert(shuffle != nullptr); - std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx); - shuffle->setName(shuffleLayerName.c_str()); - nvinfer1::Permutation permutation; - permutation.order[0] = 1; - permutation.order[1] = 0; - shuffle->setFirstTranspose(permutation); - output = shuffle->getOutput(0); - nvinfer1::Dims shuffleDims = output->getDimensions(); + nvinfer1::IShuffleLayer* shuffle = network->addShuffle(*input); + assert(shuffle != nullptr); + std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx); + shuffle->setName(shuffleLayerName.c_str()); + nvinfer1::Permutation permutation; + permutation.order[0] = 1; + permutation.order[1] = 0; + shuffle->setFirstTranspose(permutation); + output = shuffle->getOutput(0); + nvinfer1::Dims shuffleDims = output->getDimensions(); - nvinfer1::ISliceLayer* sliceLt = network->addSlice( - *output, nvinfer1::Dims{2, {0, 0}}, nvinfer1::Dims{2, {shuffleDims.d[0], 2}}, nvinfer1::Dims{2, {1, 1}}); - assert(sliceLt != nullptr); - std::string sliceLtLayerName = "slice_lt_" + std::to_string(layerIdx); - sliceLt->setName(sliceLtLayerName.c_str()); - nvinfer1::ITensor* lt = sliceLt->getOutput(0); + nvinfer1::ISliceLayer* sliceLt = network->addSlice(*output, nvinfer1::Dims{2, {0, 0}}, + nvinfer1::Dims{2, {shuffleDims.d[0], 2}}, nvinfer1::Dims{2, {1, 1}}); + assert(sliceLt != nullptr); + std::string sliceLtLayerName = "slice_lt_" + std::to_string(layerIdx); + sliceLt->setName(sliceLtLayerName.c_str()); + nvinfer1::ITensor* lt = sliceLt->getOutput(0); - nvinfer1::ISliceLayer* sliceRb = network->addSlice( - *output, nvinfer1::Dims{2, {0, 2}}, nvinfer1::Dims{2, {shuffleDims.d[0], 2}}, nvinfer1::Dims{2, {1, 1}}); - assert(sliceRb != nullptr); - std::string sliceRbLayerName = "slice_rb_" + std::to_string(layerIdx); - sliceRb->setName(sliceRbLayerName.c_str()); - nvinfer1::ITensor* rb = sliceRb->getOutput(0); + nvinfer1::ISliceLayer* sliceRb = network->addSlice(*output, nvinfer1::Dims{2, {0, 2}}, + nvinfer1::Dims{2, {shuffleDims.d[0], 2}}, nvinfer1::Dims{2, {1, 1}}); + assert(sliceRb != nullptr); + std::string sliceRbLayerName = "slice_rb_" + std::to_string(layerIdx); + sliceRb->setName(sliceRbLayerName.c_str()); + nvinfer1::ITensor* rb = sliceRb->getOutput(0); - int channels = shuffleDims.d[0] * 2; - nvinfer1::Weights anchorPointsWt{nvinfer1::DataType::kFLOAT, nullptr, channels}; - float* val = new float[channels]; - for (int i = 0; i < channels; ++i) - { - val[i] = weights[weightPtr]; - weightPtr++; - } - anchorPointsWt.values = val; - trtWeights.push_back(anchorPointsWt); + int channels = shuffleDims.d[0] * 2; + nvinfer1::Weights anchorPointsWt {nvinfer1::DataType::kFLOAT, nullptr, channels}; + float* val = new float[channels]; + for (int i = 0; i < channels; ++i) { + val[i] = weights[weightPtr]; + ++weightPtr; + } + anchorPointsWt.values = val; + trtWeights.push_back(anchorPointsWt); - nvinfer1::IConstantLayer* anchorPoints = network->addConstant(nvinfer1::Dims{2, {shuffleDims.d[0], 2}}, anchorPointsWt); - assert(anchorPoints != nullptr); - std::string anchorPointsLayerName = "anchor_points_" + std::to_string(layerIdx); - anchorPoints->setName(anchorPointsLayerName.c_str()); - nvinfer1::ITensor* anchorPointsTensor = anchorPoints->getOutput(0); + nvinfer1::IConstantLayer* anchorPoints = network->addConstant(nvinfer1::Dims{2, {shuffleDims.d[0], 2}}, anchorPointsWt); + assert(anchorPoints != nullptr); + std::string anchorPointsLayerName = "anchor_points_" + std::to_string(layerIdx); + anchorPoints->setName(anchorPointsLayerName.c_str()); + nvinfer1::ITensor* anchorPointsTensor = anchorPoints->getOutput(0); - nvinfer1::IElementWiseLayer* x1y1 - = network->addElementWise(*anchorPointsTensor, *lt, nvinfer1::ElementWiseOperation::kSUB); - assert(x1y1 != nullptr); - std::string x1y1LayerName = "x1y1_" + std::to_string(layerIdx); - x1y1->setName(x1y1LayerName.c_str()); - nvinfer1::ITensor* x1y1Tensor = x1y1->getOutput(0); + nvinfer1::IElementWiseLayer* x1y1 = network->addElementWise(*anchorPointsTensor, *lt, + nvinfer1::ElementWiseOperation::kSUB); + assert(x1y1 != nullptr); + std::string x1y1LayerName = "x1y1_" + std::to_string(layerIdx); + x1y1->setName(x1y1LayerName.c_str()); + nvinfer1::ITensor* x1y1Tensor = x1y1->getOutput(0); - nvinfer1::IElementWiseLayer* x2y2 - = network->addElementWise(*rb, *anchorPointsTensor, nvinfer1::ElementWiseOperation::kSUM); - assert(x2y2 != nullptr); - std::string x2y2LayerName = "x2y2_" + std::to_string(layerIdx); - x2y2->setName(x2y2LayerName.c_str()); - nvinfer1::ITensor* x2y2Tensor = x2y2->getOutput(0); + nvinfer1::IElementWiseLayer* x2y2 = network->addElementWise(*rb, *anchorPointsTensor, + nvinfer1::ElementWiseOperation::kSUM); + assert(x2y2 != nullptr); + std::string x2y2LayerName = "x2y2_" + std::to_string(layerIdx); + x2y2->setName(x2y2LayerName.c_str()); + nvinfer1::ITensor* x2y2Tensor = x2y2->getOutput(0); - std::vector concatInputs; - concatInputs.push_back(x1y1Tensor); - concatInputs.push_back(x2y2Tensor); + std::vector concatInputs; + concatInputs.push_back(x1y1Tensor); + concatInputs.push_back(x2y2Tensor); - nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size()); - assert(concat != nullptr); - std::string concatLayerName = "concat_" + std::to_string(layerIdx); - concat->setName(concatLayerName.c_str()); - concat->setAxis(1); - output = concat->getOutput(0); + nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size()); + assert(concat != nullptr); + std::string concatLayerName = "concat_" + std::to_string(layerIdx); + concat->setName(concatLayerName.c_str()); + concat->setAxis(1); + output = concat->getOutput(0); - channels = shuffleDims.d[0]; - nvinfer1::Weights stridePointsWt{nvinfer1::DataType::kFLOAT, nullptr, channels}; - val = new float[channels]; - for (int i = 0; i < channels; ++i) - { - val[i] = weights[weightPtr]; - weightPtr++; - } - stridePointsWt.values = val; - trtWeights.push_back(stridePointsWt); + channels = shuffleDims.d[0]; + nvinfer1::Weights stridePointsWt {nvinfer1::DataType::kFLOAT, nullptr, channels}; + val = new float[channels]; + for (int i = 0; i < channels; ++i) { + val[i] = weights[weightPtr]; + ++weightPtr; + } + stridePointsWt.values = val; + trtWeights.push_back(stridePointsWt); - nvinfer1::IConstantLayer* stridePoints = network->addConstant(nvinfer1::Dims{2, {shuffleDims.d[0], 1}}, stridePointsWt); - assert(stridePoints != nullptr); - std::string stridePointsLayerName = "stride_points_" + std::to_string(layerIdx); - stridePoints->setName(stridePointsLayerName.c_str()); - nvinfer1::ITensor* stridePointsTensor = stridePoints->getOutput(0); + nvinfer1::IConstantLayer* stridePoints = network->addConstant(nvinfer1::Dims{2, {shuffleDims.d[0], 1}}, stridePointsWt); + assert(stridePoints != nullptr); + std::string stridePointsLayerName = "stride_points_" + std::to_string(layerIdx); + stridePoints->setName(stridePointsLayerName.c_str()); + nvinfer1::ITensor* stridePointsTensor = stridePoints->getOutput(0); - nvinfer1::IElementWiseLayer* pred - = network->addElementWise(*output, *stridePointsTensor, nvinfer1::ElementWiseOperation::kPROD); - assert(pred != nullptr); - std::string predLayerName = "pred_" + std::to_string(layerIdx); - pred->setName(predLayerName.c_str()); - output = pred->getOutput(0); + nvinfer1::IElementWiseLayer* pred = network->addElementWise(*output, *stridePointsTensor, + nvinfer1::ElementWiseOperation::kPROD); + assert(pred != nullptr); + std::string predLayerName = "pred_" + std::to_string(layerIdx); + pred->setName(predLayerName.c_str()); + output = pred->getOutput(0); - return output; + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/reg_layer.h b/nvdsinfer_custom_impl_Yolo/layers/reg_layer.h index b8addb3..270c659 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/reg_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/reg_layer.h @@ -8,17 +8,11 @@ #include #include -#include #include "NvInfer.h" -nvinfer1::ITensor* regLayer( - int layerIdx, - std::map& block, - std::vector& weights, - std::vector& trtWeights, - int& weightPtr, - nvinfer1::ITensor* input, +nvinfer1::ITensor* regLayer(int layerIdx, std::map& block, std::vector& weights, + std::vector& trtWeights, int& weightPtr, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/reorg_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/reorg_layer.cpp index 9633ebb..e5688c4 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/reorg_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/reorg_layer.cpp @@ -5,58 +5,55 @@ #include "reorg_layer.h" -nvinfer1::ITensor* reorgLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, +#include +#include + +nvinfer1::ITensor* +reorgLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "reorg"); + assert(block.at("type") == "reorg"); - nvinfer1::Dims inputDims = input->getDimensions(); + nvinfer1::Dims inputDims = input->getDimensions(); - nvinfer1::ISliceLayer *slice1 = network->addSlice( - *input, nvinfer1::Dims{3, {0, 0, 0}}, nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}}, - nvinfer1::Dims{3, {1, 2, 2}}); - assert(slice1 != nullptr); - std::string slice1LayerName = "slice1_" + std::to_string(layerIdx); - slice1->setName(slice1LayerName.c_str()); + nvinfer1::ISliceLayer *slice1 = network->addSlice(*input, nvinfer1::Dims{3, {0, 0, 0}}, + nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}}, nvinfer1::Dims{3, {1, 2, 2}}); + assert(slice1 != nullptr); + std::string slice1LayerName = "slice1_" + std::to_string(layerIdx); + slice1->setName(slice1LayerName.c_str()); - nvinfer1::ISliceLayer *slice2 = network->addSlice( - *input, nvinfer1::Dims{3, {0, 1, 0}}, nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}}, - nvinfer1::Dims{3, {1, 2, 2}}); - assert(slice2 != nullptr); - std::string slice2LayerName = "slice2_" + std::to_string(layerIdx); - slice2->setName(slice2LayerName.c_str()); + nvinfer1::ISliceLayer *slice2 = network->addSlice(*input, nvinfer1::Dims{3, {0, 1, 0}}, + nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}}, nvinfer1::Dims{3, {1, 2, 2}}); + assert(slice2 != nullptr); + std::string slice2LayerName = "slice2_" + std::to_string(layerIdx); + slice2->setName(slice2LayerName.c_str()); - nvinfer1::ISliceLayer *slice3 = network->addSlice( - *input, nvinfer1::Dims{3, {0, 0, 1}}, nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}}, - nvinfer1::Dims{3, {1, 2, 2}}); - assert(slice3 != nullptr); - std::string slice3LayerName = "slice3_" + std::to_string(layerIdx); - slice3->setName(slice3LayerName.c_str()); + nvinfer1::ISliceLayer *slice3 = network->addSlice(*input, nvinfer1::Dims{3, {0, 0, 1}}, + nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}}, nvinfer1::Dims{3, {1, 2, 2}}); + assert(slice3 != nullptr); + std::string slice3LayerName = "slice3_" + std::to_string(layerIdx); + slice3->setName(slice3LayerName.c_str()); - nvinfer1::ISliceLayer *slice4 = network->addSlice( - *input, nvinfer1::Dims{3, {0, 1, 1}}, nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}}, - nvinfer1::Dims{3, {1, 2, 2}}); - assert(slice4 != nullptr); - std::string slice4LayerName = "slice4_" + std::to_string(layerIdx); - slice4->setName(slice4LayerName.c_str()); + nvinfer1::ISliceLayer *slice4 = network->addSlice(*input, nvinfer1::Dims{3, {0, 1, 1}}, + nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}}, nvinfer1::Dims{3, {1, 2, 2}}); + assert(slice4 != nullptr); + std::string slice4LayerName = "slice4_" + std::to_string(layerIdx); + slice4->setName(slice4LayerName.c_str()); - std::vector concatInputs; - concatInputs.push_back(slice1->getOutput(0)); - concatInputs.push_back(slice2->getOutput(0)); - concatInputs.push_back(slice3->getOutput(0)); - concatInputs.push_back(slice4->getOutput(0)); + std::vector concatInputs; + concatInputs.push_back(slice1->getOutput(0)); + concatInputs.push_back(slice2->getOutput(0)); + concatInputs.push_back(slice3->getOutput(0)); + concatInputs.push_back(slice4->getOutput(0)); - nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size()); - assert(concat != nullptr); - std::string concatLayerName = "concat_" + std::to_string(layerIdx); - concat->setName(concatLayerName.c_str()); - concat->setAxis(0); - output = concat->getOutput(0); + nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size()); + assert(concat != nullptr); + std::string concatLayerName = "concat_" + std::to_string(layerIdx); + concat->setName(concatLayerName.c_str()); + concat->setAxis(0); + output = concat->getOutput(0); - return output; + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/reorg_layer.h b/nvdsinfer_custom_impl_Yolo/layers/reorg_layer.h index fca09fa..585b91d 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/reorg_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/reorg_layer.h @@ -3,19 +3,14 @@ * https://www.github.com/marcoslucianops */ -#ifndef __REORGV5_LAYER_H__ -#define __REORGV5_LAYER_H__ +#ifndef __REORG_LAYER_H__ +#define __REORG_LAYER_H__ #include -#include -#include #include "NvInfer.h" -nvinfer1::ITensor* reorgLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, +nvinfer1::ITensor* reorgLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/route_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/route_layer.cpp index 6115e9e..2222841 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/route_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/route_layer.cpp @@ -5,78 +5,70 @@ #include "route_layer.h" -nvinfer1::ITensor* routeLayer( - int layerIdx, - std::string& layers, - std::map& block, - std::vector tensorOutputs, - nvinfer1::INetworkDefinition* network) +nvinfer1::ITensor* +routeLayer(int layerIdx, std::string& layers, std::map& block, + std::vector tensorOutputs, nvinfer1::INetworkDefinition* network) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "route"); - assert(block.find("layers") != block.end()); + assert(block.at("type") == "route"); + assert(block.find("layers") != block.end()); - std::string strLayers = block.at("layers"); - std::vector idxLayers; - size_t lastPos = 0, pos = 0; - while ((pos = strLayers.find(',', lastPos)) != std::string::npos) - { - int vL = std::stoi(trim(strLayers.substr(lastPos, pos - lastPos))); - idxLayers.push_back(vL); - lastPos = pos + 1; - } - if (lastPos < strLayers.length()) - { - std::string lastV = trim(strLayers.substr(lastPos)); - if (!lastV.empty()) - idxLayers.push_back(std::stoi(lastV)); - } - assert (!idxLayers.empty()); - std::vector concatInputs; - for (uint i = 0; i < idxLayers.size(); ++i) - { - if (idxLayers[i] < 0) - idxLayers[i] = tensorOutputs.size() + idxLayers[i]; - assert (idxLayers[i] >= 0 && idxLayers[i] < (int)tensorOutputs.size()); - concatInputs.push_back(tensorOutputs[idxLayers[i]]); - if (i < idxLayers.size() - 1) - layers += std::to_string(idxLayers[i]) + ", "; - } - layers += std::to_string(idxLayers[idxLayers.size() - 1]); + std::string strLayers = block.at("layers"); + std::vector idxLayers; + size_t lastPos = 0, pos = 0; + while ((pos = strLayers.find(',', lastPos)) != std::string::npos) { + int vL = std::stoi(trim(strLayers.substr(lastPos, pos - lastPos))); + idxLayers.push_back(vL); + lastPos = pos + 1; + } + if (lastPos < strLayers.length()) { + std::string lastV = trim(strLayers.substr(lastPos)); + if (!lastV.empty()) + idxLayers.push_back(std::stoi(lastV)); + } + assert (!idxLayers.empty()); + std::vector concatInputs; + for (uint i = 0; i < idxLayers.size(); ++i) { + if (idxLayers[i] < 0) + idxLayers[i] = tensorOutputs.size() + idxLayers[i]; + assert (idxLayers[i] >= 0 && idxLayers[i] < (int)tensorOutputs.size()); + concatInputs.push_back(tensorOutputs[idxLayers[i]]); + if (i < idxLayers.size() - 1) + layers += std::to_string(idxLayers[i]) + ", "; + } + layers += std::to_string(idxLayers[idxLayers.size() - 1]); - if (concatInputs.size() == 1) - output = concatInputs[0]; - else { - int axis = 0; - if (block.find("axis") != block.end()) - axis = std::stoi(block.at("axis")); - if (axis < 0) - axis = concatInputs[0]->getDimensions().nbDims + axis; + if (concatInputs.size() == 1) + output = concatInputs[0]; + else { + int axis = 0; + if (block.find("axis") != block.end()) + axis = std::stoi(block.at("axis")); + if (axis < 0) + axis = concatInputs[0]->getDimensions().nbDims + axis; - nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size()); - assert(concat != nullptr); - std::string concatLayerName = "route_" + std::to_string(layerIdx); - concat->setName(concatLayerName.c_str()); - concat->setAxis(axis); - output = concat->getOutput(0); - } + nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size()); + assert(concat != nullptr); + std::string concatLayerName = "route_" + std::to_string(layerIdx); + concat->setName(concatLayerName.c_str()); + concat->setAxis(axis); + output = concat->getOutput(0); + } - if (block.find("groups") != block.end()) - { - nvinfer1::Dims prevTensorDims = output->getDimensions(); - int groups = stoi(block.at("groups")); - int group_id = stoi(block.at("group_id")); - int startSlice = (prevTensorDims.d[0] / groups) * group_id; - int channelSlice = (prevTensorDims.d[0] / groups); - nvinfer1::ISliceLayer* slice = network->addSlice( - *output, nvinfer1::Dims{3, {startSlice, 0, 0}}, - nvinfer1::Dims{3, {channelSlice, prevTensorDims.d[1], prevTensorDims.d[2]}}, nvinfer1::Dims{3, {1, 1, 1}}); - assert(slice != nullptr); - std::string sliceLayerName = "slice_" + std::to_string(layerIdx); - slice->setName(sliceLayerName.c_str()); - output = slice->getOutput(0); - } + if (block.find("groups") != block.end()) { + nvinfer1::Dims prevTensorDims = output->getDimensions(); + int groups = stoi(block.at("groups")); + int group_id = stoi(block.at("group_id")); + int startSlice = (prevTensorDims.d[0] / groups) * group_id; + int channelSlice = (prevTensorDims.d[0] / groups); + nvinfer1::ISliceLayer* slice = network->addSlice(*output, nvinfer1::Dims{3, {startSlice, 0, 0}}, + nvinfer1::Dims{3, {channelSlice, prevTensorDims.d[1], prevTensorDims.d[2]}}, nvinfer1::Dims{3, {1, 1, 1}}); + assert(slice != nullptr); + std::string sliceLayerName = "slice_" + std::to_string(layerIdx); + slice->setName(sliceLayerName.c_str()); + output = slice->getOutput(0); + } - return output; + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/route_layer.h b/nvdsinfer_custom_impl_Yolo/layers/route_layer.h index 9679365..f3103ec 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/route_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/route_layer.h @@ -6,14 +6,9 @@ #ifndef __ROUTE_LAYER_H__ #define __ROUTE_LAYER_H__ -#include "NvInfer.h" #include "../utils.h" -nvinfer1::ITensor* routeLayer( - int layerIdx, - std::string& layers, - std::map& block, - std::vector tensorOutputs, - nvinfer1::INetworkDefinition* network); +nvinfer1::ITensor* routeLayer(int layerIdx, std::string& layers, std::map& block, + std::vector tensorOutputs, nvinfer1::INetworkDefinition* network); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/shortcut_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/shortcut_layer.cpp index 3e53cbe..929f037 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/shortcut_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/shortcut_layer.cpp @@ -5,48 +5,41 @@ #include "shortcut_layer.h" -nvinfer1::ITensor* shortcutLayer( - int layerIdx, - std::string mode, - std::string activation, - std::string inputVol, - std::string shortcutVol, - std::map& block, - nvinfer1::ITensor* input, - nvinfer1::ITensor* shortcutInput, +#include + +nvinfer1::ITensor* +shortcutLayer(int layerIdx, std::string mode, std::string activation, std::string inputVol, std::string shortcutVol, + std::map& block, nvinfer1::ITensor* input, nvinfer1::ITensor* shortcutInput, nvinfer1::INetworkDefinition* network) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "shortcut"); + assert(block.at("type") == "shortcut"); - nvinfer1::ElementWiseOperation operation = nvinfer1::ElementWiseOperation::kSUM; + nvinfer1::ElementWiseOperation operation = nvinfer1::ElementWiseOperation::kSUM; - if (mode == "mul") - operation = nvinfer1::ElementWiseOperation::kPROD; + if (mode == "mul") + operation = nvinfer1::ElementWiseOperation::kPROD; - if (mode == "add" && inputVol != shortcutVol) - { - nvinfer1::ISliceLayer* slice = network->addSlice( - *shortcutInput, nvinfer1::Dims{3, {0, 0, 0}}, input->getDimensions(), nvinfer1::Dims{3, {1, 1, 1}}); - assert(slice != nullptr); - std::string sliceLayerName = "slice_" + std::to_string(layerIdx); - slice->setName(sliceLayerName.c_str()); - output = slice->getOutput(0); - } - else - { - output = shortcutInput; - } + if (mode == "add" && inputVol != shortcutVol) { + nvinfer1::ISliceLayer* slice = network->addSlice(*shortcutInput, nvinfer1::Dims{3, {0, 0, 0}}, input->getDimensions(), + nvinfer1::Dims{3, {1, 1, 1}}); + assert(slice != nullptr); + std::string sliceLayerName = "slice_" + std::to_string(layerIdx); + slice->setName(sliceLayerName.c_str()); + output = slice->getOutput(0); + } + else + output = shortcutInput; - nvinfer1::IElementWiseLayer* shortcut = network->addElementWise(*input, *output, operation); - assert(shortcut != nullptr); - std::string shortcutLayerName = "shortcut_" + std::to_string(layerIdx); - shortcut->setName(shortcutLayerName.c_str()); - output = shortcut->getOutput(0); + nvinfer1::IElementWiseLayer* shortcut = network->addElementWise(*input, *output, operation); + assert(shortcut != nullptr); + std::string shortcutLayerName = "shortcut_" + std::to_string(layerIdx); + shortcut->setName(shortcutLayerName.c_str()); + output = shortcut->getOutput(0); - output = activationLayer(layerIdx, activation, output, network); - assert(output != nullptr); + output = activationLayer(layerIdx, activation, output, network); + assert(output != nullptr); - return output; + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/shortcut_layer.h b/nvdsinfer_custom_impl_Yolo/layers/shortcut_layer.h index 22195e8..c7b2bcf 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/shortcut_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/shortcut_layer.h @@ -12,15 +12,8 @@ #include "activation_layer.h" -nvinfer1::ITensor* shortcutLayer( - int layerIdx, - std::string mode, - std::string activation, - std::string inputVol, - std::string shortcutVol, - std::map& block, - nvinfer1::ITensor* input, - nvinfer1::ITensor* shortcut, - nvinfer1::INetworkDefinition* network); +nvinfer1::ITensor* shortcutLayer(int layerIdx, std::string mode, std::string activation, std::string inputVol, + std::string shortcutVol, std::map& block, nvinfer1::ITensor* input, + nvinfer1::ITensor* shortcut, nvinfer1::INetworkDefinition* network); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/shuffle_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/shuffle_layer.cpp index a967e46..b844b50 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/shuffle_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/shuffle_layer.cpp @@ -5,119 +5,133 @@ #include "shuffle_layer.h" -nvinfer1::ITensor* shuffleLayer( - int layerIdx, - std::string& layer, - std::map& block, - nvinfer1::ITensor* input, - std::vector tensorOutputs, - nvinfer1::INetworkDefinition* network) +nvinfer1::ITensor* +shuffleLayer(int layerIdx, std::string& layer, std::map& block, nvinfer1::ITensor* input, + std::vector tensorOutputs, nvinfer1::INetworkDefinition* network) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "shuffle"); + assert(block.at("type") == "shuffle"); - nvinfer1::IShuffleLayer* shuffle = network->addShuffle(*input); - assert(shuffle != nullptr); - std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx); - shuffle->setName(shuffleLayerName.c_str()); + nvinfer1::IShuffleLayer* shuffle = network->addShuffle(*input); + assert(shuffle != nullptr); + std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx); + shuffle->setName(shuffleLayerName.c_str()); - if (block.find("reshape") != block.end()) - { - std::string strReshape = block.at("reshape"); - std::vector reshape; - size_t lastPos = 0, pos = 0; - while ((pos = strReshape.find(',', lastPos)) != std::string::npos) - { - int vL = std::stoi(trim(strReshape.substr(lastPos, pos - lastPos))); - reshape.push_back(vL); - lastPos = pos + 1; - } - if (lastPos < strReshape.length()) - { - std::string lastV = trim(strReshape.substr(lastPos)); - if (!lastV.empty()) - reshape.push_back(std::stoi(lastV)); - } - assert(!reshape.empty()); + if (block.find("reshape") != block.end()) { + int from = -1; + if (block.find("from") != block.end()) + from = std::stoi(block.at("from")); - int from = -1; - if (block.find("from") != block.end()) - from = std::stoi(block.at("from")); + if (from < 0) + from = tensorOutputs.size() + from; - if (from < 0) - from = tensorOutputs.size() + from; + layer = std::to_string(from); - layer = std::to_string(from); + nvinfer1::Dims inputTensorDims = tensorOutputs[from]->getDimensions(); - nvinfer1::Dims inputTensorDims = tensorOutputs[from]->getDimensions(); - int32_t l = inputTensorDims.d[1] * inputTensorDims.d[2]; - - nvinfer1::Dims reshapeDims; - reshapeDims.nbDims = reshape.size(); - - for (uint i = 0; i < reshape.size(); ++i) - if (reshape[i] == 0) - reshapeDims.d[i] = l; - else - reshapeDims.d[i] = reshape[i]; - - shuffle->setReshapeDimensions(reshapeDims); + std::string strReshape = block.at("reshape"); + std::vector reshape; + size_t lastPos = 0, pos = 0; + while ((pos = strReshape.find(',', lastPos)) != std::string::npos) { + std::string V = trim(strReshape.substr(lastPos, pos - lastPos)); + if (V == "c") + reshape.push_back(inputTensorDims.d[0]); + else if (V == "ch") + reshape.push_back(inputTensorDims.d[0] * inputTensorDims.d[1]); + else if (V == "cw") + reshape.push_back(inputTensorDims.d[0] * inputTensorDims.d[2]); + else if (V == "h") + reshape.push_back(inputTensorDims.d[1]); + else if (V == "hw") + reshape.push_back(inputTensorDims.d[1] * inputTensorDims.d[2]); + else if (V == "w") + reshape.push_back(inputTensorDims.d[2]); + else if (V == "chw") + reshape.push_back(inputTensorDims.d[0] * inputTensorDims.d[1] * inputTensorDims.d[2]); + else + reshape.push_back(std::stoi(V)); + lastPos = pos + 1; } - - if (block.find("transpose1") != block.end()) - { - std::string strTranspose1 = block.at("transpose1"); - std::vector transpose1; - size_t lastPos = 0, pos = 0; - while ((pos = strTranspose1.find(',', lastPos)) != std::string::npos) - { - int vL = std::stoi(trim(strTranspose1.substr(lastPos, pos - lastPos))); - transpose1.push_back(vL); - lastPos = pos + 1; - } - if (lastPos < strTranspose1.length()) - { - std::string lastV = trim(strTranspose1.substr(lastPos)); - if (!lastV.empty()) - transpose1.push_back(std::stoi(lastV)); - } - assert(!transpose1.empty()); - - nvinfer1::Permutation permutation1; - for (uint i = 0; i < transpose1.size(); ++i) - permutation1.order[i] = transpose1[i]; - - shuffle->setFirstTranspose(permutation1); + if (lastPos < strReshape.length()) { + std::string lastV = trim(strReshape.substr(lastPos)); + if (!lastV.empty()) { + if (lastV == "c") + reshape.push_back(inputTensorDims.d[0]); + else if (lastV == "ch") + reshape.push_back(inputTensorDims.d[0] * inputTensorDims.d[1]); + else if (lastV == "cw") + reshape.push_back(inputTensorDims.d[0] * inputTensorDims.d[2]); + else if (lastV == "h") + reshape.push_back(inputTensorDims.d[1]); + else if (lastV == "hw") + reshape.push_back(inputTensorDims.d[1] * inputTensorDims.d[2]); + else if (lastV == "w") + reshape.push_back(inputTensorDims.d[2]); + else if (lastV == "chw") + reshape.push_back(inputTensorDims.d[0] * inputTensorDims.d[1] * inputTensorDims.d[2]); + else + reshape.push_back(std::stoi(lastV)); + } } + assert(!reshape.empty()); - if (block.find("transpose2") != block.end()) - { - std::string strTranspose2 = block.at("transpose2"); - std::vector transpose2; - size_t lastPos = 0, pos = 0; - while ((pos = strTranspose2.find(',', lastPos)) != std::string::npos) - { - int vL = std::stoi(trim(strTranspose2.substr(lastPos, pos - lastPos))); - transpose2.push_back(vL); - lastPos = pos + 1; - } - if (lastPos < strTranspose2.length()) - { - std::string lastV = trim(strTranspose2.substr(lastPos)); - if (!lastV.empty()) - transpose2.push_back(std::stoi(lastV)); - } - assert(!transpose2.empty()); + nvinfer1::Dims reshapeDims; + reshapeDims.nbDims = reshape.size(); - nvinfer1::Permutation permutation2; - for (uint i = 0; i < transpose2.size(); ++i) - permutation2.order[i] = transpose2[i]; + for (uint i = 0; i < reshape.size(); ++i) + reshapeDims.d[i] = reshape[i]; - shuffle->setSecondTranspose(permutation2); + shuffle->setReshapeDimensions(reshapeDims); + } + + if (block.find("transpose1") != block.end()) { + std::string strTranspose1 = block.at("transpose1"); + std::vector transpose1; + size_t lastPos = 0, pos = 0; + while ((pos = strTranspose1.find(',', lastPos)) != std::string::npos) { + int vL = std::stoi(trim(strTranspose1.substr(lastPos, pos - lastPos))); + transpose1.push_back(vL); + lastPos = pos + 1; } + if (lastPos < strTranspose1.length()) { + std::string lastV = trim(strTranspose1.substr(lastPos)); + if (!lastV.empty()) + transpose1.push_back(std::stoi(lastV)); + } + assert(!transpose1.empty()); - output = shuffle->getOutput(0); + nvinfer1::Permutation permutation1; + for (uint i = 0; i < transpose1.size(); ++i) + permutation1.order[i] = transpose1[i]; - return output; + shuffle->setFirstTranspose(permutation1); + } + + if (block.find("transpose2") != block.end()) { + std::string strTranspose2 = block.at("transpose2"); + std::vector transpose2; + size_t lastPos = 0, pos = 0; + while ((pos = strTranspose2.find(',', lastPos)) != std::string::npos) { + int vL = std::stoi(trim(strTranspose2.substr(lastPos, pos - lastPos))); + transpose2.push_back(vL); + lastPos = pos + 1; + } + if (lastPos < strTranspose2.length()) { + std::string lastV = trim(strTranspose2.substr(lastPos)); + if (!lastV.empty()) + transpose2.push_back(std::stoi(lastV)); + } + assert(!transpose2.empty()); + + nvinfer1::Permutation permutation2; + for (uint i = 0; i < transpose2.size(); ++i) + permutation2.order[i] = transpose2[i]; + + shuffle->setSecondTranspose(permutation2); + } + + output = shuffle->getOutput(0); + + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/shuffle_layer.h b/nvdsinfer_custom_impl_Yolo/layers/shuffle_layer.h index 53aa3ce..2e5a4ef 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/shuffle_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/shuffle_layer.h @@ -6,15 +6,9 @@ #ifndef __SHUFFLE_LAYER_H__ #define __SHUFFLE_LAYER_H__ -#include "NvInfer.h" #include "../utils.h" -nvinfer1::ITensor* shuffleLayer( - int layerIdx, - std::string& layer, - std::map& block, - nvinfer1::ITensor* input, - std::vector tensorOutputs, - nvinfer1::INetworkDefinition* network); +nvinfer1::ITensor* shuffleLayer(int layerIdx, std::string& layer, std::map& block, + nvinfer1::ITensor* input, std::vector tensorOutputs, nvinfer1::INetworkDefinition* network); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/softmax_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/softmax_layer.cpp index cb6348d..da73810 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/softmax_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/softmax_layer.cpp @@ -5,25 +5,25 @@ #include "softmax_layer.h" -nvinfer1::ITensor* softmaxLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, +#include + +nvinfer1::ITensor* +softmaxLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "softmax"); - assert(block.find("axes") != block.end()); + assert(block.at("type") == "softmax"); + assert(block.find("axes") != block.end()); - int axes = std::stoi(block.at("axes")); + int axes = std::stoi(block.at("axes")); - nvinfer1::ISoftMaxLayer* softmax = network->addSoftMax(*input); - assert(softmax != nullptr); - std::string softmaxLayerName = "softmax_" + std::to_string(layerIdx); - softmax->setName(softmaxLayerName.c_str()); - softmax->setAxes(1 << axes); - output = softmax->getOutput(0); + nvinfer1::ISoftMaxLayer* softmax = network->addSoftMax(*input); + assert(softmax != nullptr); + std::string softmaxLayerName = "softmax_" + std::to_string(layerIdx); + softmax->setName(softmaxLayerName.c_str()); + softmax->setAxes(1 << axes); + output = softmax->getOutput(0); - return output; + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/softmax_layer.h b/nvdsinfer_custom_impl_Yolo/layers/softmax_layer.h index 0ca208e..62ddf50 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/softmax_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/softmax_layer.h @@ -7,14 +7,10 @@ #define __SOFTMAX_LAYER_H__ #include -#include #include "NvInfer.h" -nvinfer1::ITensor* softmaxLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, +nvinfer1::ITensor* softmaxLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network); #endif diff --git a/nvdsinfer_custom_impl_Yolo/layers/upsample_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/upsample_layer.cpp index 4e3614c..e5e1caa 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/upsample_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/upsample_layer.cpp @@ -5,28 +5,28 @@ #include "upsample_layer.h" -nvinfer1::ITensor* upsampleLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, +#include + +nvinfer1::ITensor* +upsampleLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network) { - nvinfer1::ITensor* output; + nvinfer1::ITensor* output; - assert(block.at("type") == "upsample"); - assert(block.find("stride") != block.end()); + assert(block.at("type") == "upsample"); + assert(block.find("stride") != block.end()); - int stride = std::stoi(block.at("stride")); + int stride = std::stoi(block.at("stride")); - float scale[3] = {1, static_cast(stride), static_cast(stride)}; + float scale[3] = {1, static_cast(stride), static_cast(stride)}; - nvinfer1::IResizeLayer* resize = network->addResize(*input); - assert(resize != nullptr); - std::string resizeLayerName = "upsample_" + std::to_string(layerIdx); - resize->setName(resizeLayerName.c_str()); - resize->setResizeMode(nvinfer1::ResizeMode::kNEAREST); - resize->setScales(scale, 3); - output = resize->getOutput(0); + nvinfer1::IResizeLayer* resize = network->addResize(*input); + assert(resize != nullptr); + std::string resizeLayerName = "upsample_" + std::to_string(layerIdx); + resize->setName(resizeLayerName.c_str()); + resize->setResizeMode(nvinfer1::ResizeMode::kNEAREST); + resize->setScales(scale, 3); + output = resize->getOutput(0); - return output; + return output; } diff --git a/nvdsinfer_custom_impl_Yolo/layers/upsample_layer.h b/nvdsinfer_custom_impl_Yolo/layers/upsample_layer.h index 89e69bf..546db2f 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/upsample_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/upsample_layer.h @@ -7,14 +7,10 @@ #define __UPSAMPLE_LAYER_H__ #include -#include #include "NvInfer.h" -nvinfer1::ITensor* upsampleLayer( - int layerIdx, - std::map& block, - nvinfer1::ITensor* input, +nvinfer1::ITensor* upsampleLayer(int layerIdx, std::map& block, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network); #endif diff --git a/nvdsinfer_custom_impl_Yolo/nvdsinfer_yolo_engine.cpp b/nvdsinfer_custom_impl_Yolo/nvdsinfer_yolo_engine.cpp index e2cfbf5..3510feb 100644 --- a/nvdsinfer_custom_impl_Yolo/nvdsinfer_yolo_engine.cpp +++ b/nvdsinfer_custom_impl_Yolo/nvdsinfer_yolo_engine.cpp @@ -23,94 +23,87 @@ * https://www.github.com/marcoslucianops */ +#include + #include "nvdsinfer_custom_impl.h" #include "nvdsinfer_context.h" -#include "yoloPlugins.h" -#include "yolo.h" -#include +#include "yolo.h" #define USE_CUDA_ENGINE_GET_API 1 -static bool getYoloNetworkInfo(NetworkInfo &networkInfo, const NvDsInferContextInitParams* initParams) +static bool +getYoloNetworkInfo(NetworkInfo& networkInfo, const NvDsInferContextInitParams* initParams) { - std::string yoloCfg = initParams->customNetworkConfigFilePath; - std::string yoloType; + std::string yoloCfg = initParams->customNetworkConfigFilePath; + std::string yoloType; - std::transform(yoloCfg.begin(), yoloCfg.end(), yoloCfg.begin(), [] (uint8_t c) { - return std::tolower(c); - }); + std::transform(yoloCfg.begin(), yoloCfg.end(), yoloCfg.begin(), [] (uint8_t c) { + return std::tolower(c); + }); - yoloType = yoloCfg.substr(0, yoloCfg.find(".cfg")); + yoloType = yoloCfg.substr(0, yoloCfg.find(".cfg")); - networkInfo.inputBlobName = "data"; - networkInfo.networkType = yoloType; - networkInfo.configFilePath = initParams->customNetworkConfigFilePath; - networkInfo.wtsFilePath = initParams->modelFilePath; - networkInfo.int8CalibPath = initParams->int8CalibrationFilePath; - networkInfo.deviceType = (initParams->useDLA ? "kDLA" : "kGPU"); - networkInfo.numDetectedClasses = initParams->numDetectedClasses; - networkInfo.clusterMode = initParams->clusterMode; - networkInfo.scoreThreshold = initParams->perClassDetectionParams->preClusterThreshold; + networkInfo.inputBlobName = "data"; + networkInfo.networkType = yoloType; + networkInfo.configFilePath = initParams->customNetworkConfigFilePath; + networkInfo.wtsFilePath = initParams->modelFilePath; + networkInfo.int8CalibPath = initParams->int8CalibrationFilePath; + networkInfo.deviceType = (initParams->useDLA ? "kDLA" : "kGPU"); + networkInfo.numDetectedClasses = initParams->numDetectedClasses; + networkInfo.clusterMode = initParams->clusterMode; + networkInfo.scoreThreshold = initParams->perClassDetectionParams->preClusterThreshold; - if (initParams->networkMode == 0) - networkInfo.networkMode = "FP32"; - else if (initParams->networkMode == 1) - networkInfo.networkMode = "INT8"; - else if (initParams->networkMode == 2) - networkInfo.networkMode = "FP16"; + if (initParams->networkMode == 0) + networkInfo.networkMode = "FP32"; + else if (initParams->networkMode == 1) + networkInfo.networkMode = "INT8"; + else if (initParams->networkMode == 2) + networkInfo.networkMode = "FP16"; - if (networkInfo.configFilePath.empty() || networkInfo.wtsFilePath.empty()) - { - std::cerr << "YOLO config file or weights file is not specified\n" << std::endl; - return false; - } + if (networkInfo.configFilePath.empty() || networkInfo.wtsFilePath.empty()) { + std::cerr << "YOLO config file or weights file is not specified\n" << std::endl; + return false; + } - if (!fileExists(networkInfo.configFilePath) || !fileExists(networkInfo.wtsFilePath)) - { - std::cerr << "YOLO config file or weights file is not exist\n" << std::endl; - return false; - } + if (!fileExists(networkInfo.configFilePath) || !fileExists(networkInfo.wtsFilePath)) { + std::cerr << "YOLO config file or weights file is not exist\n" << std::endl; + return false; + } - return true; + return true; } #if !USE_CUDA_ENGINE_GET_API -IModelParser* NvDsInferCreateModelParser( - const NvDsInferContextInitParams* initParams) { - NetworkInfo networkInfo; - if (!getYoloNetworkInfo(networkInfo, initParams)) - return nullptr; +IModelParser* +NvDsInferCreateModelParser(const NvDsInferContextInitParams* initParams) +{ + NetworkInfo networkInfo; + if (!getYoloNetworkInfo(networkInfo, initParams)) + return nullptr; - return new Yolo(networkInfo); + return new Yolo(networkInfo); } #else -extern "C" -bool NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder * const builder, - nvinfer1::IBuilderConfig * const builderConfig, - const NvDsInferContextInitParams * const initParams, - nvinfer1::DataType dataType, - nvinfer1::ICudaEngine *& cudaEngine); +extern "C" bool +NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder* const builder, nvinfer1::IBuilderConfig* const builderConfig, + const NvDsInferContextInitParams* const initParams, nvinfer1::DataType dataType, nvinfer1::ICudaEngine*& cudaEngine); -extern "C" -bool NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder * const builder, - nvinfer1::IBuilderConfig * const builderConfig, - const NvDsInferContextInitParams * const initParams, - nvinfer1::DataType dataType, - nvinfer1::ICudaEngine *& cudaEngine) +extern "C" bool +NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder* const builder, nvinfer1::IBuilderConfig* const builderConfig, + const NvDsInferContextInitParams* const initParams, nvinfer1::DataType dataType, nvinfer1::ICudaEngine*& cudaEngine) { - NetworkInfo networkInfo; - if (!getYoloNetworkInfo(networkInfo, initParams)) - return false; + NetworkInfo networkInfo; + if (!getYoloNetworkInfo(networkInfo, initParams)) + return false; - Yolo yolo(networkInfo); - cudaEngine = yolo.createEngine (builder, builderConfig); - if (cudaEngine == nullptr) - { - std::cerr << "Failed to build CUDA engine on " << networkInfo.configFilePath << std::endl; - return false; - } + Yolo yolo(networkInfo); + cudaEngine = yolo.createEngine(builder, builderConfig); + if (cudaEngine == nullptr) { + std::cerr << "Failed to build CUDA engine on " << networkInfo.configFilePath << std::endl; + return false; + } - return true; + return true; } #endif diff --git a/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp b/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp index cb6976c..eefa76d 100644 --- a/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp +++ b/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp @@ -23,118 +23,103 @@ * https://www.github.com/marcoslucianops */ -#include -#include -#include #include "nvdsinfer_custom_impl.h" -#include "utils.h" +#include "utils.h" #include "yoloPlugins.h" -extern "C" bool NvDsInferParseYolo( - std::vector const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo, +extern "C" bool +NvDsInferParseYolo(std::vector const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo, NvDsInferParseDetectionParams const& detectionParams, std::vector& objectList); -static NvDsInferParseObjectInfo convertBBox( - const float& bx1, const float& by1, const float& bx2, const float& by2, const uint& netW, const uint& netH) +static NvDsInferParseObjectInfo +convertBBox(const float& bx1, const float& by1, const float& bx2, const float& by2, const uint& netW, const uint& netH) { - NvDsInferParseObjectInfo b; + NvDsInferParseObjectInfo b; - float x1 = bx1; - float y1 = by1; - float x2 = bx2; - float y2 = by2; + float x1 = bx1; + float y1 = by1; + float x2 = bx2; + float y2 = by2; - x1 = clamp(x1, 0, netW); - y1 = clamp(y1, 0, netH); - x2 = clamp(x2, 0, netW); - y2 = clamp(y2, 0, netH); + x1 = clamp(x1, 0, netW); + y1 = clamp(y1, 0, netH); + x2 = clamp(x2, 0, netW); + y2 = clamp(y2, 0, netH); - b.left = x1; - b.width = clamp(x2 - x1, 0, netW); - b.top = y1; - b.height = clamp(y2 - y1, 0, netH); + b.left = x1; + b.width = clamp(x2 - x1, 0, netW); + b.top = y1; + b.height = clamp(y2 - y1, 0, netH); - return b; + return b; } -static void addBBoxProposal( - const float bx1, const float by1, const float bx2, const float by2, const uint& netW, const uint& netH, +static void +addBBoxProposal(const float bx1, const float by1, const float bx2, const float by2, const uint& netW, const uint& netH, const int maxIndex, const float maxProb, std::vector& binfo) { - NvDsInferParseObjectInfo bbi = convertBBox(bx1, by1, bx2, by2, netW, netH); - if (bbi.width < 1 || bbi.height < 1) return; + NvDsInferParseObjectInfo bbi = convertBBox(bx1, by1, bx2, by2, netW, netH); + if (bbi.width < 1 || bbi.height < 1) return; - bbi.detectionConfidence = maxProb; - bbi.classId = maxIndex; - binfo.push_back(bbi); + bbi.detectionConfidence = maxProb; + bbi.classId = maxIndex; + binfo.push_back(bbi); } -static std::vector decodeYoloTensor( - const int* counts, const float* boxes, const float* scores, const int* classes, const uint& netW, const uint& netH) +static std::vector +decodeYoloTensor(const int* counts, const float* boxes, const float* scores, const int* classes, const uint& netW, + const uint& netH) { - std::vector binfo; + std::vector binfo; - uint numBoxes = counts[0]; - for (uint b = 0; b < numBoxes; ++b) - { - float bx1 = boxes[b * 4 + 0]; - float by1 = boxes[b * 4 + 1]; - float bx2 = boxes[b * 4 + 2]; - float by2 = boxes[b * 4 + 3]; + uint numBoxes = counts[0]; + for (uint b = 0; b < numBoxes; ++b) { + float bx1 = boxes[b * 4 + 0]; + float by1 = boxes[b * 4 + 1]; + float bx2 = boxes[b * 4 + 2]; + float by2 = boxes[b * 4 + 3]; - float maxProb = scores[b]; - int maxIndex = classes[b]; + float maxProb = scores[b]; + int maxIndex = classes[b]; - addBBoxProposal(bx1, by1, bx2, by2, netW, netH, maxIndex, maxProb, binfo); - } - return binfo; + addBBoxProposal(bx1, by1, bx2, by2, netW, netH, maxIndex, maxProb, binfo); + } + return binfo; } -static bool NvDsInferParseCustomYolo( - std::vector const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo, - NvDsInferParseDetectionParams const& detectionParams, std::vector& objectList, - const uint &numClasses) -{ - if (outputLayersInfo.empty()) - { - std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl; - return false; - } - - if (numClasses != detectionParams.numClassesConfigured) - { - std::cerr << "WARNING: Num classes mismatch. Configured: " << detectionParams.numClassesConfigured - << ", detected by network: " << numClasses << std::endl; - } - - std::vector objects; - - const NvDsInferLayerInfo &counts = outputLayersInfo[0]; - const NvDsInferLayerInfo &boxes = outputLayersInfo[1]; - const NvDsInferLayerInfo &scores = outputLayersInfo[2]; - const NvDsInferLayerInfo &classes = outputLayersInfo[3]; - - std::vector outObjs = - decodeYoloTensor( - (const int*)(counts.buffer), (const float*)(boxes.buffer), (const float*)(scores.buffer), - (const int*)(classes.buffer), networkInfo.width, networkInfo.height); - - objects.insert(objects.end(), outObjs.begin(), outObjs.end()); - - objectList = objects; - - return true; -} - -extern "C" bool NvDsInferParseYolo( - std::vector const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo, +static bool +NvDsInferParseCustomYolo(std::vector const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo, NvDsInferParseDetectionParams const& detectionParams, std::vector& objectList) { - int num_classes = kNUM_CLASSES; + if (outputLayersInfo.empty()) { + std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl; + return false; + } - return NvDsInferParseCustomYolo ( - outputLayersInfo, networkInfo, detectionParams, objectList, num_classes); + std::vector objects; + + const NvDsInferLayerInfo& counts = outputLayersInfo[0]; + const NvDsInferLayerInfo& boxes = outputLayersInfo[1]; + const NvDsInferLayerInfo& scores = outputLayersInfo[2]; + const NvDsInferLayerInfo& classes = outputLayersInfo[3]; + + std::vector outObjs = decodeYoloTensor((const int*) (counts.buffer), + (const float*) (boxes.buffer), (const float*) (scores.buffer), (const int*) (classes.buffer), networkInfo.width, + networkInfo.height); + + objects.insert(objects.end(), outObjs.begin(), outObjs.end()); + + objectList = objects; + + return true; +} + +extern "C" bool +NvDsInferParseYolo(std::vector const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo, + NvDsInferParseDetectionParams const& detectionParams, std::vector& objectList) +{ + return NvDsInferParseCustomYolo(outputLayersInfo, networkInfo, detectionParams, objectList); } CHECK_CUSTOM_PARSE_FUNC_PROTOTYPE(NvDsInferParseYolo); diff --git a/nvdsinfer_custom_impl_Yolo/utils.cpp b/nvdsinfer_custom_impl_Yolo/utils.cpp index a91739a..1e1689f 100644 --- a/nvdsinfer_custom_impl_Yolo/utils.cpp +++ b/nvdsinfer_custom_impl_Yolo/utils.cpp @@ -25,133 +25,137 @@ #include "utils.h" -#include #include #include -#include +#include -static void leftTrim(std::string& s) +static void +leftTrim(std::string& s) { - s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !isspace(ch); })); + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !isspace(ch); })); } -static void rightTrim(std::string& s) +static void +rightTrim(std::string& s) { - s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !isspace(ch); }).base(), s.end()); + s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !isspace(ch); }).base(), s.end()); } -std::string trim(std::string s) +std::string +trim(std::string s) { - leftTrim(s); - rightTrim(s); - return s; + leftTrim(s); + rightTrim(s); + return s; } -float clamp(const float val, const float minVal, const float maxVal) +float +clamp(const float val, const float minVal, const float maxVal) { - assert(minVal <= maxVal); - return std::min(maxVal, std::max(minVal, val)); + assert(minVal <= maxVal); + return std::min(maxVal, std::max(minVal, val)); } -bool fileExists(const std::string fileName, bool verbose) +bool +fileExists(const std::string fileName, bool verbose) { - if (!std::experimental::filesystem::exists(std::experimental::filesystem::path(fileName))) - { - if (verbose) std::cout << "\nFile does not exist: " << fileName << std::endl; - return false; + if (!std::experimental::filesystem::exists(std::experimental::filesystem::path(fileName))) { + if (verbose) + std::cout << "\nFile does not exist: " << fileName << std::endl; + return false; + } + return true; +} + +std::vector +loadWeights(const std::string weightsFilePath, const std::string& networkType) +{ + assert(fileExists(weightsFilePath)); + std::cout << "\nLoading pre-trained weights" << std::endl; + + std::vector weights; + + if (weightsFilePath.find(".weights") != std::string::npos) { + std::ifstream file(weightsFilePath, std::ios_base::binary); + assert(file.good()); + std::string line; + + if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos) { + // Remove 4 int32 bytes of data from the stream belonging to the header + file.ignore(4 * 4); } - return true; -} - -std::vector loadWeights(const std::string weightsFilePath, const std::string& networkType) -{ - assert(fileExists(weightsFilePath)); - std::cout << "\nLoading pre-trained weights" << std::endl; - - std::vector weights; - - if (weightsFilePath.find(".weights") != std::string::npos) { - std::ifstream file(weightsFilePath, std::ios_base::binary); - assert(file.good()); - std::string line; - - if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos) - { - // Remove 4 int32 bytes of data from the stream belonging to the header - file.ignore(4 * 4); - } - else - { - // Remove 5 int32 bytes of data from the stream belonging to the header - file.ignore(4 * 5); - } - - char floatWeight[4]; - while (!file.eof()) - { - file.read(floatWeight, 4); - assert(file.gcount() == 4); - weights.push_back(*reinterpret_cast(floatWeight)); - if (file.peek() == std::istream::traits_type::eof()) break; - } - } - - else if (weightsFilePath.find(".wts") != std::string::npos) { - std::ifstream file(weightsFilePath); - assert(file.good()); - int32_t count; - file >> count; - assert(count > 0 && "\nInvalid .wts file."); - - uint32_t floatWeight; - std::string name; - uint32_t size; - - while (count--) { - file >> name >> std::dec >> size; - for (uint32_t x = 0, y = size; x < y; ++x) - { - file >> std::hex >> floatWeight; - weights.push_back(*reinterpret_cast(&floatWeight)); - }; - } - } - else { - std::cerr << "\nFile " << weightsFilePath << " is not supported" << std::endl; - std::abort(); + // Remove 5 int32 bytes of data from the stream belonging to the header + file.ignore(4 * 5); } - std::cout << "Loading weights of " << networkType << " complete" - << std::endl; - std::cout << "Total weights read: " << weights.size() << std::endl; - return weights; + char floatWeight[4]; + while (!file.eof()) { + file.read(floatWeight, 4); + assert(file.gcount() == 4); + weights.push_back(*reinterpret_cast(floatWeight)); + if (file.peek() == std::istream::traits_type::eof()) + break; + } + } + else if (weightsFilePath.find(".wts") != std::string::npos) { + std::ifstream file(weightsFilePath); + assert(file.good()); + int32_t count; + file >> count; + assert(count > 0 && "\nInvalid .wts file."); + + uint32_t floatWeight; + std::string name; + uint32_t size; + + while (count--) { + file >> name >> std::dec >> size; + for (uint32_t x = 0, y = size; x < y; ++x) { + file >> std::hex >> floatWeight; + weights.push_back(*reinterpret_cast(&floatWeight)); + }; + } + } + else { + std::cerr << "\nFile " << weightsFilePath << " is not supported" << std::endl; + assert(0); + } + + std::cout << "Loading weights of " << networkType << " complete" << std::endl; + std::cout << "Total weights read: " << weights.size() << std::endl; + + return weights; } -std::string dimsToString(const nvinfer1::Dims d) +std::string +dimsToString(const nvinfer1::Dims d) { - std::stringstream s; - assert(d.nbDims >= 1); - s << "["; - for (int i = 0; i < d.nbDims - 1; ++i) - s << d.d[i] << ", "; - s << d.d[d.nbDims - 1] << "]"; + assert(d.nbDims >= 1); - return s.str(); + std::stringstream s; + s << "["; + for (int i = 0; i < d.nbDims - 1; ++i) + s << d.d[i] << ", "; + s << d.d[d.nbDims - 1] << "]"; + + return s.str(); } -int getNumChannels(nvinfer1::ITensor* t) +int +getNumChannels(nvinfer1::ITensor* t) { - nvinfer1::Dims d = t->getDimensions(); - assert(d.nbDims == 3); + nvinfer1::Dims d = t->getDimensions(); + assert(d.nbDims == 3); - return d.d[0]; + return d.d[0]; } -void printLayerInfo( - std::string layerIndex, std::string layerName, std::string layerInput, std::string layerOutput, std::string weightPtr) +void +printLayerInfo(std::string layerIndex, std::string layerName, std::string layerInput, std::string layerOutput, + std::string weightPtr) { - std::cout << std::setw(8) << std::left << layerIndex << std::setw(30) << std::left << layerName; - std::cout << std::setw(20) << std::left << layerInput << std::setw(20) << std::left << layerOutput; - std::cout << weightPtr << std::endl; + std::cout << std::setw(8) << std::left << layerIndex << std::setw(30) << std::left << layerName; + std::cout << std::setw(20) << std::left << layerInput << std::setw(20) << std::left << layerOutput; + std::cout << weightPtr << std::endl; } diff --git a/nvdsinfer_custom_impl_Yolo/utils.h b/nvdsinfer_custom_impl_Yolo/utils.h index 6b124eb..f50f954 100644 --- a/nvdsinfer_custom_impl_Yolo/utils.h +++ b/nvdsinfer_custom_impl_Yolo/utils.h @@ -23,7 +23,6 @@ * https://www.github.com/marcoslucianops */ - #ifndef __UTILS_H__ #define __UTILS_H__ @@ -36,11 +35,17 @@ #include "NvInfer.h" std::string trim(std::string s); + float clamp(const float val, const float minVal, const float maxVal); + bool fileExists(const std::string fileName, bool verbose = true); + std::vector loadWeights(const std::string weightsFilePath, const std::string& networkType); + std::string dimsToString(const nvinfer1::Dims d); + int getNumChannels(nvinfer1::ITensor* t); + void printLayerInfo( std::string layerIndex, std::string layerName, std::string layerInput, std::string layerOutput, std::string weightPtr); diff --git a/nvdsinfer_custom_impl_Yolo/yolo.cpp b/nvdsinfer_custom_impl_Yolo/yolo.cpp index f5dc68e..f412df6 100644 --- a/nvdsinfer_custom_impl_Yolo/yolo.cpp +++ b/nvdsinfer_custom_impl_Yolo/yolo.cpp @@ -25,620 +25,576 @@ #include "yolo.h" #include "yoloPlugins.h" -#include #ifdef OPENCV #include "calibrator.h" #endif -Yolo::Yolo(const NetworkInfo& networkInfo) - : m_InputBlobName(networkInfo.inputBlobName), - m_NetworkType(networkInfo.networkType), - m_ConfigFilePath(networkInfo.configFilePath), - m_WtsFilePath(networkInfo.wtsFilePath), - m_Int8CalibPath(networkInfo.int8CalibPath), - m_DeviceType(networkInfo.deviceType), - m_NumDetectedClasses(networkInfo.numDetectedClasses), - m_ClusterMode(networkInfo.clusterMode), - m_NetworkMode(networkInfo.networkMode), - m_ScoreThreshold(networkInfo.scoreThreshold), - m_InputH(0), - m_InputW(0), - m_InputC(0), - m_InputSize(0), - m_NumClasses(0), - m_LetterBox(0), - m_NewCoords(0), - m_YoloCount(0) -{} +Yolo::Yolo(const NetworkInfo& networkInfo) : m_InputBlobName(networkInfo.inputBlobName), + m_NetworkType(networkInfo.networkType), m_ConfigFilePath(networkInfo.configFilePath), + m_WtsFilePath(networkInfo.wtsFilePath), m_Int8CalibPath(networkInfo.int8CalibPath), m_DeviceType(networkInfo.deviceType), + m_NumDetectedClasses(networkInfo.numDetectedClasses), m_ClusterMode(networkInfo.clusterMode), + m_NetworkMode(networkInfo.networkMode), m_ScoreThreshold(networkInfo.scoreThreshold), m_InputH(0), m_InputW(0), + m_InputC(0), m_InputSize(0), m_NumClasses(0), m_LetterBox(0), m_NewCoords(0), m_YoloCount(0) +{ +} Yolo::~Yolo() { - destroyNetworkUtils(); + destroyNetworkUtils(); } -nvinfer1::ICudaEngine *Yolo::createEngine(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config) +nvinfer1::ICudaEngine* +Yolo::createEngine(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config) { - assert (builder); + assert (builder); - m_ConfigBlocks = parseConfigFile(m_ConfigFilePath); - parseConfigBlocks(); - - nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0); - if (parseModel(*network) != NVDSINFER_SUCCESS) - { - delete network; - return nullptr; - } - - std::cout << "Building the TensorRT Engine\n" << std::endl; - - if (m_NumClasses != m_NumDetectedClasses) - { - std::cout << "NOTE: Number of classes mismatch, make sure to set num-detected-classes=" << m_NumClasses - << " in config_infer file\n" << std::endl; - } - if (m_LetterBox == 1) - { - std::cout << "NOTE: letter_box is set in cfg file, make sure to set maintain-aspect-ratio=1 in config_infer file" - << " to get better accuracy\n" << std::endl; - } - if (m_ClusterMode != 2) - { - std::cout << "NOTE: Wrong cluster-mode is set, make sure to set cluster-mode=2 in config_infer file\n" - << std::endl; - } - - if (m_NetworkMode == "INT8" && !fileExists(m_Int8CalibPath)) - { - assert(builder->platformHasFastInt8()); -#ifdef OPENCV - std::string calib_image_list; - int calib_batch_size; - if (getenv("INT8_CALIB_IMG_PATH")) - calib_image_list = getenv("INT8_CALIB_IMG_PATH"); - else - { - std::cerr << "INT8_CALIB_IMG_PATH not set" << std::endl; - std::abort(); - } - if (getenv("INT8_CALIB_BATCH_SIZE")) - calib_batch_size = std::stoi(getenv("INT8_CALIB_BATCH_SIZE")); - else - { - std::cerr << "INT8_CALIB_BATCH_SIZE not set" << std::endl; - std::abort(); - } - nvinfer1::Int8EntropyCalibrator2 *calibrator = new nvinfer1::Int8EntropyCalibrator2( - calib_batch_size, m_InputC, m_InputH, m_InputW, m_LetterBox, calib_image_list, m_Int8CalibPath); - config->setFlag(nvinfer1::BuilderFlag::kINT8); - config->setInt8Calibrator(calibrator); -#else - std::cerr << "OpenCV is required to run INT8 calibrator\n" << std::endl; - assert(0); -#endif - } - - nvinfer1::ICudaEngine *engine = builder->buildEngineWithConfig(*network, *config); - if (engine) - std::cout << "Building complete\n" << std::endl; - else - std::cerr << "Building engine failed\n" << std::endl; + m_ConfigBlocks = parseConfigFile(m_ConfigFilePath); + parseConfigBlocks(); + nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0); + if (parseModel(*network) != NVDSINFER_SUCCESS) { delete network; - return engine; + return nullptr; + } + + std::cout << "Building the TensorRT Engine\n" << std::endl; + + if (m_NumClasses != m_NumDetectedClasses) { + std::cout << "NOTE: Number of classes mismatch, make sure to set num-detected-classes=" << m_NumClasses + << " in config_infer file\n" << std::endl; + } + if (m_LetterBox == 1) { + std::cout << "NOTE: letter_box is set in cfg file, make sure to set maintain-aspect-ratio=1 in config_infer file" + << " to get better accuracy\n" << std::endl; + } + if (m_ClusterMode != 2) { + std::cout << "NOTE: Wrong cluster-mode is set, make sure to set cluster-mode=2 in config_infer file\n" << std::endl; + } + + if (m_NetworkMode == "INT8" && !fileExists(m_Int8CalibPath)) { + assert(builder->platformHasFastInt8()); +#ifdef OPENCV + std::string calib_image_list; + int calib_batch_size; + if (getenv("INT8_CALIB_IMG_PATH")) + calib_image_list = getenv("INT8_CALIB_IMG_PATH"); + else { + std::cerr << "INT8_CALIB_IMG_PATH not set" << std::endl; + assert(0); + } + if (getenv("INT8_CALIB_BATCH_SIZE")) + calib_batch_size = std::stoi(getenv("INT8_CALIB_BATCH_SIZE")); + else { + std::cerr << "INT8_CALIB_BATCH_SIZE not set" << std::endl; + assert(0); + } + nvinfer1::IInt8EntropyCalibrator2 *calibrator = new Int8EntropyCalibrator2(calib_batch_size, m_InputC, m_InputH, + m_InputW, m_LetterBox, calib_image_list, m_Int8CalibPath); + config->setFlag(nvinfer1::BuilderFlag::kINT8); + config->setInt8Calibrator(calibrator); +#else + std::cerr << "OpenCV is required to run INT8 calibrator\n" << std::endl; + assert(0); +#endif + } + + nvinfer1::ICudaEngine *engine = builder->buildEngineWithConfig(*network, *config); + if (engine) + std::cout << "Building complete\n" << std::endl; + else + std::cerr << "Building engine failed\n" << std::endl; + + delete network; + return engine; } -NvDsInferStatus Yolo::parseModel(nvinfer1::INetworkDefinition& network) { - destroyNetworkUtils(); +NvDsInferStatus +Yolo::parseModel(nvinfer1::INetworkDefinition& network) { + destroyNetworkUtils(); - std::vector weights = loadWeights(m_WtsFilePath, m_NetworkType); - std::cout << "Building YOLO network\n" << std::endl; - NvDsInferStatus status = buildYoloNetwork(weights, network); + std::vector weights = loadWeights(m_WtsFilePath, m_NetworkType); + std::cout << "Building YOLO network\n" << std::endl; + NvDsInferStatus status = buildYoloNetwork(weights, network); - if (status == NVDSINFER_SUCCESS) - std::cout << "Building YOLO network complete" << std::endl; - else - std::cerr << "Building YOLO network failed" << std::endl; + if (status == NVDSINFER_SUCCESS) + std::cout << "Building YOLO network complete" << std::endl; + else + std::cerr << "Building YOLO network failed" << std::endl; - return status; + return status; } -NvDsInferStatus Yolo::buildYoloNetwork(std::vector& weights, nvinfer1::INetworkDefinition& network) +NvDsInferStatus +Yolo::buildYoloNetwork(std::vector& weights, nvinfer1::INetworkDefinition& network) { - int weightPtr = 0; + int weightPtr = 0; - std::string weightsType; - if (m_WtsFilePath.find(".weights") != std::string::npos) - weightsType = "weights"; - else - weightsType = "wts"; + std::string weightsType = "wts"; + if (m_WtsFilePath.find(".weights") != std::string::npos) + weightsType = "weights"; - float eps = 1.0e-5; - if (m_NetworkType.find("yolov5") != std::string::npos || m_NetworkType.find("yolov7") != std::string::npos) - eps = 1.0e-3; - else if (m_NetworkType.find("yolor") != std::string::npos) - eps = 1.0e-4; + float eps = 1.0e-5; + if (m_NetworkType.find("yolov5") != std::string::npos || m_NetworkType.find("yolov7") != std::string::npos || + m_NetworkType.find("yolov8") != std::string::npos) + eps = 1.0e-3; + else if (m_NetworkType.find("yolor") != std::string::npos) + eps = 1.0e-4; - nvinfer1::ITensor* data = network.addInput( - m_InputBlobName.c_str(), nvinfer1::DataType::kFLOAT, - nvinfer1::Dims{3, {static_cast(m_InputC), static_cast(m_InputH), static_cast(m_InputW)}}); - assert(data != nullptr && data->getDimensions().nbDims > 0); + nvinfer1::ITensor* data = network.addInput(m_InputBlobName.c_str(), nvinfer1::DataType::kFLOAT, + nvinfer1::Dims{3, {static_cast(m_InputC), static_cast(m_InputH), static_cast(m_InputW)}}); + assert(data != nullptr && data->getDimensions().nbDims > 0); - nvinfer1::ITensor* previous = data; - std::vector tensorOutputs; + nvinfer1::ITensor* previous = data; + std::vector tensorOutputs; - nvinfer1::ITensor* yoloTensorInputs[m_YoloCount]; - uint yoloCountInputs = 0; + nvinfer1::ITensor* yoloTensorInputs[m_YoloCount]; + uint yoloCountInputs = 0; - int modelType = -1; + int modelType = -1; - for (uint i = 0; i < m_ConfigBlocks.size(); ++i) - { - std::string layerIndex = "(" + std::to_string(tensorOutputs.size()) + ")"; + for (uint i = 0; i < m_ConfigBlocks.size(); ++i) { + std::string layerIndex = "(" + std::to_string(tensorOutputs.size()) + ")"; - if (m_ConfigBlocks.at(i).at("type") == "net") - printLayerInfo("", "Layer", "Input Shape", "Output Shape", "WeightPtr"); - - else if (m_ConfigBlocks.at(i).at("type") == "convolutional") - { - int channels = getNumChannels(previous); - std::string inputVol = dimsToString(previous->getDimensions()); - previous = convolutionalLayer( - i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, weightsType, channels, eps, previous, &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - std::string layerName = "conv_" + m_ConfigBlocks.at(i).at("activation"); - printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr)); - } + if (m_ConfigBlocks.at(i).at("type") == "net") + printLayerInfo("", "Layer", "Input Shape", "Output Shape", "WeightPtr"); + else if (m_ConfigBlocks.at(i).at("type") == "convolutional") { + int channels = getNumChannels(previous); + std::string inputVol = dimsToString(previous->getDimensions()); + previous = convolutionalLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, weightsType, channels, eps, + previous, &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerName = "conv_" + m_ConfigBlocks.at(i).at("activation"); + printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr)); + } + else if (m_ConfigBlocks.at(i).at("type") == "c2f") { + std::string inputVol = dimsToString(previous->getDimensions()); + previous = c2fLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, weightsType, eps, previous, &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerName = "c2f_" + m_ConfigBlocks.at(i).at("activation"); + printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr)); + } + else if (m_ConfigBlocks.at(i).at("type") == "batchnorm") { + std::string inputVol = dimsToString(previous->getDimensions()); + previous = batchnormLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, weightsType, eps, previous, + &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerName = "batchnorm_" + m_ConfigBlocks.at(i).at("activation"); + printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr)); + } + else if (m_ConfigBlocks.at(i).at("type") == "implicit_add" || m_ConfigBlocks.at(i).at("type") == "implicit_mul") { + previous = implicitLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerName = m_ConfigBlocks.at(i).at("type"); + printLayerInfo(layerIndex, layerName, "-", outputVol, std::to_string(weightPtr)); + } + else if (m_ConfigBlocks.at(i).at("type") == "shift_channels" || m_ConfigBlocks.at(i).at("type") == "control_channels") { + assert(m_ConfigBlocks.at(i).find("from") != m_ConfigBlocks.at(i).end()); + int from = stoi(m_ConfigBlocks.at(i).at("from")); + if (from > 0) + from = from - i + 1; + assert((i - 2 >= 0) && (i - 2 < tensorOutputs.size())); + assert((i + from - 1 >= 0) && (i + from - 1 < tensorOutputs.size())); + assert(i + from - 1 < i - 2); - else if (m_ConfigBlocks.at(i).at("type") == "batchnorm") - { - std::string inputVol = dimsToString(previous->getDimensions()); - previous = batchnormLayer( - i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, weightsType, eps, previous, &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - std::string layerName = "batchnorm_" + m_ConfigBlocks.at(i).at("activation"); - printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr)); - } + std::string inputVol = dimsToString(previous->getDimensions()); + previous = channelsLayer(i, m_ConfigBlocks.at(i), previous, tensorOutputs[i + from - 1], &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerName = m_ConfigBlocks.at(i).at("type") + ": " + std::to_string(i + from - 1); + printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); + } + else if (m_ConfigBlocks.at(i).at("type") == "shortcut") { + assert(m_ConfigBlocks.at(i).find("from") != m_ConfigBlocks.at(i).end()); + int from = stoi(m_ConfigBlocks.at(i).at("from")); + if (from > 0) + from = from - i + 1; + assert((i - 2 >= 0) && (i - 2 < tensorOutputs.size())); + assert((i + from - 1 >= 0) && (i + from - 1 < tensorOutputs.size())); + assert(i + from - 1 < i - 2); - else if (m_ConfigBlocks.at(i).at("type") == "implicit_add" || m_ConfigBlocks.at(i).at("type") == "implicit_mul") - { - previous = implicitLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - std::string layerName = m_ConfigBlocks.at(i).at("type"); - printLayerInfo(layerIndex, layerName, "-", outputVol, std::to_string(weightPtr)); - } + std::string mode = "add"; + if (m_ConfigBlocks.at(i).find("mode") != m_ConfigBlocks.at(i).end()) + mode = m_ConfigBlocks.at(i).at("mode"); - else if (m_ConfigBlocks.at(i).at("type") == "shift_channels" || - m_ConfigBlocks.at(i).at("type") == "control_channels") - { - assert(m_ConfigBlocks.at(i).find("from") != m_ConfigBlocks.at(i).end()); - int from = stoi(m_ConfigBlocks.at(i).at("from")); - if (from > 0) - from = from - i + 1; - assert((i - 2 >= 0) && (i - 2 < tensorOutputs.size())); - assert((i + from - 1 >= 0) && (i + from - 1 < tensorOutputs.size())); - assert(i + from - 1 < i - 2); + std::string activation = "linear"; + if (m_ConfigBlocks.at(i).find("activation") != m_ConfigBlocks.at(i).end()) + activation = m_ConfigBlocks.at(i).at("activation"); - std::string inputVol = dimsToString(previous->getDimensions()); - previous = channelsLayer(i, m_ConfigBlocks.at(i), previous, tensorOutputs[i + from - 1], &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - std::string layerName = m_ConfigBlocks.at(i).at("type") + ": " + std::to_string(i + from - 1); - printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); - } - - else if (m_ConfigBlocks.at(i).at("type") == "shortcut") - { - assert(m_ConfigBlocks.at(i).find("from") != m_ConfigBlocks.at(i).end()); - int from = stoi(m_ConfigBlocks.at(i).at("from")); - if (from > 0) - from = from - i + 1; - assert((i - 2 >= 0) && (i - 2 < tensorOutputs.size())); - assert((i + from - 1 >= 0) && (i + from - 1 < tensorOutputs.size())); - assert(i + from - 1 < i - 2); - - std::string mode = "add"; - if (m_ConfigBlocks.at(i).find("mode") != m_ConfigBlocks.at(i).end()) - mode = m_ConfigBlocks.at(i).at("mode"); - - std::string activation = "linear"; - if (m_ConfigBlocks.at(i).find("activation") != m_ConfigBlocks.at(i).end()) - activation = m_ConfigBlocks.at(i).at("activation"); - - std::string inputVol = dimsToString(previous->getDimensions()); - std::string shortcutVol = dimsToString(tensorOutputs[i + from - 1]->getDimensions()); - previous = shortcutLayer( - i, mode, activation, inputVol, shortcutVol, m_ConfigBlocks.at(i), previous, tensorOutputs[i + from - 1], - &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - std::string layerName = "shortcut_" + mode + "_" + activation + ": " + std::to_string(i + from - 1); - printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); - - if (mode == "add" && inputVol != shortcutVol) - std::cout << inputVol << " +" << shortcutVol << std::endl; - } - - else if (m_ConfigBlocks.at(i).at("type") == "route") - { - std::string layers; - previous = routeLayer(i, layers, m_ConfigBlocks.at(i), tensorOutputs, &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - std::string layerName = "route: " + layers; - printLayerInfo(layerIndex, layerName, "-", outputVol, "-"); - } - - else if (m_ConfigBlocks.at(i).at("type") == "upsample") - { - std::string inputVol = dimsToString(previous->getDimensions()); - previous = upsampleLayer(i, m_ConfigBlocks[i], previous, &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - std::string layerName = "upsample"; - printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); - } - - else if (m_ConfigBlocks.at(i).at("type") == "maxpool" || m_ConfigBlocks.at(i).at("type") == "avgpool") - { - std::string inputVol = dimsToString(previous->getDimensions()); - previous = poolingLayer(i, m_ConfigBlocks.at(i), previous, &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - std::string layerName = m_ConfigBlocks.at(i).at("type"); - printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); - } - - else if (m_ConfigBlocks.at(i).at("type") == "reorg") - { - std::string inputVol = dimsToString(previous->getDimensions()); - if (m_NetworkType.find("yolov2") != std::string::npos) { - nvinfer1::IPluginV2* reorgPlugin = createReorgPlugin(2); - assert(reorgPlugin != nullptr); - nvinfer1::IPluginV2Layer* reorg = network.addPluginV2(&previous, 1, *reorgPlugin); - assert(reorg != nullptr); - std::string reorglayerName = "reorg_" + std::to_string(i); - reorg->setName(reorglayerName.c_str()); - previous = reorg->getOutput(0); - } - else - previous = reorgLayer(i, m_ConfigBlocks.at(i), previous, &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - std::string layerName = "reorg"; - printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); - } - - else if (m_ConfigBlocks.at(i).at("type") == "reduce") - { - std::string inputVol = dimsToString(previous->getDimensions()); - previous = reduceLayer(i, m_ConfigBlocks.at(i), previous, &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - std::string layerName = "reduce"; - printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); - } - - else if (m_ConfigBlocks.at(i).at("type") == "shuffle") - { - std::string layer; - std::string inputVol = dimsToString(previous->getDimensions()); - previous = shuffleLayer(i, layer, m_ConfigBlocks.at(i), previous, tensorOutputs, &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - std::string layerName = "shuffle: " + layer; - printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); - } - - else if (m_ConfigBlocks.at(i).at("type") == "softmax") - { - std::string inputVol = dimsToString(previous->getDimensions()); - previous = softmaxLayer(i, m_ConfigBlocks.at(i), previous, &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - std::string layerName = "softmax"; - printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); - } - - else if (m_ConfigBlocks.at(i).at("type") == "yolo" || m_ConfigBlocks.at(i).at("type") == "region") - { - if (m_ConfigBlocks.at(i).at("type") == "yolo") - if (m_NetworkType.find("yolor") != std::string::npos) - modelType = 2; - else - modelType = 1; - else - modelType = 0; - - std::string blobName = modelType != 0 ? "yolo_" + std::to_string(i) : "region_" + std::to_string(i); - nvinfer1::Dims prevTensorDims = previous->getDimensions(); - TensorInfo& curYoloTensor = m_YoloTensors.at(yoloCountInputs); - curYoloTensor.blobName = blobName; - curYoloTensor.gridSizeX = prevTensorDims.d[2]; - curYoloTensor.gridSizeY = prevTensorDims.d[1]; - - std::string inputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - yoloTensorInputs[yoloCountInputs] = previous; - ++yoloCountInputs; - std::string layerName = modelType != 0 ? "yolo" : "region"; - printLayerInfo(layerIndex, layerName, inputVol, "-", "-"); - } - - else if (m_ConfigBlocks.at(i).at("type") == "cls") - { - modelType = 3; - - std::string blobName = "cls_" + std::to_string(i); - nvinfer1::Dims prevTensorDims = previous->getDimensions(); - TensorInfo& curYoloTensor = m_YoloTensors.at(yoloCountInputs); - curYoloTensor.blobName = blobName; - curYoloTensor.numBBoxes = prevTensorDims.d[1]; - m_NumClasses = prevTensorDims.d[0]; - - std::string inputVol = dimsToString(previous->getDimensions()); - previous = clsLayer(i, m_ConfigBlocks.at(i), previous, &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - yoloTensorInputs[yoloCountInputs] = previous; - ++yoloCountInputs; - std::string layerName = "cls"; - printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); - } - - else if (m_ConfigBlocks.at(i).at("type") == "reg") - { - modelType = 3; - - std::string blobName = "reg_" + std::to_string(i); - nvinfer1::Dims prevTensorDims = previous->getDimensions(); - TensorInfo& curYoloTensor = m_YoloTensors.at(yoloCountInputs); - curYoloTensor.blobName = blobName; - curYoloTensor.numBBoxes = prevTensorDims.d[1]; - - std::string inputVol = dimsToString(previous->getDimensions()); - previous = regLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, previous, &network); - assert(previous != nullptr); - std::string outputVol = dimsToString(previous->getDimensions()); - tensorOutputs.push_back(previous); - yoloTensorInputs[yoloCountInputs] = previous; - ++yoloCountInputs; - std::string layerName = "reg"; - printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr)); - } + std::string inputVol = dimsToString(previous->getDimensions()); + std::string shortcutVol = dimsToString(tensorOutputs[i + from - 1]->getDimensions()); + previous = shortcutLayer(i, mode, activation, inputVol, shortcutVol, m_ConfigBlocks.at(i), previous, + tensorOutputs[i + from - 1], &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerName = "shortcut_" + mode + "_" + activation + ": " + std::to_string(i + from - 1); + printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); + if (mode == "add" && inputVol != shortcutVol) + std::cout << inputVol << " +" << shortcutVol << std::endl; + } + else if (m_ConfigBlocks.at(i).at("type") == "route") { + std::string layers; + previous = routeLayer(i, layers, m_ConfigBlocks.at(i), tensorOutputs, &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerName = "route: " + layers; + printLayerInfo(layerIndex, layerName, "-", outputVol, "-"); + } + else if (m_ConfigBlocks.at(i).at("type") == "upsample") { + std::string inputVol = dimsToString(previous->getDimensions()); + previous = upsampleLayer(i, m_ConfigBlocks[i], previous, &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerName = "upsample"; + printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); + } + else if (m_ConfigBlocks.at(i).at("type") == "maxpool" || m_ConfigBlocks.at(i).at("type") == "avgpool") { + std::string inputVol = dimsToString(previous->getDimensions()); + previous = poolingLayer(i, m_ConfigBlocks.at(i), previous, &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerName = m_ConfigBlocks.at(i).at("type"); + printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); + } + else if (m_ConfigBlocks.at(i).at("type") == "reorg") { + std::string inputVol = dimsToString(previous->getDimensions()); + if (m_NetworkType.find("yolov2") != std::string::npos) { + nvinfer1::IPluginV2* reorgPlugin = createReorgPlugin(2); + assert(reorgPlugin != nullptr); + nvinfer1::IPluginV2Layer* reorg = network.addPluginV2(&previous, 1, *reorgPlugin); + assert(reorg != nullptr); + std::string reorglayerName = "reorg_" + std::to_string(i); + reorg->setName(reorglayerName.c_str()); + previous = reorg->getOutput(0); + } + else + previous = reorgLayer(i, m_ConfigBlocks.at(i), previous, &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerName = "reorg"; + printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); + } + else if (m_ConfigBlocks.at(i).at("type") == "reduce") { + std::string inputVol = dimsToString(previous->getDimensions()); + previous = reduceLayer(i, m_ConfigBlocks.at(i), previous, &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerName = "reduce"; + printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); + } + else if (m_ConfigBlocks.at(i).at("type") == "shuffle") { + std::string layer; + std::string inputVol = dimsToString(previous->getDimensions()); + previous = shuffleLayer(i, layer, m_ConfigBlocks.at(i), previous, tensorOutputs, &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerName = "shuffle: " + layer; + printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); + } + else if (m_ConfigBlocks.at(i).at("type") == "softmax") { + std::string inputVol = dimsToString(previous->getDimensions()); + previous = softmaxLayer(i, m_ConfigBlocks.at(i), previous, &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + std::string layerName = "softmax"; + printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); + } + else if (m_ConfigBlocks.at(i).at("type") == "yolo" || m_ConfigBlocks.at(i).at("type") == "region") { + if (m_ConfigBlocks.at(i).at("type") == "yolo") + if (m_NetworkType.find("yolor") != std::string::npos) + modelType = 2; else - { - std::cout << "\nUnsupported layer type --> \"" << m_ConfigBlocks.at(i).at("type") << "\"" << std::endl; - assert(0); - } + modelType = 1; + else + modelType = 0; + + std::string blobName = modelType != 0 ? "yolo_" + std::to_string(i) : "region_" + std::to_string(i); + nvinfer1::Dims prevTensorDims = previous->getDimensions(); + TensorInfo& curYoloTensor = m_YoloTensors.at(yoloCountInputs); + curYoloTensor.blobName = blobName; + curYoloTensor.gridSizeX = prevTensorDims.d[2]; + curYoloTensor.gridSizeY = prevTensorDims.d[1]; + + std::string inputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + yoloTensorInputs[yoloCountInputs] = previous; + ++yoloCountInputs; + std::string layerName = modelType != 0 ? "yolo" : "region"; + printLayerInfo(layerIndex, layerName, inputVol, "-", "-"); } + else if (m_ConfigBlocks.at(i).at("type") == "cls") { + modelType = 3; - if ((int)weights.size() != weightPtr) - { - std::cout << "\nNumber of unused weights left: " << weights.size() - weightPtr << std::endl; - assert(0); + std::string blobName = "cls_" + std::to_string(i); + nvinfer1::Dims prevTensorDims = previous->getDimensions(); + TensorInfo& curYoloTensor = m_YoloTensors.at(yoloCountInputs); + curYoloTensor.blobName = blobName; + curYoloTensor.numBBoxes = prevTensorDims.d[1]; + m_NumClasses = prevTensorDims.d[0]; + + std::string inputVol = dimsToString(previous->getDimensions()); + previous = clsLayer(i, m_ConfigBlocks.at(i), previous, &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + yoloTensorInputs[yoloCountInputs] = previous; + ++yoloCountInputs; + std::string layerName = "cls"; + printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-"); } + else if (m_ConfigBlocks.at(i).at("type") == "reg") { + modelType = 3; - if (m_YoloCount == yoloCountInputs) - { - assert((modelType != -1) && "\nCould not determine model type"); + std::string blobName = "reg_" + std::to_string(i); + nvinfer1::Dims prevTensorDims = previous->getDimensions(); + TensorInfo& curYoloTensor = m_YoloTensors.at(yoloCountInputs); + curYoloTensor.blobName = blobName; + curYoloTensor.numBBoxes = prevTensorDims.d[1]; - uint64_t outputSize = 0; - for (uint j = 0; j < yoloCountInputs; ++j) - { - TensorInfo& curYoloTensor = m_YoloTensors.at(j); - if (modelType == 3) - outputSize = curYoloTensor.numBBoxes; - else - outputSize += curYoloTensor.gridSizeX * curYoloTensor.gridSizeY * curYoloTensor.numBBoxes; - } + std::string inputVol = dimsToString(previous->getDimensions()); + previous = regLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, previous, &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + yoloTensorInputs[yoloCountInputs] = previous; + ++yoloCountInputs; + std::string layerName = "reg"; + printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr)); + } + else if (m_ConfigBlocks.at(i).at("type") == "detect_v8") { + modelType = 4; - nvinfer1::IPluginV2* yoloPlugin = new YoloLayer( - m_InputW, m_InputH, m_NumClasses, m_NewCoords, m_YoloTensors, outputSize, modelType, m_ScoreThreshold); - assert(yoloPlugin != nullptr); - nvinfer1::IPluginV2Layer* yolo = network.addPluginV2(yoloTensorInputs, m_YoloCount, *yoloPlugin); - assert(yolo != nullptr); - std::string yoloLayerName = "yolo"; - yolo->setName(yoloLayerName.c_str()); + std::string blobName = "detect_v8_" + std::to_string(i); + nvinfer1::Dims prevTensorDims = previous->getDimensions(); + TensorInfo& curYoloTensor = m_YoloTensors.at(yoloCountInputs); + curYoloTensor.blobName = blobName; + curYoloTensor.numBBoxes = prevTensorDims.d[1]; - std::string outputlayerName; - nvinfer1::ITensor* num_detections = yolo->getOutput(0); - outputlayerName = "num_detections"; - num_detections->setName(outputlayerName.c_str()); - nvinfer1::ITensor* detection_boxes = yolo->getOutput(1); - outputlayerName = "detection_boxes"; - detection_boxes->setName(outputlayerName.c_str()); - nvinfer1::ITensor* detection_scores = yolo->getOutput(2); - outputlayerName = "detection_scores"; - detection_scores->setName(outputlayerName.c_str()); - nvinfer1::ITensor* detection_classes = yolo->getOutput(3); - outputlayerName = "detection_classes"; - detection_classes->setName(outputlayerName.c_str()); - network.markOutput(*num_detections); - network.markOutput(*detection_boxes); - network.markOutput(*detection_scores); - network.markOutput(*detection_classes); + std::string inputVol = dimsToString(previous->getDimensions()); + previous = detectV8Layer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, previous, &network); + assert(previous != nullptr); + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + yoloTensorInputs[yoloCountInputs] = previous; + ++yoloCountInputs; + std::string layerName = "detect_v8"; + printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr)); } else { - std::cout << "\nError in yolo cfg file" << std::endl; - assert(0); + std::cerr << "\nUnsupported layer type --> \"" << m_ConfigBlocks.at(i).at("type") << "\"" << std::endl; + assert(0); + } + } + + if ((int) weights.size() != weightPtr) { + std::cerr << "\nNumber of unused weights left: " << weights.size() - weightPtr << std::endl; + assert(0); + } + + if (m_YoloCount == yoloCountInputs) { + assert((modelType != -1) && "\nCould not determine model type"); + + uint64_t outputSize = 0; + for (uint j = 0; j < yoloCountInputs; ++j) { + TensorInfo& curYoloTensor = m_YoloTensors.at(j); + if (modelType == 3 || modelType == 4) + outputSize = curYoloTensor.numBBoxes; + else + outputSize += curYoloTensor.gridSizeX * curYoloTensor.gridSizeY * curYoloTensor.numBBoxes; } - std::cout << "\nOutput YOLO blob names: " << std::endl; - for (auto& tensor : m_YoloTensors) - { - std::cout << tensor.blobName << std::endl; - } + nvinfer1::IPluginV2* yoloPlugin = new YoloLayer(m_InputW, m_InputH, m_NumClasses, m_NewCoords, m_YoloTensors, outputSize, + modelType, m_ScoreThreshold); + assert(yoloPlugin != nullptr); + nvinfer1::IPluginV2Layer* yolo = network.addPluginV2(yoloTensorInputs, m_YoloCount, *yoloPlugin); + assert(yolo != nullptr); + std::string yoloLayerName = "yolo"; + yolo->setName(yoloLayerName.c_str()); - int nbLayers = network.getNbLayers(); - std::cout << "\nTotal number of YOLO layers: " << nbLayers << "\n" << std::endl; + std::string outputlayerName; + nvinfer1::ITensor* num_detections = yolo->getOutput(0); + outputlayerName = "num_detections"; + num_detections->setName(outputlayerName.c_str()); + nvinfer1::ITensor* detection_boxes = yolo->getOutput(1); + outputlayerName = "detection_boxes"; + detection_boxes->setName(outputlayerName.c_str()); + nvinfer1::ITensor* detection_scores = yolo->getOutput(2); + outputlayerName = "detection_scores"; + detection_scores->setName(outputlayerName.c_str()); + nvinfer1::ITensor* detection_classes = yolo->getOutput(3); + outputlayerName = "detection_classes"; + detection_classes->setName(outputlayerName.c_str()); + network.markOutput(*num_detections); + network.markOutput(*detection_boxes); + network.markOutput(*detection_scores); + network.markOutput(*detection_classes); + } + else { + std::cerr << "\nError in yolo cfg file" << std::endl; + assert(0); + } - return NVDSINFER_SUCCESS; + std::cout << "\nOutput YOLO blob names: " << std::endl; + for (auto& tensor : m_YoloTensors) + std::cout << tensor.blobName << std::endl; + + int nbLayers = network.getNbLayers(); + std::cout << "\nTotal number of YOLO layers: " << nbLayers << "\n" << std::endl; + + return NVDSINFER_SUCCESS; } std::vector> -Yolo::parseConfigFile (const std::string cfgFilePath) +Yolo::parseConfigFile(const std::string cfgFilePath) { - assert(fileExists(cfgFilePath)); - std::ifstream file(cfgFilePath); - assert(file.good()); - std::string line; - std::vector> blocks; - std::map block; + assert(fileExists(cfgFilePath)); + std::ifstream file(cfgFilePath); + assert(file.good()); + std::string line; + std::vector> blocks; + std::map block; - while (getline(file, line)) - { - if (line.size() == 0) continue; - if (line.front() == ' ') continue; - if (line.front() == '#') continue; - line = trim(line); - if (line.front() == '[') - { - if (block.size() > 0) - { - blocks.push_back(block); - block.clear(); - } - std::string key = "type"; - std::string value = trim(line.substr(1, line.size() - 2)); - block.insert(std::pair(key, value)); - } - else - { - int cpos = line.find('='); - std::string key = trim(line.substr(0, cpos)); - std::string value = trim(line.substr(cpos + 1)); - block.insert(std::pair(key, value)); - } + while (getline(file, line)) { + if (line.size() == 0 || line.front() == ' ' || line.front() == '#') + continue; + + line = trim(line); + if (line.front() == '[') { + if (block.size() > 0) { + blocks.push_back(block); + block.clear(); + } + std::string key = "type"; + std::string value = trim(line.substr(1, line.size() - 2)); + block.insert(std::pair(key, value)); } - blocks.push_back(block); - return blocks; -} - -void Yolo::parseConfigBlocks() -{ - for (auto block : m_ConfigBlocks) - { - if (block.at("type") == "net") - { - assert((block.find("height") != block.end()) && "Missing 'height' param in network cfg"); - assert((block.find("width") != block.end()) && "Missing 'width' param in network cfg"); - assert((block.find("channels") != block.end()) && "Missing 'channels' param in network cfg"); - - m_InputH = std::stoul(block.at("height")); - m_InputW = std::stoul(block.at("width")); - m_InputC = std::stoul(block.at("channels")); - m_InputSize = m_InputC * m_InputH * m_InputW; - - if (block.find("letter_box") != block.end()) - { - m_LetterBox = std::stoul(block.at("letter_box")); - } - } - else if ((block.at("type") == "region") || (block.at("type") == "yolo")) - { - assert((block.find("num") != block.end()) - && std::string("Missing 'num' param in " + block.at("type") + " layer").c_str()); - assert((block.find("classes") != block.end()) - && std::string("Missing 'classes' param in " + block.at("type") + " layer").c_str()); - assert((block.find("anchors") != block.end()) - && std::string("Missing 'anchors' param in " + block.at("type") + " layer").c_str()); - - ++m_YoloCount; - - m_NumClasses = std::stoul(block.at("classes")); - - if (block.find("new_coords") != block.end()) - { - m_NewCoords = std::stoul(block.at("new_coords")); - } - - TensorInfo outputTensor; - - std::string anchorString = block.at("anchors"); - while (!anchorString.empty()) - { - int npos = anchorString.find_first_of(','); - if (npos != -1) - { - float anchor = std::stof(trim(anchorString.substr(0, npos))); - outputTensor.anchors.push_back(anchor); - anchorString.erase(0, npos + 1); - } - else - { - float anchor = std::stof(trim(anchorString)); - outputTensor.anchors.push_back(anchor); - break; - } - } - - if (block.find("mask") != block.end()) - { - std::string maskString = block.at("mask"); - while (!maskString.empty()) - { - int npos = maskString.find_first_of(','); - if (npos != -1) - { - int mask = std::stoul(trim(maskString.substr(0, npos))); - outputTensor.mask.push_back(mask); - maskString.erase(0, npos + 1); - } - else - { - int mask = std::stoul(trim(maskString)); - outputTensor.mask.push_back(mask); - break; - } - } - } - - if (block.find("scale_x_y") != block.end()) - { - outputTensor.scaleXY = std::stof(block.at("scale_x_y")); - } - else - { - outputTensor.scaleXY = 1.0; - } - - outputTensor.numBBoxes - = outputTensor.mask.size() > 0 ? outputTensor.mask.size() : std::stoul(trim(block.at("num"))); - - m_YoloTensors.push_back(outputTensor); - } - else if ((block.at("type") == "cls") || (block.at("type") == "reg")) - { - ++m_YoloCount; - TensorInfo outputTensor; - m_YoloTensors.push_back(outputTensor); - } + else { + int cpos = line.find('='); + std::string key = trim(line.substr(0, cpos)); + std::string value = trim(line.substr(cpos + 1)); + block.insert(std::pair(key, value)); } + } + + blocks.push_back(block); + return blocks; } -void Yolo::destroyNetworkUtils() +void +Yolo::parseConfigBlocks() { - for (uint i = 0; i < m_TrtWeights.size(); ++i) - if (m_TrtWeights[i].count > 0) - free(const_cast(m_TrtWeights[i].values)); - m_TrtWeights.clear(); + for (auto block : m_ConfigBlocks) { + if (block.at("type") == "net") { + assert((block.find("height") != block.end()) && "Missing 'height' param in network cfg"); + assert((block.find("width") != block.end()) && "Missing 'width' param in network cfg"); + assert((block.find("channels") != block.end()) && "Missing 'channels' param in network cfg"); + + m_InputH = std::stoul(block.at("height")); + m_InputW = std::stoul(block.at("width")); + m_InputC = std::stoul(block.at("channels")); + m_InputSize = m_InputC * m_InputH * m_InputW; + + if (block.find("letter_box") != block.end()) + m_LetterBox = std::stoul(block.at("letter_box")); + } + else if ((block.at("type") == "region") || (block.at("type") == "yolo")) + { + assert((block.find("num") != block.end()) && + std::string("Missing 'num' param in " + block.at("type") + " layer").c_str()); + assert((block.find("classes") != block.end()) && + std::string("Missing 'classes' param in " + block.at("type") + " layer").c_str()); + assert((block.find("anchors") != block.end()) && + std::string("Missing 'anchors' param in " + block.at("type") + " layer").c_str()); + + ++m_YoloCount; + + m_NumClasses = std::stoul(block.at("classes")); + + if (block.find("new_coords") != block.end()) + m_NewCoords = std::stoul(block.at("new_coords")); + + TensorInfo outputTensor; + + std::string anchorString = block.at("anchors"); + while (!anchorString.empty()) { + int npos = anchorString.find_first_of(','); + if (npos != -1) { + float anchor = std::stof(trim(anchorString.substr(0, npos))); + outputTensor.anchors.push_back(anchor); + anchorString.erase(0, npos + 1); + } + else { + float anchor = std::stof(trim(anchorString)); + outputTensor.anchors.push_back(anchor); + break; + } + } + + if (block.find("mask") != block.end()) { + std::string maskString = block.at("mask"); + while (!maskString.empty()) { + int npos = maskString.find_first_of(','); + if (npos != -1) { + int mask = std::stoul(trim(maskString.substr(0, npos))); + outputTensor.mask.push_back(mask); + maskString.erase(0, npos + 1); + } + else { + int mask = std::stoul(trim(maskString)); + outputTensor.mask.push_back(mask); + break; + } + } + } + + if (block.find("scale_x_y") != block.end()) + outputTensor.scaleXY = std::stof(block.at("scale_x_y")); + else + outputTensor.scaleXY = 1.0; + + outputTensor.numBBoxes = outputTensor.mask.size() > 0 ? outputTensor.mask.size() : std::stoul(trim(block.at("num"))); + + m_YoloTensors.push_back(outputTensor); + } + else if ((block.at("type") == "cls") || (block.at("type") == "reg")) { + ++m_YoloCount; + TensorInfo outputTensor; + m_YoloTensors.push_back(outputTensor); + } + else if (block.at("type") == "detect_v8") { + ++m_YoloCount; + + m_NumClasses = std::stoul(block.at("classes")); + + TensorInfo outputTensor; + m_YoloTensors.push_back(outputTensor); + } + } +} + +void +Yolo::destroyNetworkUtils() +{ + for (uint i = 0; i < m_TrtWeights.size(); ++i) + if (m_TrtWeights[i].count > 0) + free(const_cast(m_TrtWeights[i].values)); + m_TrtWeights.clear(); } diff --git a/nvdsinfer_custom_impl_Yolo/yolo.h b/nvdsinfer_custom_impl_Yolo/yolo.h index ea887cc..c915337 100644 --- a/nvdsinfer_custom_impl_Yolo/yolo.h +++ b/nvdsinfer_custom_impl_Yolo/yolo.h @@ -26,7 +26,11 @@ #ifndef _YOLO_H_ #define _YOLO_H_ +#include "NvInferPlugin.h" +#include "nvdsinfer_custom_impl.h" + #include "layers/convolutional_layer.h" +#include "layers/c2f_layer.h" #include "layers/batchnorm_layer.h" #include "layers/implicit_layer.h" #include "layers/channels_layer.h" @@ -40,36 +44,35 @@ #include "layers/softmax_layer.h" #include "layers/cls_layer.h" #include "layers/reg_layer.h" - -#include "nvdsinfer_custom_impl.h" +#include "layers/detect_v8_layer.h" struct NetworkInfo { - std::string inputBlobName; - std::string networkType; - std::string configFilePath; - std::string wtsFilePath; - std::string int8CalibPath; - std::string deviceType; - uint numDetectedClasses; - int clusterMode; - float scoreThreshold; - std::string networkMode; + std::string inputBlobName; + std::string networkType; + std::string configFilePath; + std::string wtsFilePath; + std::string int8CalibPath; + std::string deviceType; + uint numDetectedClasses; + int clusterMode; + float scoreThreshold; + std::string networkMode; }; struct TensorInfo { - std::string blobName; - uint gridSizeX {0}; - uint gridSizeY {0}; - uint numBBoxes {0}; - float scaleXY; - std::vector anchors; - std::vector mask; + std::string blobName; + uint gridSizeX {0}; + uint gridSizeY {0}; + uint numBBoxes {0}; + float scaleXY; + std::vector anchors; + std::vector mask; }; class Yolo : public IModelParser { -public: + public: Yolo(const NetworkInfo& networkInfo); ~Yolo() override; @@ -77,14 +80,14 @@ public: bool hasFullDimsSupported() const override { return false; } const char* getModelName() const override { - return m_ConfigFilePath.empty() ? m_NetworkType.c_str() : m_ConfigFilePath.c_str(); + return m_ConfigFilePath.empty() ? m_NetworkType.c_str() : m_ConfigFilePath.c_str(); } NvDsInferStatus parseModel(nvinfer1::INetworkDefinition& network) override; - nvinfer1::ICudaEngine *createEngine (nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config); + nvinfer1::ICudaEngine* createEngine(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config); -protected: + protected: const std::string m_InputBlobName; const std::string m_NetworkType; const std::string m_ConfigFilePath; @@ -109,7 +112,7 @@ protected: std::vector> m_ConfigBlocks; std::vector m_TrtWeights; -private: + private: NvDsInferStatus buildYoloNetwork(std::vector& weights, nvinfer1::INetworkDefinition& network); std::vector> parseConfigFile(const std::string cfgFilePath); diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward.cu b/nvdsinfer_custom_impl_Yolo/yoloForward.cu index ab65833..9d0d613 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloForward.cu +++ b/nvdsinfer_custom_impl_Yolo/yoloForward.cu @@ -7,98 +7,82 @@ inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); } -__global__ void gpuYoloLayer( - const float* input, int* num_detections, float* detection_boxes, float* detection_scores, int* detection_classes, - const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, const uint gridSizeY, - const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors, const int* mask) +__global__ void gpuYoloLayer(const float* input, int* num_detections, float* detection_boxes, float* detection_scores, + int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, + const uint gridSizeY, const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors, + const int* mask) { - uint x_id = blockIdx.x * blockDim.x + threadIdx.x; - uint y_id = blockIdx.y * blockDim.y + threadIdx.y; - uint z_id = blockIdx.z * blockDim.z + threadIdx.z; + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + uint y_id = blockIdx.y * blockDim.y + threadIdx.y; + uint z_id = blockIdx.z * blockDim.z + threadIdx.z; - if (x_id >= gridSizeX || y_id >= gridSizeY || z_id >= numBBoxes) - return; + if (x_id >= gridSizeX || y_id >= gridSizeY || z_id >= numBBoxes) + return; - const int numGridCells = gridSizeX * gridSizeY; - const int bbindex = y_id * gridSizeX + x_id; + const int numGridCells = gridSizeX * gridSizeY; + const int bbindex = y_id * gridSizeX + x_id; - const float objectness - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]); + const float objectness = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]); - if (objectness < scoreThreshold) - return; + if (objectness < scoreThreshold) + return; - int count = (int)atomicAdd(num_detections, 1); + int count = (int)atomicAdd(num_detections, 1); - const float alpha = scaleXY; - const float beta = -0.5 * (scaleXY - 1); + const float alpha = scaleXY; + const float beta = -0.5 * (scaleXY - 1); - float x - = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) - * alpha + beta + x_id) * netWidth / gridSizeX; + float x = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta + x_id) + * netWidth / gridSizeX; - float y - = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) - * alpha + beta + y_id) * netHeight / gridSizeY; + float y = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * alpha + beta + y_id) + * netHeight / gridSizeY; - float w - = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) - * anchors[mask[z_id] * 2]; + float w = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * anchors[mask[z_id] * 2]; - float h - = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) - * anchors[mask[z_id] * 2 + 1]; + float h = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * anchors[mask[z_id] * 2 + 1]; - float maxProb = 0.0f; - int maxIndex = -1; + float maxProb = 0.0f; + int maxIndex = -1; - for (uint i = 0; i < numOutputClasses; ++i) - { - float prob - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]); - - if (prob > maxProb) - { - maxProb = prob; - maxIndex = i; - } + for (uint i = 0; i < numOutputClasses; ++i) { + float prob = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]); + if (prob > maxProb) { + maxProb = prob; + maxIndex = i; } + } - detection_boxes[count * 4 + 0] = x - 0.5 * w; - detection_boxes[count * 4 + 1] = y - 0.5 * h; - detection_boxes[count * 4 + 2] = x + 0.5 * w; - detection_boxes[count * 4 + 3] = y + 0.5 * h; - detection_scores[count] = objectness * maxProb; - detection_classes[count] = maxIndex; + detection_boxes[count * 4 + 0] = x - 0.5 * w; + detection_boxes[count * 4 + 1] = y - 0.5 * h; + detection_boxes[count * 4 + 2] = x + 0.5 * w; + detection_boxes[count * 4 + 3] = y + 0.5 * h; + detection_scores[count] = objectness * maxProb; + detection_classes[count] = maxIndex; } -cudaError_t cudaYoloLayer( - const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes, - const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, - const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes, - const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream); +cudaError_t cudaYoloLayer(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, + void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, + const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, + const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream); -cudaError_t cudaYoloLayer( - const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes, - const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, - const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes, - const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream) +cudaError_t cudaYoloLayer(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, + void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, + const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, + const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream) { - dim3 threads_per_block(16, 16, 4); - dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, - (gridSizeY / threads_per_block.y) + 1, - (numBBoxes / threads_per_block.z) + 1); + dim3 threads_per_block(16, 16, 4); + dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, (gridSizeY / threads_per_block.y) + 1, + (numBBoxes / threads_per_block.z) + 1); - for (unsigned int batch = 0; batch < batchSize; ++batch) - { - gpuYoloLayer<<>>( - reinterpret_cast(input) + (batch * inputSize), - reinterpret_cast(num_detections) + (batch), - reinterpret_cast(detection_boxes) + (batch * 4 * outputSize), - reinterpret_cast(detection_scores) + (batch * outputSize), - reinterpret_cast(detection_classes) + (batch * outputSize), - scoreThreshold, netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes, scaleXY, - reinterpret_cast(anchors), reinterpret_cast(mask)); - } - return cudaGetLastError(); + for (unsigned int batch = 0; batch < batchSize; ++batch) { + gpuYoloLayer<<>>( + reinterpret_cast(input) + (batch * inputSize), reinterpret_cast(num_detections) + (batch), + reinterpret_cast(detection_boxes) + (batch * 4 * outputSize), + reinterpret_cast(detection_scores) + (batch * outputSize), + reinterpret_cast(detection_classes) + (batch * outputSize), scoreThreshold, netWidth, netHeight, gridSizeX, + gridSizeY, numOutputClasses, numBBoxes, scaleXY, reinterpret_cast(anchors), + reinterpret_cast(mask)); + } + return cudaGetLastError(); } diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward_e.cu b/nvdsinfer_custom_impl_Yolo/yoloForward_e.cu index b702c99..9b7596a 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloForward_e.cu +++ b/nvdsinfer_custom_impl_Yolo/yoloForward_e.cu @@ -4,69 +4,61 @@ */ #include -#include -__global__ void gpuYoloLayer_e( - const float* cls, const float* reg, int* num_detections, float* detection_boxes, float* detection_scores, - int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, +__global__ void gpuYoloLayer_e(const float* cls, const float* reg, int* num_detections, float* detection_boxes, + float* detection_scores, int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, const uint numOutputClasses, const uint64_t outputSize) { - uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; - if (x_id >= outputSize) - return; + if (x_id >= outputSize) + return; - float maxProb = 0.0f; - int maxIndex = -1; + float maxProb = 0.0f; + int maxIndex = -1; - for (uint i = 0; i < numOutputClasses; ++i) - { - float prob - = cls[x_id * numOutputClasses + i]; - - if (prob > maxProb) - { - maxProb = prob; - maxIndex = i; - } + for (uint i = 0; i < numOutputClasses; ++i) { + float prob = cls[x_id * numOutputClasses + i]; + if (prob > maxProb) { + maxProb = prob; + maxIndex = i; } + } - if (maxProb < scoreThreshold) - return; + if (maxProb < scoreThreshold) + return; - int count = (int)atomicAdd(num_detections, 1); + int count = (int)atomicAdd(num_detections, 1); - detection_boxes[count * 4 + 0] = reg[x_id * 4 + 0]; - detection_boxes[count * 4 + 1] = reg[x_id * 4 + 1]; - detection_boxes[count * 4 + 2] = reg[x_id * 4 + 2]; - detection_boxes[count * 4 + 3] = reg[x_id * 4 + 3]; - detection_scores[count] = maxProb; - detection_classes[count] = maxIndex; + detection_boxes[count * 4 + 0] = reg[x_id * 4 + 0]; + detection_boxes[count * 4 + 1] = reg[x_id * 4 + 1]; + detection_boxes[count * 4 + 2] = reg[x_id * 4 + 2]; + detection_boxes[count * 4 + 3] = reg[x_id * 4 + 3]; + detection_scores[count] = maxProb; + detection_classes[count] = maxIndex; } -cudaError_t cudaYoloLayer_e( - const void* cls, const void* reg, void* num_detections, void* detection_boxes, void* detection_scores, - void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, - const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream); +cudaError_t cudaYoloLayer_e(const void* cls, const void* reg, void* num_detections, void* detection_boxes, + void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& outputSize, + const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& numOutputClasses, + cudaStream_t stream); -cudaError_t cudaYoloLayer_e( - const void* cls, const void* reg, void* num_detections, void* detection_boxes, void* detection_scores, - void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, - const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream) +cudaError_t cudaYoloLayer_e(const void* cls, const void* reg, void* num_detections, void* detection_boxes, + void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& outputSize, + const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& numOutputClasses, + cudaStream_t stream) { - int threads_per_block = 16; - int number_of_blocks = (outputSize / threads_per_block) + 1; + int threads_per_block = 16; + int number_of_blocks = (outputSize / threads_per_block) + 1; - for (unsigned int batch = 0; batch < batchSize; ++batch) - { - gpuYoloLayer_e<<>>( - reinterpret_cast(cls) + (batch * numOutputClasses * outputSize), - reinterpret_cast(reg) + (batch * 4 * outputSize), - reinterpret_cast(num_detections) + (batch), - reinterpret_cast(detection_boxes) + (batch * 4 * outputSize), - reinterpret_cast(detection_scores) + (batch * outputSize), - reinterpret_cast(detection_classes) + (batch * outputSize), - scoreThreshold, netWidth, netHeight, numOutputClasses, outputSize); - } - return cudaGetLastError(); + for (unsigned int batch = 0; batch < batchSize; ++batch) { + gpuYoloLayer_e<<>>( + reinterpret_cast(cls) + (batch * numOutputClasses * outputSize), + reinterpret_cast(reg) + (batch * 4 * outputSize), reinterpret_cast(num_detections) + (batch), + reinterpret_cast(detection_boxes) + (batch * 4 * outputSize), + reinterpret_cast(detection_scores) + (batch * outputSize), + reinterpret_cast(detection_classes) + (batch * outputSize), scoreThreshold, netWidth, netHeight, + numOutputClasses, outputSize); + } + return cudaGetLastError(); } diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu b/nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu index 703eb0c..45b8ca7 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu +++ b/nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu @@ -5,98 +5,82 @@ #include -__global__ void gpuYoloLayer_nc( - const float* input, int* num_detections, float* detection_boxes, float* detection_scores, int* detection_classes, - const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, const uint gridSizeY, - const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors, const int* mask) +__global__ void gpuYoloLayer_nc(const float* input, int* num_detections, float* detection_boxes, float* detection_scores, + int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, + const uint gridSizeY, const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors, + const int* mask) { - uint x_id = blockIdx.x * blockDim.x + threadIdx.x; - uint y_id = blockIdx.y * blockDim.y + threadIdx.y; - uint z_id = blockIdx.z * blockDim.z + threadIdx.z; + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + uint y_id = blockIdx.y * blockDim.y + threadIdx.y; + uint z_id = blockIdx.z * blockDim.z + threadIdx.z; - if (x_id >= gridSizeX || y_id >= gridSizeY || z_id >= numBBoxes) - return; + if (x_id >= gridSizeX || y_id >= gridSizeY || z_id >= numBBoxes) + return; - const int numGridCells = gridSizeX * gridSizeY; - const int bbindex = y_id * gridSizeX + x_id; + const int numGridCells = gridSizeX * gridSizeY; + const int bbindex = y_id * gridSizeX + x_id; - const float objectness - = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]; + const float objectness = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]; - if (objectness < scoreThreshold) - return; + if (objectness < scoreThreshold) + return; - int count = (int)atomicAdd(num_detections, 1); + int count = (int)atomicAdd(num_detections, 1); - const float alpha = scaleXY; - const float beta = -0.5 * (scaleXY - 1); + const float alpha = scaleXY; + const float beta = -0.5 * (scaleXY - 1); - float x - = (input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] - * alpha + beta + x_id) * netWidth / gridSizeX; + float x = (input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] * alpha + beta + x_id) * netWidth / + gridSizeX; - float y - = (input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] - * alpha + beta + y_id) * netHeight / gridSizeY; + float y = (input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] * alpha + beta + y_id) * netHeight / + gridSizeY; - float w - = __powf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] * 2, 2) - * anchors[mask[z_id] * 2]; + float w = __powf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] * 2, 2) * anchors[mask[z_id] * 2]; - float h - = __powf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] * 2, 2) - * anchors[mask[z_id] * 2 + 1]; + float h = __powf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] * 2, 2) * anchors[mask[z_id] * 2 + 1]; - float maxProb = 0.0f; - int maxIndex = -1; + float maxProb = 0.0f; + int maxIndex = -1; - for (uint i = 0; i < numOutputClasses; ++i) - { - float prob - = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]; - - if (prob > maxProb) - { - maxProb = prob; - maxIndex = i; - } + for (uint i = 0; i < numOutputClasses; ++i) { + float prob = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]; + if (prob > maxProb) { + maxProb = prob; + maxIndex = i; } + } - detection_boxes[count * 4 + 0] = x - 0.5 * w; - detection_boxes[count * 4 + 1] = y - 0.5 * h; - detection_boxes[count * 4 + 2] = x + 0.5 * w; - detection_boxes[count * 4 + 3] = y + 0.5 * h; - detection_scores[count] = objectness * maxProb; - detection_classes[count] = maxIndex; + detection_boxes[count * 4 + 0] = x - 0.5 * w; + detection_boxes[count * 4 + 1] = y - 0.5 * h; + detection_boxes[count * 4 + 2] = x + 0.5 * w; + detection_boxes[count * 4 + 3] = y + 0.5 * h; + detection_scores[count] = objectness * maxProb; + detection_classes[count] = maxIndex; } -cudaError_t cudaYoloLayer_nc( - const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes, - const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, - const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes, - const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream); +cudaError_t cudaYoloLayer_nc(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, + void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, + const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, + const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream); -cudaError_t cudaYoloLayer_nc( - const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes, - const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, - const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes, - const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream) +cudaError_t cudaYoloLayer_nc(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, + void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, + const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, + const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream) { - dim3 threads_per_block(16, 16, 4); - dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, - (gridSizeY / threads_per_block.y) + 1, - (numBBoxes / threads_per_block.z) + 1); + dim3 threads_per_block(16, 16, 4); + dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, (gridSizeY / threads_per_block.y) + 1, + (numBBoxes / threads_per_block.z) + 1); - for (unsigned int batch = 0; batch < batchSize; ++batch) - { - gpuYoloLayer_nc<<>>( - reinterpret_cast(input) + (batch * inputSize), - reinterpret_cast(num_detections) + (batch), - reinterpret_cast(detection_boxes) + (batch * 4 * outputSize), - reinterpret_cast(detection_scores) + (batch * outputSize), - reinterpret_cast(detection_classes) + (batch * outputSize), - scoreThreshold, netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes, scaleXY, - reinterpret_cast(anchors), reinterpret_cast(mask)); - } - return cudaGetLastError(); + for (unsigned int batch = 0; batch < batchSize; ++batch) { + gpuYoloLayer_nc<<>>( + reinterpret_cast(input) + (batch * inputSize), reinterpret_cast(num_detections) + (batch), + reinterpret_cast(detection_boxes) + (batch * 4 * outputSize), + reinterpret_cast(detection_scores) + (batch * outputSize), + reinterpret_cast(detection_classes) + (batch * outputSize), scoreThreshold, netWidth, netHeight, gridSizeX, + gridSizeY, numOutputClasses, numBBoxes, scaleXY, reinterpret_cast(anchors), + reinterpret_cast(mask)); + } + return cudaGetLastError(); } diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward_r.cu b/nvdsinfer_custom_impl_Yolo/yoloForward_r.cu index 3d16ad3..6a0327e 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloForward_r.cu +++ b/nvdsinfer_custom_impl_Yolo/yoloForward_r.cu @@ -7,98 +7,84 @@ inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); } -__global__ void gpuYoloLayer_r( - const float* input, int* num_detections, float* detection_boxes, float* detection_scores, int* detection_classes, - const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, const uint gridSizeY, - const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors, const int* mask) +__global__ void gpuYoloLayer_r(const float* input, int* num_detections, float* detection_boxes, float* detection_scores, + int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, + const uint gridSizeY, const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors, + const int* mask) { - uint x_id = blockIdx.x * blockDim.x + threadIdx.x; - uint y_id = blockIdx.y * blockDim.y + threadIdx.y; - uint z_id = blockIdx.z * blockDim.z + threadIdx.z; + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + uint y_id = blockIdx.y * blockDim.y + threadIdx.y; + uint z_id = blockIdx.z * blockDim.z + threadIdx.z; - if (x_id >= gridSizeX || y_id >= gridSizeY || z_id >= numBBoxes) - return; + if (x_id >= gridSizeX || y_id >= gridSizeY || z_id >= numBBoxes) + return; - const int numGridCells = gridSizeX * gridSizeY; - const int bbindex = y_id * gridSizeX + x_id; + const int numGridCells = gridSizeX * gridSizeY; + const int bbindex = y_id * gridSizeX + x_id; - const float objectness - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]); + const float objectness = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]); - if (objectness < scoreThreshold) - return; + if (objectness < scoreThreshold) + return; - int count = (int)atomicAdd(num_detections, 1); + int count = (int)atomicAdd(num_detections, 1); - const float alpha = scaleXY; - const float beta = -0.5 * (scaleXY - 1); + const float alpha = scaleXY; + const float beta = -0.5 * (scaleXY - 1); - float x - = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) - * alpha + beta + x_id) * netWidth / gridSizeX; + float x = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta + x_id) + * netWidth / gridSizeX; - float y - = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) - * alpha + beta + y_id) * netHeight / gridSizeY; + float y = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * alpha + beta + y_id) + * netHeight / gridSizeY; - float w - = __powf(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * 2, 2) - * anchors[mask[z_id] * 2]; + float w = __powf(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * 2, 2) + * anchors[mask[z_id] * 2]; - float h - = __powf(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * 2, 2) - * anchors[mask[z_id] * 2 + 1]; + float h = __powf(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * 2, 2) + * anchors[mask[z_id] * 2 + 1]; - float maxProb = 0.0f; - int maxIndex = -1; + float maxProb = 0.0f; + int maxIndex = -1; - for (uint i = 0; i < numOutputClasses; ++i) - { - float prob - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]); - - if (prob > maxProb) - { - maxProb = prob; - maxIndex = i; - } + for (uint i = 0; i < numOutputClasses; ++i) { + float prob = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]); + if (prob > maxProb) { + maxProb = prob; + maxIndex = i; } + } - detection_boxes[count * 4 + 0] = x - 0.5 * w; - detection_boxes[count * 4 + 1] = y - 0.5 * h; - detection_boxes[count * 4 + 2] = x + 0.5 * w; - detection_boxes[count * 4 + 3] = y + 0.5 * h; - detection_scores[count] = objectness * maxProb; - detection_classes[count] = maxIndex; + detection_boxes[count * 4 + 0] = x - 0.5 * w; + detection_boxes[count * 4 + 1] = y - 0.5 * h; + detection_boxes[count * 4 + 2] = x + 0.5 * w; + detection_boxes[count * 4 + 3] = y + 0.5 * h; + detection_scores[count] = objectness * maxProb; + detection_classes[count] = maxIndex; } -cudaError_t cudaYoloLayer_r( - const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes, - const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, - const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes, - const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream); +cudaError_t cudaYoloLayer_r(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, + void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, + const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, + const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream); -cudaError_t cudaYoloLayer_r( - const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes, - const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, - const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes, - const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream) +cudaError_t cudaYoloLayer_r(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, + void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, + const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, + const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream) { - dim3 threads_per_block(16, 16, 4); - dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, - (gridSizeY / threads_per_block.y) + 1, - (numBBoxes / threads_per_block.z) + 1); + dim3 threads_per_block(16, 16, 4); + dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, (gridSizeY / threads_per_block.y) + 1, + (numBBoxes / threads_per_block.z) + 1); - for (unsigned int batch = 0; batch < batchSize; ++batch) - { - gpuYoloLayer_r<<>>( - reinterpret_cast(input) + (batch * inputSize), - reinterpret_cast(num_detections) + (batch), - reinterpret_cast(detection_boxes) + (batch * 4 * outputSize), - reinterpret_cast(detection_scores) + (batch * outputSize), - reinterpret_cast(detection_classes) + (batch * outputSize), - scoreThreshold, netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes, scaleXY, - reinterpret_cast(anchors), reinterpret_cast(mask)); - } - return cudaGetLastError(); + for (unsigned int batch = 0; batch < batchSize; ++batch) { + gpuYoloLayer_r<<>>( + reinterpret_cast(input) + (batch * inputSize), reinterpret_cast(num_detections) + (batch), + reinterpret_cast(detection_boxes) + (batch * 4 * outputSize), + reinterpret_cast(detection_scores) + (batch * outputSize), + reinterpret_cast(detection_classes) + (batch * outputSize), scoreThreshold, netWidth, netHeight, gridSizeX, + gridSizeY, numOutputClasses, numBBoxes, scaleXY, reinterpret_cast(anchors), + reinterpret_cast(mask)); + } + return cudaGetLastError(); } diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu b/nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu index b90d1f4..93f12da 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu +++ b/nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu @@ -7,119 +7,100 @@ inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); } -__device__ void softmaxGPU( - const float* input, const int bbindex, const int numGridCells, uint z_id, const uint numOutputClasses, float temp, - float* output) +__device__ void softmaxGPU(const float* input, const int bbindex, const int numGridCells, uint z_id, + const uint numOutputClasses, float temp, float* output) { - int i; - float sum = 0; - float largest = -INFINITY; - for (i = 0; i < numOutputClasses; ++i) { - int val = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]; - largest = (val>largest) ? val : largest; - } - for (i = 0; i < numOutputClasses; ++i) { - float e = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] / temp - largest / temp); - sum += e; - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] = e; - } - for (i = 0; i < numOutputClasses; ++i) { - output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] /= sum; - } + int i; + float sum = 0; + float largest = -INFINITY; + for (i = 0; i < numOutputClasses; ++i) { + int val = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]; + largest = (val>largest) ? val : largest; + } + for (i = 0; i < numOutputClasses; ++i) { + float e = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] / temp - largest / temp); + sum += e; + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] = e; + } + for (i = 0; i < numOutputClasses; ++i) { + output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))] /= sum; + } } -__global__ void gpuRegionLayer( - const float* input, float* softmax, int* num_detections, float* detection_boxes, float* detection_scores, - int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, - const uint gridSizeY, const uint numOutputClasses, const uint numBBoxes, const float* anchors) +__global__ void gpuRegionLayer(const float* input, float* softmax, int* num_detections, float* detection_boxes, + float* detection_scores, int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, + const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses, const uint numBBoxes, const float* anchors) { - uint x_id = blockIdx.x * blockDim.x + threadIdx.x; - uint y_id = blockIdx.y * blockDim.y + threadIdx.y; - uint z_id = blockIdx.z * blockDim.z + threadIdx.z; + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + uint y_id = blockIdx.y * blockDim.y + threadIdx.y; + uint z_id = blockIdx.z * blockDim.z + threadIdx.z; - if (x_id >= gridSizeX || y_id >= gridSizeY || z_id >= numBBoxes) - return; + if (x_id >= gridSizeX || y_id >= gridSizeY || z_id >= numBBoxes) + return; - const int numGridCells = gridSizeX * gridSizeY; - const int bbindex = y_id * gridSizeX + x_id; + const int numGridCells = gridSizeX * gridSizeY; + const int bbindex = y_id * gridSizeX + x_id; - const float objectness - = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]); + const float objectness = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]); - if (objectness < scoreThreshold) - return; + if (objectness < scoreThreshold) + return; - int count = (int)atomicAdd(num_detections, 1); + int count = (int)atomicAdd(num_detections, 1); - float x - = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) - + x_id) * netWidth / gridSizeX; + float x = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) + x_id) * netWidth / gridSizeX; - float y - = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) - + y_id) * netHeight / gridSizeY; + float y = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) + y_id) * netHeight / gridSizeY; - float w - = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) - * anchors[z_id * 2] * netWidth / gridSizeX; + float w = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * anchors[z_id * 2] * netWidth / + gridSizeX; - float h - = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) - * anchors[z_id * 2 + 1] * netHeight / gridSizeY; + float h = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * anchors[z_id * 2 + 1] * netHeight / + gridSizeY; - softmaxGPU(input, bbindex, numGridCells, z_id, numOutputClasses, 1.0, softmax); + softmaxGPU(input, bbindex, numGridCells, z_id, numOutputClasses, 1.0, softmax); - float maxProb = 0.0f; - int maxIndex = -1; + float maxProb = 0.0f; + int maxIndex = -1; - for (uint i = 0; i < numOutputClasses; ++i) - { - float prob - = softmax[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]; - - if (prob > maxProb) - { - maxProb = prob; - maxIndex = i; - } + for (uint i = 0; i < numOutputClasses; ++i) { + float prob = softmax[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]; + if (prob > maxProb) { + maxProb = prob; + maxIndex = i; } + } - detection_boxes[count * 4 + 0] = x - 0.5 * w; - detection_boxes[count * 4 + 1] = y - 0.5 * h; - detection_boxes[count * 4 + 2] = x + 0.5 * w; - detection_boxes[count * 4 + 3] = y + 0.5 * h; - detection_scores[count] = objectness * maxProb; - detection_classes[count] = maxIndex; + detection_boxes[count * 4 + 0] = x - 0.5 * w; + detection_boxes[count * 4 + 1] = y - 0.5 * h; + detection_boxes[count * 4 + 2] = x + 0.5 * w; + detection_boxes[count * 4 + 3] = y + 0.5 * h; + detection_scores[count] = objectness * maxProb; + detection_classes[count] = maxIndex; } -cudaError_t cudaRegionLayer( - const void* input, void* softmax, void* num_detections, void* detection_boxes, void* detection_scores, - void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, - const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, - const uint& numBBoxes, const void* anchors, cudaStream_t stream); +cudaError_t cudaRegionLayer(const void* input, void* softmax, void* num_detections, void* detection_boxes, + void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, + const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, + const uint& numOutputClasses, const uint& numBBoxes, const void* anchors, cudaStream_t stream); -cudaError_t cudaRegionLayer( - const void* input, void* softmax, void* num_detections, void* detection_boxes, void* detection_scores, - void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, - const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, - const uint& numBBoxes, const void* anchors, cudaStream_t stream) +cudaError_t cudaRegionLayer(const void* input, void* softmax, void* num_detections, void* detection_boxes, + void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, + const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, + const uint& numOutputClasses, const uint& numBBoxes, const void* anchors, cudaStream_t stream) { - dim3 threads_per_block(16, 16, 4); - dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, - (gridSizeY / threads_per_block.y) + 1, - (numBBoxes / threads_per_block.z) + 1); + dim3 threads_per_block(16, 16, 4); + dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, (gridSizeY / threads_per_block.y) + 1, + (numBBoxes / threads_per_block.z) + 1); - for (unsigned int batch = 0; batch < batchSize; ++batch) - { - gpuRegionLayer<<>>( - reinterpret_cast(input) + (batch * inputSize), - reinterpret_cast(softmax) + (batch * inputSize), - reinterpret_cast(num_detections) + (batch), - reinterpret_cast(detection_boxes) + (batch * 4 * outputSize), - reinterpret_cast(detection_scores) + (batch * outputSize), - reinterpret_cast(detection_classes) + (batch * outputSize), - scoreThreshold, netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes, - reinterpret_cast(anchors)); - } - return cudaGetLastError(); + for (unsigned int batch = 0; batch < batchSize; ++batch) { + gpuRegionLayer<<>>( + reinterpret_cast(input) + (batch * inputSize), reinterpret_cast(softmax) + (batch * inputSize), + reinterpret_cast(num_detections) + (batch), + reinterpret_cast(detection_boxes) + (batch * 4 * outputSize), + reinterpret_cast(detection_scores) + (batch * outputSize), + reinterpret_cast(detection_classes) + (batch * outputSize), scoreThreshold, netWidth, netHeight, gridSizeX, + gridSizeY, numOutputClasses, numBBoxes, reinterpret_cast(anchors)); + } + return cudaGetLastError(); } diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward_v8.cu b/nvdsinfer_custom_impl_Yolo/yoloForward_v8.cu new file mode 100644 index 0000000..8bc5413 --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/yoloForward_v8.cu @@ -0,0 +1,62 @@ +/* + * Created by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#include + +__global__ void gpuYoloLayer_v8(const float* input, int* num_detections, float* detection_boxes, float* detection_scores, + int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, + const uint numOutputClasses, const uint64_t outputSize) +{ + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (x_id >= outputSize) + return; + + float maxProb = 0.0f; + int maxIndex = -1; + + for (uint i = 0; i < numOutputClasses; ++i) { + float prob = input[x_id * (4 + numOutputClasses) + i + 4]; + if (prob > maxProb) { + maxProb = prob; + maxIndex = i; + } + } + + if (maxProb < scoreThreshold) + return; + + int count = (int)atomicAdd(num_detections, 1); + + detection_boxes[count * 4 + 0] = input[x_id * (4 + numOutputClasses) + 0]; + detection_boxes[count * 4 + 1] = input[x_id * (4 + numOutputClasses) + 1]; + detection_boxes[count * 4 + 2] = input[x_id * (4 + numOutputClasses) + 2]; + detection_boxes[count * 4 + 3] = input[x_id * (4 + numOutputClasses) + 3]; + detection_scores[count] = maxProb; + detection_classes[count] = maxIndex; +} + +cudaError_t cudaYoloLayer_v8(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, + void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, + const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream); + +cudaError_t cudaYoloLayer_v8(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, + void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, + const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream) +{ + int threads_per_block = 16; + int number_of_blocks = (outputSize / threads_per_block) + 1; + + for (unsigned int batch = 0; batch < batchSize; ++batch) { + gpuYoloLayer_v8<<>>( + reinterpret_cast(input) + (batch * (4 + numOutputClasses) * outputSize), + reinterpret_cast(num_detections) + (batch), + reinterpret_cast(detection_boxes) + (batch * 4 * outputSize), + reinterpret_cast(detection_scores) + (batch * outputSize), + reinterpret_cast(detection_classes) + (batch * outputSize), + scoreThreshold, netWidth, netHeight, numOutputClasses, outputSize); + } + return cudaGetLastError(); +} diff --git a/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp b/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp index da1c402..ebb24b5 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp +++ b/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp @@ -24,325 +24,288 @@ */ #include "yoloPlugins.h" -#include "NvInferPlugin.h" -#include -#include -#include - -uint kNUM_CLASSES; namespace { - template - void write(char*& buffer, const T& val) - { - *reinterpret_cast(buffer) = val; - buffer += sizeof(T); - } - - template - void read(const char*& buffer, T& val) - { - val = *reinterpret_cast(buffer); - buffer += sizeof(T); - } + template + void write(char*& buffer, const T& val) { + *reinterpret_cast(buffer) = val; + buffer += sizeof(T); + } + template + void read(const char*& buffer, T& val) { + val = *reinterpret_cast(buffer); + buffer += sizeof(T); + } } -cudaError_t cudaYoloLayer_e( - const void* cls, const void* reg, void* num_detections, void* detection_boxes, void* detection_scores, +cudaError_t cudaYoloLayer_v8(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream); -cudaError_t cudaYoloLayer_r( - const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes, - const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, - const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes, - const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream); +cudaError_t cudaYoloLayer_e(const void* cls, const void* reg, void* num_detections, void* detection_boxes, + void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& outputSize, + const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& numOutputClasses, + cudaStream_t stream); -cudaError_t cudaYoloLayer_nc( - const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes, - const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, - const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes, - const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream); - -cudaError_t cudaYoloLayer( - const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes, - const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, - const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes, - const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream); - -cudaError_t cudaRegionLayer( - const void* input, void* softmax, void* num_detections, void* detection_boxes, void* detection_scores, +cudaError_t cudaYoloLayer_r(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, - const uint& numBBoxes, const void* anchors, cudaStream_t stream); + const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream); -YoloLayer::YoloLayer (const void* data, size_t length) -{ - const char *d = static_cast(data); +cudaError_t cudaYoloLayer_nc(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, + void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, + const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, + const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream); - read(d, m_NetWidth); - read(d, m_NetHeight); - read(d, m_NumClasses); - read(d, m_NewCoords); - read(d, m_OutputSize); - read(d, m_Type); - read(d, m_ScoreThreshold); +cudaError_t cudaYoloLayer(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, + void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, + const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, + const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream); - if (m_Type != 3) { - uint yoloTensorsSize; - read(d, yoloTensorsSize); - for (uint i = 0; i < yoloTensorsSize; ++i) - { - TensorInfo curYoloTensor; - read(d, curYoloTensor.gridSizeX); - read(d, curYoloTensor.gridSizeY); - read(d, curYoloTensor.numBBoxes); - read(d, curYoloTensor.scaleXY); +cudaError_t cudaRegionLayer(const void* input, void* softmax, void* num_detections, void* detection_boxes, + void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, + const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, + const uint& numOutputClasses, const uint& numBBoxes, const void* anchors, cudaStream_t stream); - uint anchorsSize; - read(d, anchorsSize); - for (uint j = 0; j < anchorsSize; j++) - { - float result; - read(d, result); - curYoloTensor.anchors.push_back(result); - } +YoloLayer::YoloLayer(const void* data, size_t length) { + const char* d = static_cast(data); - uint maskSize; - read(d, maskSize); - for (uint j = 0; j < maskSize; j++) - { - int result; - read(d, result); - curYoloTensor.mask.push_back(result); - } - m_YoloTensors.push_back(curYoloTensor); - } + read(d, m_NetWidth); + read(d, m_NetHeight); + read(d, m_NumClasses); + read(d, m_NewCoords); + read(d, m_OutputSize); + read(d, m_Type); + read(d, m_ScoreThreshold); + + if (m_Type != 3 && m_Type != 4) { + uint yoloTensorsSize; + read(d, yoloTensorsSize); + for (uint i = 0; i < yoloTensorsSize; ++i) { + TensorInfo curYoloTensor; + read(d, curYoloTensor.gridSizeX); + read(d, curYoloTensor.gridSizeY); + read(d, curYoloTensor.numBBoxes); + read(d, curYoloTensor.scaleXY); + + uint anchorsSize; + read(d, anchorsSize); + for (uint j = 0; j < anchorsSize; ++j) { + float result; + read(d, result); + curYoloTensor.anchors.push_back(result); + } + + uint maskSize; + read(d, maskSize); + for (uint j = 0; j < maskSize; ++j) { + int result; + read(d, result); + curYoloTensor.mask.push_back(result); + } + + m_YoloTensors.push_back(curYoloTensor); } - - kNUM_CLASSES = m_NumClasses; + } }; -YoloLayer::YoloLayer( - const uint& netWidth, const uint& netHeight, const uint& numClasses, const uint& newCoords, +YoloLayer::YoloLayer(const uint& netWidth, const uint& netHeight, const uint& numClasses, const uint& newCoords, const std::vector& yoloTensors, const uint64_t& outputSize, const uint& modelType, - const float& scoreThreshold) : - m_NetWidth(netWidth), - m_NetHeight(netHeight), - m_NumClasses(numClasses), - m_NewCoords(newCoords), - m_YoloTensors(yoloTensors), - m_OutputSize(outputSize), - m_Type(modelType), + const float& scoreThreshold) : m_NetWidth(netWidth), m_NetHeight(netHeight), m_NumClasses(numClasses), + m_NewCoords(newCoords), m_YoloTensors(yoloTensors), m_OutputSize(outputSize), m_Type(modelType), m_ScoreThreshold(scoreThreshold) { - assert(m_NetWidth > 0); - assert(m_NetHeight > 0); - - kNUM_CLASSES = m_NumClasses; + assert(m_NetWidth > 0); + assert(m_NetHeight > 0); }; nvinfer1::Dims -YoloLayer::getOutputDimensions( - int index, const nvinfer1::Dims* inputs, int nbInputDims) noexcept +YoloLayer::getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) noexcept { - assert(index <= 4); - if (index == 0) { - return nvinfer1::Dims{1, {1}}; - } - else if (index == 1) { - return nvinfer1::Dims{2, {static_cast(m_OutputSize), 4}}; - } - return nvinfer1::Dims{1, {static_cast(m_OutputSize)}}; + assert(index <= 4); + if (index == 0) + return nvinfer1::Dims{1, {1}}; + else if (index == 1) + return nvinfer1::Dims{2, {static_cast(m_OutputSize), 4}}; + return nvinfer1::Dims{1, {static_cast(m_OutputSize)}}; } -bool YoloLayer::supportsFormat ( - nvinfer1::DataType type, nvinfer1::PluginFormat format) const noexcept { - return (type == nvinfer1::DataType::kFLOAT && - format == nvinfer1::PluginFormat::kLINEAR); +bool +YoloLayer::supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) const noexcept { + return (type == nvinfer1::DataType::kFLOAT && format == nvinfer1::PluginFormat::kLINEAR); } void -YoloLayer::configureWithFormat ( - const nvinfer1::Dims* inputDims, int nbInputs, - const nvinfer1::Dims* outputDims, int nbOutputs, - nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) noexcept +YoloLayer::configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs, const nvinfer1::Dims* outputDims, + int nbOutputs, nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) noexcept { - assert(nbInputs > 0); - assert(format == nvinfer1::PluginFormat::kLINEAR); - assert(inputDims != nullptr); + assert(nbInputs > 0); + assert(format == nvinfer1::PluginFormat::kLINEAR); + assert(inputDims != nullptr); } -int32_t YoloLayer::enqueue ( - int batchSize, void const* const* inputs, void* const* outputs, void* workspace, - cudaStream_t stream) noexcept +int32_t +YoloLayer::enqueue(int batchSize, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) + noexcept { - void* num_detections = outputs[0]; - void* detection_boxes = outputs[1]; - void* detection_scores = outputs[2]; - void* detection_classes = outputs[3]; + void* num_detections = outputs[0]; + void* detection_boxes = outputs[1]; + void* detection_scores = outputs[2]; + void* detection_classes = outputs[3]; - CUDA_CHECK(cudaMemsetAsync((int*)num_detections, 0, sizeof(int) * batchSize, stream)); - CUDA_CHECK(cudaMemsetAsync((float*)detection_boxes, 0, sizeof(float) * m_OutputSize * 4 * batchSize, stream)); - CUDA_CHECK(cudaMemsetAsync((float*)detection_scores, 0, sizeof(float) * m_OutputSize * batchSize, stream)); - CUDA_CHECK(cudaMemsetAsync((int*)detection_classes, 0, sizeof(int) * m_OutputSize * batchSize, stream)); + CUDA_CHECK(cudaMemsetAsync((int*)num_detections, 0, sizeof(int) * batchSize, stream)); + CUDA_CHECK(cudaMemsetAsync((float*)detection_boxes, 0, sizeof(float) * m_OutputSize * 4 * batchSize, stream)); + CUDA_CHECK(cudaMemsetAsync((float*)detection_scores, 0, sizeof(float) * m_OutputSize * batchSize, stream)); + CUDA_CHECK(cudaMemsetAsync((int*)detection_classes, 0, sizeof(int) * m_OutputSize * batchSize, stream)); - if (m_Type == 3) - { - CUDA_CHECK(cudaYoloLayer_e( - inputs[0], inputs[1], num_detections, detection_boxes, detection_scores, detection_classes, batchSize, - m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, stream)); - } - else - { - uint yoloTensorsSize = m_YoloTensors.size(); - for (uint i = 0; i < yoloTensorsSize; ++i) - { - TensorInfo& curYoloTensor = m_YoloTensors.at(i); + if (m_Type == 4) { + CUDA_CHECK(cudaYoloLayer_v8(inputs[0], num_detections, detection_boxes, detection_scores, detection_classes, batchSize, + m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, stream)); + } + else if (m_Type == 3) { + CUDA_CHECK(cudaYoloLayer_e(inputs[0], inputs[1], num_detections, detection_boxes, detection_scores, detection_classes, + batchSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, stream)); + } + else { + uint yoloTensorsSize = m_YoloTensors.size(); + for (uint i = 0; i < yoloTensorsSize; ++i) { + TensorInfo& curYoloTensor = m_YoloTensors.at(i); - uint numBBoxes = curYoloTensor.numBBoxes; - float scaleXY = curYoloTensor.scaleXY; - uint gridSizeX = curYoloTensor.gridSizeX; - uint gridSizeY = curYoloTensor.gridSizeY; - std::vector anchors = curYoloTensor.anchors; - std::vector mask = curYoloTensor.mask; + uint numBBoxes = curYoloTensor.numBBoxes; + float scaleXY = curYoloTensor.scaleXY; + uint gridSizeX = curYoloTensor.gridSizeX; + uint gridSizeY = curYoloTensor.gridSizeY; + std::vector anchors = curYoloTensor.anchors; + std::vector mask = curYoloTensor.mask; - void* v_anchors; - void* v_mask; - if (anchors.size() > 0) { - float* f_anchors = anchors.data(); - CUDA_CHECK(cudaMalloc(&v_anchors, sizeof(float) * anchors.size())); - CUDA_CHECK(cudaMemcpyAsync(v_anchors, f_anchors, sizeof(float) * anchors.size(), cudaMemcpyHostToDevice, - stream)); - } - if (mask.size() > 0) { - int* f_mask = mask.data(); - CUDA_CHECK(cudaMalloc(&v_mask, sizeof(int) * mask.size())); - CUDA_CHECK(cudaMemcpyAsync(v_mask, f_mask, sizeof(int) * mask.size(), cudaMemcpyHostToDevice, stream)); - } + void* v_anchors; + void* v_mask; + if (anchors.size() > 0) { + float* f_anchors = anchors.data(); + CUDA_CHECK(cudaMalloc(&v_anchors, sizeof(float) * anchors.size())); + CUDA_CHECK(cudaMemcpyAsync(v_anchors, f_anchors, sizeof(float) * anchors.size(), cudaMemcpyHostToDevice, stream)); + } + if (mask.size() > 0) { + int* f_mask = mask.data(); + CUDA_CHECK(cudaMalloc(&v_mask, sizeof(int) * mask.size())); + CUDA_CHECK(cudaMemcpyAsync(v_mask, f_mask, sizeof(int) * mask.size(), cudaMemcpyHostToDevice, stream)); + } - uint64_t inputSize = gridSizeX * gridSizeY * (numBBoxes * (4 + 1 + m_NumClasses)); + uint64_t inputSize = gridSizeX * gridSizeY * (numBBoxes * (4 + 1 + m_NumClasses)); - if (m_Type == 2) { // YOLOR incorrect param: scale_x_y = 2.0 - CUDA_CHECK(cudaYoloLayer_r( - inputs[i], num_detections, detection_boxes, detection_scores, detection_classes, batchSize, inputSize, - m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, m_NumClasses, numBBoxes, - 2.0, v_anchors, v_mask, stream)); - } - else if (m_Type == 1) { - if (m_NewCoords) { - CUDA_CHECK(cudaYoloLayer_nc( - inputs[i], num_detections, detection_boxes, detection_scores, detection_classes, batchSize, - inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, - m_NumClasses, numBBoxes, scaleXY, v_anchors, v_mask, stream)); - } - else { - CUDA_CHECK(cudaYoloLayer( - inputs[i], num_detections, detection_boxes, detection_scores, detection_classes, batchSize, - inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, - m_NumClasses, numBBoxes, scaleXY, v_anchors, v_mask, stream)); - } - } - else { - void* softmax; - CUDA_CHECK(cudaMalloc(&softmax, sizeof(float) * inputSize * batchSize)); - CUDA_CHECK(cudaMemsetAsync((float*)softmax, 0, sizeof(float) * inputSize * batchSize, stream)); - - CUDA_CHECK(cudaRegionLayer( - inputs[i], softmax, num_detections, detection_boxes, detection_scores, detection_classes, batchSize, - inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, m_NumClasses, - numBBoxes, v_anchors, stream)); - - CUDA_CHECK(cudaFree(softmax)); - } - - if (anchors.size() > 0) { - CUDA_CHECK(cudaFree(v_anchors)); - } - if (mask.size() > 0) { - CUDA_CHECK(cudaFree(v_mask)); - } + if (m_Type == 2) { // YOLOR incorrect param: scale_x_y = 2.0 + CUDA_CHECK(cudaYoloLayer_r(inputs[i], num_detections, detection_boxes, detection_scores, detection_classes, + batchSize, inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, + m_NumClasses, numBBoxes, 2.0, v_anchors, v_mask, stream)); + } + else if (m_Type == 1) { + if (m_NewCoords) { + CUDA_CHECK(cudaYoloLayer_nc( inputs[i], num_detections, detection_boxes, detection_scores, detection_classes, + batchSize, inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, + m_NumClasses, numBBoxes, scaleXY, v_anchors, v_mask, stream)); } - } - - return 0; -} - -size_t YoloLayer::getSerializationSize() const noexcept -{ - size_t totalSize = 0; - - totalSize += sizeof(m_NetWidth); - totalSize += sizeof(m_NetHeight); - totalSize += sizeof(m_NumClasses); - totalSize += sizeof(m_NewCoords); - totalSize += sizeof(m_OutputSize); - totalSize += sizeof(m_Type); - totalSize += sizeof(m_ScoreThreshold); - - if (m_Type != 3) { - uint yoloTensorsSize = m_YoloTensors.size(); - totalSize += sizeof(yoloTensorsSize); - - for (uint i = 0; i < yoloTensorsSize; ++i) - { - const TensorInfo& curYoloTensor = m_YoloTensors.at(i); - totalSize += sizeof(curYoloTensor.gridSizeX); - totalSize += sizeof(curYoloTensor.gridSizeY); - totalSize += sizeof(curYoloTensor.numBBoxes); - totalSize += sizeof(curYoloTensor.scaleXY); - totalSize += sizeof(uint) + sizeof(curYoloTensor.anchors[0]) * curYoloTensor.anchors.size(); - totalSize += sizeof(uint) + sizeof(curYoloTensor.mask[0]) * curYoloTensor.mask.size(); + else { + CUDA_CHECK(cudaYoloLayer(inputs[i], num_detections, detection_boxes, detection_scores, detection_classes, + batchSize, inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, + m_NumClasses, numBBoxes, scaleXY, v_anchors, v_mask, stream)); } - } + } + else { + void* softmax; + CUDA_CHECK(cudaMalloc(&softmax, sizeof(float) * inputSize * batchSize)); + CUDA_CHECK(cudaMemsetAsync((float*)softmax, 0, sizeof(float) * inputSize * batchSize, stream)); - return totalSize; + CUDA_CHECK(cudaRegionLayer(inputs[i], softmax, num_detections, detection_boxes, detection_scores, detection_classes, + batchSize, inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, + m_NumClasses, numBBoxes, v_anchors, stream)); + + CUDA_CHECK(cudaFree(softmax)); + } + + if (anchors.size() > 0) { + CUDA_CHECK(cudaFree(v_anchors)); + } + if (mask.size() > 0) { + CUDA_CHECK(cudaFree(v_mask)); + } + } + } + + return 0; } -void YoloLayer::serialize(void* buffer) const noexcept +size_t +YoloLayer::getSerializationSize() const noexcept { - char *d = static_cast(buffer); + size_t totalSize = 0; - write(d, m_NetWidth); - write(d, m_NetHeight); - write(d, m_NumClasses); - write(d, m_NewCoords); - write(d, m_OutputSize); - write(d, m_Type); - write(d, m_ScoreThreshold); + totalSize += sizeof(m_NetWidth); + totalSize += sizeof(m_NetHeight); + totalSize += sizeof(m_NumClasses); + totalSize += sizeof(m_NewCoords); + totalSize += sizeof(m_OutputSize); + totalSize += sizeof(m_Type); + totalSize += sizeof(m_ScoreThreshold); - if (m_Type != 3) { - uint yoloTensorsSize = m_YoloTensors.size(); - write(d, yoloTensorsSize); - for (uint i = 0; i < yoloTensorsSize; ++i) - { - const TensorInfo& curYoloTensor = m_YoloTensors.at(i); - write(d, curYoloTensor.gridSizeX); - write(d, curYoloTensor.gridSizeY); - write(d, curYoloTensor.numBBoxes); - write(d, curYoloTensor.scaleXY); + if (m_Type != 3 && m_Type != 4) { + uint yoloTensorsSize = m_YoloTensors.size(); + totalSize += sizeof(yoloTensorsSize); - uint anchorsSize = curYoloTensor.anchors.size(); - write(d, anchorsSize); - for (uint j = 0; j < anchorsSize; ++j) - { - write(d, curYoloTensor.anchors[j]); - } - - uint maskSize = curYoloTensor.mask.size(); - write(d, maskSize); - for (uint j = 0; j < maskSize; ++j) - { - write(d, curYoloTensor.mask[j]); - } - } + for (uint i = 0; i < yoloTensorsSize; ++i) { + const TensorInfo& curYoloTensor = m_YoloTensors.at(i); + totalSize += sizeof(curYoloTensor.gridSizeX); + totalSize += sizeof(curYoloTensor.gridSizeY); + totalSize += sizeof(curYoloTensor.numBBoxes); + totalSize += sizeof(curYoloTensor.scaleXY); + totalSize += sizeof(uint) + sizeof(curYoloTensor.anchors[0]) * curYoloTensor.anchors.size(); + totalSize += sizeof(uint) + sizeof(curYoloTensor.mask[0]) * curYoloTensor.mask.size(); } + } + + return totalSize; } -nvinfer1::IPluginV2* YoloLayer::clone() const noexcept +void +YoloLayer::serialize(void* buffer) const noexcept { - return new YoloLayer ( - m_NetWidth, m_NetHeight, m_NumClasses, m_NewCoords, m_YoloTensors, m_OutputSize, m_Type, m_ScoreThreshold); + char* d = static_cast(buffer); + + write(d, m_NetWidth); + write(d, m_NetHeight); + write(d, m_NumClasses); + write(d, m_NewCoords); + write(d, m_OutputSize); + write(d, m_Type); + write(d, m_ScoreThreshold); + + if (m_Type != 3 && m_Type != 4) { + uint yoloTensorsSize = m_YoloTensors.size(); + write(d, yoloTensorsSize); + for (uint i = 0; i < yoloTensorsSize; ++i) { + const TensorInfo& curYoloTensor = m_YoloTensors.at(i); + write(d, curYoloTensor.gridSizeX); + write(d, curYoloTensor.gridSizeY); + write(d, curYoloTensor.numBBoxes); + write(d, curYoloTensor.scaleXY); + + uint anchorsSize = curYoloTensor.anchors.size(); + write(d, anchorsSize); + for (uint j = 0; j < anchorsSize; ++j) + write(d, curYoloTensor.anchors[j]); + + uint maskSize = curYoloTensor.mask.size(); + write(d, maskSize); + for (uint j = 0; j < maskSize; ++j) + write(d, curYoloTensor.mask[j]); + } + } +} + +nvinfer1::IPluginV2* +YoloLayer::clone() const noexcept +{ + return new YoloLayer(m_NetWidth, m_NetHeight, m_NumClasses, m_NewCoords, m_YoloTensors, m_OutputSize, m_Type, + m_ScoreThreshold); } REGISTER_TENSORRT_PLUGIN(YoloLayerPluginCreator); diff --git a/nvdsinfer_custom_impl_Yolo/yoloPlugins.h b/nvdsinfer_custom_impl_Yolo/yoloPlugins.h index 961b630..7d41791 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloPlugins.h +++ b/nvdsinfer_custom_impl_Yolo/yoloPlugins.h @@ -26,88 +26,64 @@ #ifndef __YOLO_PLUGINS__ #define __YOLO_PLUGINS__ -#include -#include -#include -#include -#include - -#include - -#include "NvInferPlugin.h" - #include "yolo.h" -#define CUDA_CHECK(status) \ - { \ - if (status != 0) \ - { \ - std::cout << "CUDA failure: " << cudaGetErrorString(status) << " in file " << __FILE__ << " at line " \ - << __LINE__ << std::endl; \ - abort(); \ - } \ - } +#define CUDA_CHECK(status) { \ + if (status != 0) { \ + std::cout << "CUDA failure: " << cudaGetErrorString(status) << " in file " << __FILE__ << " at line " << __LINE__ << \ + std::endl; \ + abort(); \ + } \ +} -namespace -{ -const char* YOLOLAYER_PLUGIN_VERSION {"1"}; -const char* YOLOLAYER_PLUGIN_NAME {"YoloLayer_TRT"}; +namespace { + const char* YOLOLAYER_PLUGIN_VERSION {"1"}; + const char* YOLOLAYER_PLUGIN_NAME {"YoloLayer_TRT"}; } // namespace -class YoloLayer : public nvinfer1::IPluginV2 -{ -public: - YoloLayer (const void* data, size_t length); +class YoloLayer : public nvinfer1::IPluginV2 { + public: + YoloLayer(const void* data, size_t length); - YoloLayer ( - const uint& netWidth, const uint& netHeight, const uint& numClasses, const uint& newCoords, + YoloLayer(const uint& netWidth, const uint& netHeight, const uint& numClasses, const uint& newCoords, const std::vector& yoloTensors, const uint64_t& outputSize, const uint& modelType, const float& scoreThreshold); - const char* getPluginType () const noexcept override { return YOLOLAYER_PLUGIN_NAME; } + const char* getPluginType() const noexcept override { return YOLOLAYER_PLUGIN_NAME; } - const char* getPluginVersion () const noexcept override { return YOLOLAYER_PLUGIN_VERSION; } + const char* getPluginVersion() const noexcept override { return YOLOLAYER_PLUGIN_VERSION; } - int getNbOutputs () const noexcept override { return 4; } + int getNbOutputs() const noexcept override { return 4; } - nvinfer1::Dims getOutputDimensions ( - int index, const nvinfer1::Dims* inputs, - int nbInputDims) noexcept override; + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) noexcept override; - bool supportsFormat ( - nvinfer1::DataType type, nvinfer1::PluginFormat format) const noexcept override; + bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) const noexcept override; - void configureWithFormat ( - const nvinfer1::Dims* inputDims, int nbInputs, const nvinfer1::Dims* outputDims, int nbOutputs, + void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs, const nvinfer1::Dims* outputDims, int nbOutputs, nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) noexcept override; - int initialize () noexcept override { return 0; } + int initialize() noexcept override { return 0; } - void terminate () noexcept override {} + void terminate() noexcept override {} - size_t getWorkspaceSize (int maxBatchSize) const noexcept override { return 0; } + size_t getWorkspaceSize(int maxBatchSize) const noexcept override { return 0; } - int32_t enqueue ( - int batchSize, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) + int32_t enqueue(int batchSize, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; size_t getSerializationSize() const noexcept override; - void serialize (void* buffer) const noexcept override; + void serialize(void* buffer) const noexcept override; - void destroy () noexcept override { delete this; } + void destroy() noexcept override { delete this; } nvinfer1::IPluginV2* clone() const noexcept override; - void setPluginNamespace (const char* pluginNamespace) noexcept override { - m_Namespace = pluginNamespace; - } + void setPluginNamespace(const char* pluginNamespace) noexcept override { m_Namespace = pluginNamespace; } - virtual const char* getPluginNamespace () const noexcept override { - return m_Namespace.c_str(); - } + virtual const char* getPluginNamespace() const noexcept override { return m_Namespace.c_str(); } -private: + private: std::string m_Namespace {""}; uint m_NetWidth {0}; uint m_NetHeight {0}; @@ -119,47 +95,37 @@ private: float m_ScoreThreshold {0}; }; -class YoloLayerPluginCreator : public nvinfer1::IPluginCreator -{ -public: - YoloLayerPluginCreator () {} +class YoloLayerPluginCreator : public nvinfer1::IPluginCreator { + public: + YoloLayerPluginCreator() {} - ~YoloLayerPluginCreator () {} + ~YoloLayerPluginCreator() {} - const char* getPluginName () const noexcept override { return YOLOLAYER_PLUGIN_NAME; } + const char* getPluginName() const noexcept override { return YOLOLAYER_PLUGIN_NAME; } - const char* getPluginVersion () const noexcept override { return YOLOLAYER_PLUGIN_VERSION; } + const char* getPluginVersion() const noexcept override { return YOLOLAYER_PLUGIN_VERSION; } const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override { - std::cerr<< "YoloLayerPluginCreator::getFieldNames is not implemented" << std::endl; - return nullptr; + std::cerr<< "YoloLayerPluginCreator::getFieldNames is not implemented" << std::endl; + return nullptr; } - nvinfer1::IPluginV2* createPlugin ( - const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override - { - std::cerr<< "YoloLayerPluginCreator::getFieldNames is not implemented"; - return nullptr; + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override { + std::cerr<< "YoloLayerPluginCreator::getFieldNames is not implemented"; + return nullptr; } - nvinfer1::IPluginV2* deserializePlugin ( - const char* name, const void* serialData, size_t serialLength) noexcept override - { - std::cout << "Deserialize yoloLayer plugin: " << name << std::endl; - return new YoloLayer(serialData, serialLength); + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override { + std::cout << "Deserialize yoloLayer plugin: " << name << std::endl; + return new YoloLayer(serialData, serialLength); } - void setPluginNamespace(const char* libNamespace) noexcept override { - m_Namespace = libNamespace; - } - const char* getPluginNamespace() const noexcept override { - return m_Namespace.c_str(); - } + void setPluginNamespace(const char* libNamespace) noexcept override { m_Namespace = libNamespace; } -private: + const char* getPluginNamespace() const noexcept override { return m_Namespace.c_str(); } + + private: std::string m_Namespace {""}; }; -extern uint kNUM_CLASSES; - #endif // __YOLO_PLUGINS__ diff --git a/utils/gen_wts_ppyoloe.py b/utils/gen_wts_ppyoloe.py index 8c985b2..825e9f0 100644 --- a/utils/gen_wts_ppyoloe.py +++ b/utils/gen_wts_ppyoloe.py @@ -8,6 +8,7 @@ from ppdet.utils.cli import ArgsParser from ppdet.engine import Trainer from ppdet.slim import build_slim_model + class Layers(object): def __init__(self, size, fw, fc, letter_box): self.blocks = [0 for _ in range(300)] @@ -123,7 +124,7 @@ class Layers(object): def Shuffle(self, reshape=None, transpose1=None, transpose2=None, route=None, output=''): self.current += 1 - r = 0 + r = None if route is not None: r = self.get_route(route) self.shuffle(reshape=reshape, transpose1=transpose1, transpose2=transpose2, route=r) @@ -156,7 +157,7 @@ class Layers(object): 'channels=3\n' + lb) - def convolutional(self, cv, act='linear', detect=False): + def convolutional(self, cv, act='linear'): self.blocks[self.current] += 1 self.get_state_dict(cv.state_dict()) @@ -178,9 +179,6 @@ class Layers(object): bias = cv.conv.bias bn = True if hasattr(cv, 'bn') else False - if detect: - act = 'logistic' - b = 'batch_normalize=1\n' if bn is True else '' g = 'groups=%d\n' % groups if groups > 1 else '' w = 'bias=0\n' if bias is None and bn is False else '' @@ -251,9 +249,9 @@ class Layers(object): def shuffle(self, reshape=None, transpose1=None, transpose2=None, route=None): self.blocks[self.current] += 1 - r = 'reshape=%s\n' % str(reshape)[1:-1] if reshape is not None else '' - t1 = 'transpose1=%s\n' % str(transpose1)[1:-1] if transpose1 is not None else '' - t2 = 'transpose2=%s\n' % str(transpose2)[1:-1] if transpose2 is not None else '' + r = 'reshape=%s\n' % ', '.join(str(x) for x in reshape) if reshape is not None else '' + t1 = 'transpose1=%s\n' % ', '.join(str(x) for x in transpose1) if transpose1 is not None else '' + t2 = 'transpose2=%s\n' % ', '.join(str(x) for x in transpose2) if transpose2 is not None else '' f = 'from=%d\n' % route if route is not None else '' self.fc.write('\n[shuffle]\n' + @@ -419,13 +417,13 @@ with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc: layers.AvgPool2d() layers.ESEAttn(model.yolo_head.stem_cls[i]) layers.Conv2D(model.yolo_head.pred_cls[i], act='sigmoid') - layers.Shuffle(reshape=[model.yolo_head.num_classes, 0], route=feat, output='cls') + layers.Shuffle(reshape=[model.yolo_head.num_classes, 'hw'], route=feat, output='cls') layers.ESEAttn(model.yolo_head.stem_reg[i], route=-7) layers.Conv2D(model.yolo_head.pred_reg[i]) - layers.Shuffle(reshape=[4, model.yolo_head.reg_max + 1, 0], transpose2=[1, 0, 2], route=feat) + layers.Shuffle(reshape=[4, model.yolo_head.reg_max + 1, 'hw'], transpose2=[1, 0, 2], route=feat) layers.SoftMax(0) layers.Conv2D(model.yolo_head.proj_conv) - layers.Shuffle(reshape=[4, 0], route=feat, output='reg') + layers.Shuffle(reshape=[4, 'hw'], route=feat, output='reg') layers.Detect('cls') layers.Detect('reg') layers.get_anchors(model.yolo_head.anchor_points.reshape([-1]), model.yolo_head.stride_tensor) diff --git a/utils/gen_wts_yoloV8.py b/utils/gen_wts_yoloV8.py new file mode 100644 index 0000000..4be6b1f --- /dev/null +++ b/utils/gen_wts_yoloV8.py @@ -0,0 +1,315 @@ +import argparse +import os +import struct +import torch +from ultralytics.yolo.utils.torch_utils import select_device + + +class Layers(object): + def __init__(self, n, size, fw, fc): + self.blocks = [0 for _ in range(n)] + self.current = -1 + + self.width = size[0] if len(size) == 1 else size[1] + self.height = size[0] + + self.fw = fw + self.fc = fc + self.wc = 0 + + self.net() + + def Conv(self, child): + self.current = child.i + self.fc.write('\n# Conv\n') + + self.convolutional(child) + + def C2f(self, child): + self.current = child.i + self.fc.write('\n# C2f\n') + + self.convolutional(child.cv1) + self.c2f(child.m) + self.convolutional(child.cv2) + + def SPPF(self, child): + self.current = child.i + self.fc.write('\n# SPPF\n') + + self.convolutional(child.cv1) + self.maxpool(child.m) + self.maxpool(child.m) + self.maxpool(child.m) + self.route('-4, -3, -2, -1') + self.convolutional(child.cv2) + + def Upsample(self, child): + self.current = child.i + self.fc.write('\n# Upsample\n') + + self.upsample(child) + + def Concat(self, child): + self.current = child.i + self.fc.write('\n# Concat\n') + + r = [] + for i in range(1, len(child.f)): + r.append(self.get_route(child.f[i])) + self.route('-1, %s' % str(r)[1:-1]) + + def Detect(self, child): + self.current = child.i + self.fc.write('\n# Detect\n') + + output_idxs = [0 for _ in range(child.nl)] + for i in range(child.nl): + r = self.get_route(child.f[i]) + self.route('%d' % r) + for j in range(len(child.cv3[i])): + self.convolutional(child.cv3[i][j]) + self.route('%d' % (-1 - len(child.cv3[i]))) + for j in range(len(child.cv2[i])): + self.convolutional(child.cv2[i][j]) + self.route('-1, %d' % (-2 - len(child.cv2[i]))) + self.shuffle(reshape=[child.no, -1]) + output_idxs[i] = (-1 + i * (-4 - len(child.cv3[i]) - len(child.cv2[i]))) + self.route('%s' % str(output_idxs[::-1])[1:-1], axis=1) + self.yolo(child) + + def net(self): + self.fc.write('[net]\n' + + 'width=%d\n' % self.width + + 'height=%d\n' % self.height + + 'channels=3\n' + + 'letter_box=1\n') + + def convolutional(self, cv, act=None, detect=False): + self.blocks[self.current] += 1 + + self.get_state_dict(cv.state_dict()) + + if cv._get_name() == 'Conv2d': + filters = cv.out_channels + size = cv.kernel_size + stride = cv.stride + pad = cv.padding + groups = cv.groups + bias = cv.bias + bn = False + act = 'linear' if not detect else 'logistic' + else: + filters = cv.conv.out_channels + size = cv.conv.kernel_size + stride = cv.conv.stride + pad = cv.conv.padding + groups = cv.conv.groups + bias = cv.conv.bias + bn = True if hasattr(cv, 'bn') else False + if act is None: + act = self.get_activation(cv.act._get_name()) if hasattr(cv, 'act') else 'linear' + + b = 'batch_normalize=1\n' if bn is True else '' + g = 'groups=%d\n' % groups if groups > 1 else '' + w = 'bias=0\n' if bias is None and bn is False else '' + + self.fc.write('\n[convolutional]\n' + + b + + 'filters=%d\n' % filters + + 'size=%s\n' % self.get_value(size) + + 'stride=%s\n' % self.get_value(stride) + + 'pad=%s\n' % self.get_value(pad) + + g + + w + + 'activation=%s\n' % act) + + def c2f(self, m): + self.blocks[self.current] += 1 + + for x in m: + self.get_state_dict(x.state_dict()) + + n = len(m) + shortcut = 1 if m[0].add else 0 + filters = m[0].cv1.conv.out_channels + size = m[0].cv1.conv.kernel_size + stride = m[0].cv1.conv.stride + pad = m[0].cv1.conv.padding + groups = m[0].cv1.conv.groups + bias = m[0].cv1.conv.bias + bn = True if hasattr(m[0].cv1, 'bn') else False + act = 'linear' + if hasattr(m[0].cv1, 'act'): + act = self.get_activation(m[0].cv1.act._get_name()) + + b = 'batch_normalize=1\n' if bn is True else '' + g = 'groups=%d\n' % groups if groups > 1 else '' + w = 'bias=0\n' if bias is None and bn is False else '' + + self.fc.write('\n[c2f]\n' + + 'n=%d\n' % n + + 'shortcut=%d\n' % shortcut + + b + + 'filters=%d\n' % filters + + 'size=%s\n' % self.get_value(size) + + 'stride=%s\n' % self.get_value(stride) + + 'pad=%s\n' % self.get_value(pad) + + g + + w + + 'activation=%s\n' % act) + + def route(self, layers, axis=0): + self.blocks[self.current] += 1 + + a = 'axis=%d\n' % axis if axis != 0 else '' + + self.fc.write('\n[route]\n' + + 'layers=%s\n' % layers + + a) + + def shortcut(self, r, ew='add', act='linear'): + self.blocks[self.current] += 1 + + m = 'mode=mul\n' if ew == 'mul' else '' + + self.fc.write('\n[shortcut]\n' + + 'from=%d\n' % r + + m + + 'activation=%s\n' % act) + + def maxpool(self, m): + self.blocks[self.current] += 1 + + stride = m.stride + size = m.kernel_size + mode = m.ceil_mode + + m = 'maxpool_up' if mode else 'maxpool' + + self.fc.write('\n[%s]\n' % m + + 'stride=%d\n' % stride + + 'size=%d\n' % size) + + def upsample(self, child): + self.blocks[self.current] += 1 + + stride = child.scale_factor + + self.fc.write('\n[upsample]\n' + + 'stride=%d\n' % stride) + + def shuffle(self, reshape=None, transpose1=None, transpose2=None, route=None): + self.blocks[self.current] += 1 + + r = 'reshape=%s\n' % ', '.join(str(x) for x in reshape) if reshape is not None else '' + t1 = 'transpose1=%s\n' % ', '.join(str(x) for x in transpose1) if transpose1 is not None else '' + t2 = 'transpose2=%s\n' % ', '.join(str(x) for x in transpose2) if transpose2 is not None else '' + f = 'from=%d\n' % route if route is not None else '' + + self.fc.write('\n[shuffle]\n' + + r + + t1 + + t2 + + f) + + def yolo(self, child): + self.blocks[self.current] += 1 + + self.fc.write('\n[detect_v8]\n' + + 'num=%d\n' % (child.reg_max * 4) + + 'classes=%d\n' % child.nc) + + def get_state_dict(self, state_dict): + for k, v in state_dict.items(): + if 'num_batches_tracked' not in k: + vr = v.reshape(-1).numpy() + self.fw.write('{} {} '.format(k, len(vr))) + for vv in vr: + self.fw.write(' ') + self.fw.write(struct.pack('>f', float(vv)).hex()) + self.fw.write('\n') + self.wc += 1 + + def get_anchors(self, anchor_points, stride_tensor): + vr = anchor_points.numpy() + self.fw.write('{} {} '.format('anchor_points', len(vr))) + for vv in vr: + self.fw.write(' ') + self.fw.write(struct.pack('>f', float(vv)).hex()) + self.fw.write('\n') + self.wc += 1 + vr = stride_tensor.numpy() + self.fw.write('{} {} '.format('stride_tensor', len(vr))) + for vv in vr: + self.fw.write(' ') + self.fw.write(struct.pack('>f', float(vv)).hex()) + self.fw.write('\n') + self.wc += 1 + + def get_value(self, key): + if type(key) == int: + return key + return key[0] if key[0] == key[1] else str(key)[1:-1] + + def get_route(self, n): + r = 0 + for i, b in enumerate(self.blocks): + if i <= n: + r += b + else: + break + return r - 1 + + def get_activation(self, act): + if act == 'Hardswish': + return 'hardswish' + elif act == 'LeakyReLU': + return 'leaky' + elif act == 'SiLU': + return 'silu' + return 'linear' + + +def parse_args(): + parser = argparse.ArgumentParser(description='PyTorch YOLOv8 conversion') + parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)') + parser.add_argument( + '-s', '--size', nargs='+', type=int, default=[640], help='Inference size [H,W] (default [640])') + args = parser.parse_args() + if not os.path.isfile(args.weights): + raise SystemExit('Invalid weights file') + return args.weights, args.size + + +pt_file, inference_size = parse_args() + +model_name = os.path.basename(pt_file).split('.pt')[0] +wts_file = model_name + '.wts' if 'yolov8' in model_name else 'yolov8_' + model_name + '.wts' +cfg_file = model_name + '.cfg' if 'yolov8' in model_name else 'yolov8_' + model_name + '.cfg' + +device = select_device('cpu') +model = torch.load(pt_file, map_location=device)['model'].float() +model.to(device).eval() + +with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc: + layers = Layers(len(model.model), inference_size, fw, fc) + + for child in model.model.children(): + if child._get_name() == 'Conv': + layers.Conv(child) + elif child._get_name() == 'C2f': + layers.C2f(child) + elif child._get_name() == 'SPPF': + layers.SPPF(child) + elif child._get_name() == 'Upsample': + layers.Upsample(child) + elif child._get_name() == 'Concat': + layers.Concat(child) + elif child._get_name() == 'Detect': + layers.Detect(child) + layers.get_anchors(child.anchors.reshape([-1]), child.strides.reshape([-1])) + else: + raise SystemExit('Model not supported') + +os.system('echo "%d" | cat - %s > temp && mv temp %s' % (layers.wc, wts_file, wts_file))