Add YOLOv8 support

This commit is contained in:
Marcos Luciano
2023-01-27 15:56:00 -03:00
parent f1cd701247
commit f9c7a4dfca
59 changed files with 3260 additions and 2763 deletions

View File

@@ -8,6 +8,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models
* YOLOX support
* YOLOv6 support
* Dynamic batch-size
* PP-YOLOE+ support
### Improvements on this repository
@@ -22,11 +23,12 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models
* New documentation for multiple models
* YOLOv5 support
* YOLOR support
* **GPU YOLO Decoder** [#138](https://github.com/marcoslucianops/DeepStream-Yolo/issues/138)
* **PP-YOLOE support**
* **YOLOv7 support**
* **Optimized NMS** [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142)
* **Models benchmarks**
* GPU YOLO Decoder [#138](https://github.com/marcoslucianops/DeepStream-Yolo/issues/138)
* PP-YOLOE support
* YOLOv7 support
* Optimized NMS [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142)
* Models benchmarks
* **YOLOv8 support**
##
@@ -44,6 +46,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models
* [YOLOR usage](docs/YOLOR.md)
* [PP-YOLOE usage](docs/PPYOLOE.md)
* [YOLOv7 usage](docs/YOLOv7.md)
* [YOLOv8 usage](docs/YOLOv8.md)
* [Using your custom model](docs/customModels.md)
* [Multiple YOLO GIEs](docs/multipleGIEs.md)
@@ -108,6 +111,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models
* [YOLOR](https://github.com/WongKinYiu/yolor)
* [PP-YOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)
* [YOLOv7](https://github.com/WongKinYiu/yolov7)
* [YOLOv8](https://github.com/ultralytics/ultralytics)
* [MobileNet-YOLO](https://github.com/dog-qiuqiu/MobileNet-Yolo)
* [YOLO-Fastest](https://github.com/dog-qiuqiu/Yolo-Fastest)
@@ -131,7 +135,7 @@ sample = 1920x1080 video
- Eval
```
nms-iou-threshold = 0.6 (Darknet) / 0.65 (PyTorch) / 0.7 (Paddle)
nms-iou-threshold = 0.6 (Darknet and YOLOv8) / 0.65 (YOLOR, YOLOv5 and YOLOv7) / 0.7 (Paddle)
pre-cluster-threshold = 0.001
topk = 300
```

View File

@@ -16,6 +16,7 @@ process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=0
symmetric-padding=1
parse-bbox-func-name=NvDsInferParseYolo
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
engine-create-func-name=NvDsInferYoloCudaEngineGet

View File

@@ -16,10 +16,10 @@ process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=1
parse-bbox-func-name=NvDsInferParseYolo
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
engine-create-func-name=NvDsInferYoloCudaEngineGet
symmetric-padding=1
[class-attrs-all]
nms-iou-threshold=0.45

View File

@@ -16,6 +16,7 @@ process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=1
parse-bbox-func-name=NvDsInferParseYolo
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
engine-create-func-name=NvDsInferYoloCudaEngineGet

View File

@@ -0,0 +1,27 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
custom-network-config=yolov8s.cfg
model-file=yolov8s.wts
model-engine-file=model_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=1
parse-bbox-func-name=NvDsInferParseYolo
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
engine-create-func-name=NvDsInferYoloCudaEngineGet
[class-attrs-all]
nms-iou-threshold=0.45
pre-cluster-threshold=0.25
topk=300

View File

@@ -16,6 +16,7 @@ process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=1
parse-bbox-func-name=NvDsInferParseYolo
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
engine-create-func-name=NvDsInferYoloCudaEngineGet

139
docs/YOLOv8.md Normal file
View File

@@ -0,0 +1,139 @@
# YOLOv8 usage
**NOTE**: The yaml file is not required.
* [Convert model](#convert-model)
* [Compile the lib](#compile-the-lib)
* [Edit the config_infer_primary_yoloV8 file](#edit-the-config_infer_primary_yolov8-file)
* [Edit the deepstream_app_config file](#edit-the-deepstream_app_config-file)
* [Testing the model](#testing-the-model)
##
### Convert model
#### 1. Download the YOLOv8 repo and install the requirements
```
git clone https://github.com/ultralytics/ultralytics.git
cd ultralytics
pip3 install -r requirements.txt
```
**NOTE**: It is recommended to use Python virtualenv.
#### 2. Copy conversor
Copy the `gen_wts_yoloV8.py` file from `DeepStream-Yolo/utils` directory to the `ultralytics` folder.
#### 3. Download the model
Download the `pt` file from [YOLOv8](https://github.com/ultralytics/assets/releases/) releases (example for YOLOv8s)
```
wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s.pt
```
**NOTE**: You can use your custom model, but it is important to keep the YOLO model reference (`yolov8_`) in you `cfg` and `weights`/`wts` filenames to generate the engine correctly.
#### 4. Convert model
Generate the `cfg` and `wts` files (example for YOLOv8s)
```
python3 gen_wts_yoloV8.py -w yolov8s.pt
```
**NOTE**: To change the inference size (defaut: 640)
```
-s SIZE
--size SIZE
-s HEIGHT WIDTH
--size HEIGHT WIDTH
```
Example for 1280
```
-s 1280
```
or
```
-s 1280 1280
```
#### 5. Copy generated files
Copy the generated `cfg` and `wts` files to the `DeepStream-Yolo` folder.
##
### Compile the lib
Open the `DeepStream-Yolo` folder and compile the lib
* DeepStream 6.1.1 on x86 platform
```
CUDA_VER=11.7 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.1 on x86 platform
```
CUDA_VER=11.6 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.0.1 / 6.0 on x86 platform
```
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.1.1 / 6.1 on Jetson platform
```
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.0.1 / 6.0 on Jetson platform
```
CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
```
##
### Edit the config_infer_primary_yoloV8 file
Edit the `config_infer_primary_yoloV8.txt` file according to your model (example for YOLOv8s)
```
[property]
...
custom-network-config=yolov8s.cfg
model-file=yolov8s.wts
...
```
##
### Edit the deepstream_app_config file
```
...
[primary-gie]
...
config-file=config_infer_primary_yoloV8.txt
```
##
### Testing the model
```
deepstream-app -c deepstream_app_config.txt
```

View File

@@ -4,20 +4,21 @@
*/
#include "calibrator.h"
#include <fstream>
#include <iterator>
namespace nvinfer1
{
Int8EntropyCalibrator2::Int8EntropyCalibrator2(const int &batchsize, const int &channels, const int &height, const int &width, const int &letterbox, const std::string &imgPath,
const std::string &calibTablePath):batchSize(batchsize), inputC(channels), inputH(height), inputW(width), letterBox(letterbox), calibTablePath(calibTablePath), imageIndex(0)
Int8EntropyCalibrator2::Int8EntropyCalibrator2(const int& batchsize, const int& channels, const int& height,
const int& width, const int& letterbox, const std::string& imgPath,
const std::string& calibTablePath) : batchSize(batchsize), inputC(channels), inputH(height), inputW(width),
letterBox(letterbox), calibTablePath(calibTablePath), imageIndex(0)
{
inputCount = batchsize * channels * height * width;
std::fstream f(imgPath);
if (f.is_open())
{
if (f.is_open()) {
std::string temp;
while (std::getline(f, temp)) imgPaths.push_back(temp);
while (std::getline(f, temp))
imgPaths.push_back(temp);
}
batchData = new float[inputCount];
CUDA_CHECK(cudaMalloc(&deviceInput, inputCount * sizeof(float)));
@@ -30,28 +31,29 @@ namespace nvinfer1
delete[] batchData;
}
int Int8EntropyCalibrator2::getBatchSize() const noexcept
int
Int8EntropyCalibrator2::getBatchSize() const noexcept
{
return batchSize;
}
bool Int8EntropyCalibrator2::getBatch(void **bindings, const char **names, int nbBindings) noexcept
bool
Int8EntropyCalibrator2::getBatch(void** bindings, const char** names, int nbBindings) noexcept
{
if (imageIndex + batchSize > uint(imgPaths.size()))
return false;
float* ptr = batchData;
for (size_t j = imageIndex; j < imageIndex + batchSize; ++j)
{
cv::Mat img = cv::imread(imgPaths[j], cv::IMREAD_COLOR);
for (size_t i = imageIndex; i < imageIndex + batchSize; ++i) {
cv::Mat img = cv::imread(imgPaths[i], cv::IMREAD_COLOR);
std::vector<float> inputData = prepareImage(img, inputC, inputH, inputW, letterBox);
int len = (int) (inputData.size());
memcpy(ptr, inputData.data(), len * sizeof(float));
ptr += inputData.size();
std::cout << "Load image: " << imgPaths[j] << std::endl;
std::cout << "Progress: " << (j + 1)*100. / imgPaths.size() << "%" << std::endl;
std::cout << "Load image: " << imgPaths[i] << std::endl;
std::cout << "Progress: " << (i + 1)*100. / imgPaths.size() << "%" << std::endl;
}
imageIndex += batchSize;
CUDA_CHECK(cudaMemcpy(deviceInput, batchData, inputCount * sizeof(float), cudaMemcpyHostToDevice));
@@ -59,84 +61,73 @@ namespace nvinfer1
return true;
}
const void* Int8EntropyCalibrator2::readCalibrationCache(std::size_t &length) noexcept
const void*
Int8EntropyCalibrator2::readCalibrationCache(std::size_t &length) noexcept
{
calibrationCache.clear();
std::ifstream input(calibTablePath, std::ios::binary);
input >> std::noskipws;
if (readCache && input.good())
{
std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(),
std::back_inserter(calibrationCache));
}
std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(), std::back_inserter(calibrationCache));
length = calibrationCache.size();
return length ? calibrationCache.data() : nullptr;
}
void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, std::size_t length) noexcept
void
Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, std::size_t length) noexcept
{
std::ofstream output(calibTablePath, std::ios::binary);
output.write(reinterpret_cast<const char*>(cache), length);
}
}
std::vector<float> prepareImage(cv::Mat& img, int input_c, int input_h, int input_w, int letter_box)
std::vector<float>
prepareImage(cv::Mat& img, int input_c, int input_h, int input_w, int letter_box)
{
cv::Mat out;
int image_w = img.cols;
int image_h = img.rows;
if (image_w != input_w || image_h != input_h)
{
if (letter_box == 1)
{
if (image_w != input_w || image_h != input_h) {
if (letter_box == 1) {
float ratio_w = (float) image_w / (float) input_w;
float ratio_h = (float) image_h / (float) input_h;
if (ratio_w > ratio_h)
{
if (ratio_w > ratio_h) {
int new_width = input_w * ratio_h;
int x = (image_w - new_width) / 2;
cv::Rect roi(abs(x), 0, new_width, image_h);
out = img(roi);
}
else if (ratio_w < ratio_h)
{
else if (ratio_w < ratio_h) {
int new_height = input_h * ratio_w;
int y = (image_h - new_height) / 2;
cv::Rect roi(0, abs(y), image_w, new_height);
out = img(roi);
}
else {
else
out = img;
}
cv::resize(out, out, cv::Size(input_w, input_h), 0, 0, cv::INTER_CUBIC);
}
else
{
else {
cv::resize(img, out, cv::Size(input_w, input_h), 0, 0, cv::INTER_CUBIC);
}
cv::cvtColor(out, out, cv::COLOR_BGR2RGB);
}
else
{
cv::cvtColor(img, out, cv::COLOR_BGR2RGB);
}
if (input_c == 3)
{
out.convertTo(out, CV_32FC3, 1.0 / 255.0);
}
else
{
out.convertTo(out, CV_32FC1, 1.0 / 255.0);
}
std::vector<cv::Mat> input_channels(input_c);
cv::split(out, input_channels);
std::vector<float> result(input_h * input_w * input_c);
auto data = result.data();
int channelLength = input_h * input_w;
for (int i = 0; i < input_c; ++i)
{
for (int i = 0; i < input_c; ++i) {
memcpy(data, input_channels[i].data, channelLength * sizeof(float));
data += channelLength;
}
return result;
}

View File

@@ -6,38 +6,32 @@
#ifndef CALIBRATOR_H
#define CALIBRATOR_H
#include "opencv2/opencv.hpp"
#include "cuda_runtime.h"
#include "NvInfer.h"
#include <vector>
#include <string>
#ifndef CUDA_CHECK
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
#include "NvInfer.h"
#include "opencv2/opencv.hpp"
#define CUDA_CHECK(status) { \
if (status != 0) { \
std::cout << "CUDA failure: " << cudaGetErrorString(status) << " in file " << __FILE__ << " at line " << __LINE__ << \
std::endl; \
abort(); \
} \
}
#endif
namespace nvinfer1 {
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 {
public:
Int8EntropyCalibrator2(const int &batchsize,
const int &channels,
const int &height,
const int &width,
const int &letterbox,
const std::string &imgPath,
const std::string &calibTablePath);
Int8EntropyCalibrator2(const int& batchsize, const int& channels, const int& height, const int& width,
const int& letterbox, const std::string& imgPath, const std::string& calibTablePath);
virtual ~Int8EntropyCalibrator2();
int getBatchSize() const noexcept override;
bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override;
const void* readCalibrationCache(std::size_t& length) noexcept override;
void writeCalibrationCache(const void* cache, size_t length) noexcept override;
private:
@@ -55,7 +49,6 @@ namespace nvinfer1 {
bool readCache;
std::vector<char> calibrationCache;
};
}
std::vector<float> prepareImage(cv::Mat& img, int input_c, int input_h, int input_w, int letter_box);

View File

@@ -5,118 +5,107 @@
#include "activation_layer.h"
nvinfer1::ITensor* activationLayer(
int layerIdx,
std::string activation,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
#include <cassert>
#include <iostream>
nvinfer1::ITensor*
activationLayer(int layerIdx, std::string activation, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network,
std::string layerName)
{
nvinfer1::ITensor* output;
if (activation == "linear")
{
output = input;
}
else if (activation == "relu")
{
else if (activation == "relu") {
nvinfer1::IActivationLayer* relu = network->addActivation(*input, nvinfer1::ActivationType::kRELU);
assert(relu != nullptr);
std::string reluLayerName = "relu_" + std::to_string(layerIdx);
std::string reluLayerName = "relu_" + layerName + std::to_string(layerIdx);
relu->setName(reluLayerName.c_str());
output = relu->getOutput(0);
}
else if (activation == "sigmoid" || activation == "logistic")
{
else if (activation == "sigmoid" || activation == "logistic") {
nvinfer1::IActivationLayer* sigmoid = network->addActivation(*input, nvinfer1::ActivationType::kSIGMOID);
assert(sigmoid != nullptr);
std::string sigmoidLayerName = "sigmoid_" + std::to_string(layerIdx);
std::string sigmoidLayerName = "sigmoid_" + layerName + std::to_string(layerIdx);
sigmoid->setName(sigmoidLayerName.c_str());
output = sigmoid->getOutput(0);
}
else if (activation == "tanh")
{
else if (activation == "tanh") {
nvinfer1::IActivationLayer* tanh = network->addActivation(*input, nvinfer1::ActivationType::kTANH);
assert(tanh != nullptr);
std::string tanhLayerName = "tanh_" + std::to_string(layerIdx);
std::string tanhLayerName = "tanh_" + layerName + std::to_string(layerIdx);
tanh->setName(tanhLayerName.c_str());
output = tanh->getOutput(0);
}
else if (activation == "leaky")
{
else if (activation == "leaky") {
nvinfer1::IActivationLayer* leaky = network->addActivation(*input, nvinfer1::ActivationType::kLEAKY_RELU);
assert(leaky != nullptr);
std::string leakyLayerName = "leaky_" + std::to_string(layerIdx);
std::string leakyLayerName = "leaky_" + layerName + std::to_string(layerIdx);
leaky->setName(leakyLayerName.c_str());
leaky->setAlpha(0.1);
output = leaky->getOutput(0);
}
else if (activation == "softplus")
{
else if (activation == "softplus") {
nvinfer1::IActivationLayer* softplus = network->addActivation(*input, nvinfer1::ActivationType::kSOFTPLUS);
assert(softplus != nullptr);
std::string softplusLayerName = "softplus_" + std::to_string(layerIdx);
std::string softplusLayerName = "softplus_" + layerName + std::to_string(layerIdx);
softplus->setName(softplusLayerName.c_str());
output = softplus->getOutput(0);
}
else if (activation == "mish")
{
else if (activation == "mish") {
nvinfer1::IActivationLayer* softplus = network->addActivation(*input, nvinfer1::ActivationType::kSOFTPLUS);
assert(softplus != nullptr);
std::string softplusLayerName = "softplus_" + std::to_string(layerIdx);
std::string softplusLayerName = "softplus_" + layerName + std::to_string(layerIdx);
softplus->setName(softplusLayerName.c_str());
nvinfer1::IActivationLayer* tanh = network->addActivation(*softplus->getOutput(0), nvinfer1::ActivationType::kTANH);
assert(tanh != nullptr);
std::string tanhLayerName = "tanh_" + std::to_string(layerIdx);
std::string tanhLayerName = "tanh_" + layerName + std::to_string(layerIdx);
tanh->setName(tanhLayerName.c_str());
nvinfer1::IElementWiseLayer* mish
= network->addElementWise(*input, *tanh->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
nvinfer1::IElementWiseLayer* mish = network->addElementWise(*input, *tanh->getOutput(0),
nvinfer1::ElementWiseOperation::kPROD);
assert(mish != nullptr);
std::string mishLayerName = "mish_" + std::to_string(layerIdx);
std::string mishLayerName = "mish_" + layerName + std::to_string(layerIdx);
mish->setName(mishLayerName.c_str());
output = mish->getOutput(0);
}
else if (activation == "silu" || activation == "swish")
{
else if (activation == "silu" || activation == "swish") {
nvinfer1::IActivationLayer* sigmoid = network->addActivation(*input, nvinfer1::ActivationType::kSIGMOID);
assert(sigmoid != nullptr);
std::string sigmoidLayerName = "sigmoid_" + std::to_string(layerIdx);
std::string sigmoidLayerName = "sigmoid_" + layerName + std::to_string(layerIdx);
sigmoid->setName(sigmoidLayerName.c_str());
nvinfer1::IElementWiseLayer* silu
= network->addElementWise(*input, *sigmoid->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
nvinfer1::IElementWiseLayer* silu = network->addElementWise(*input, *sigmoid->getOutput(0),
nvinfer1::ElementWiseOperation::kPROD);
assert(silu != nullptr);
std::string siluLayerName = "silu_" + std::to_string(layerIdx);
std::string siluLayerName = "silu_" + layerName + std::to_string(layerIdx);
silu->setName(siluLayerName.c_str());
output = silu->getOutput(0);
}
else if (activation == "hardsigmoid")
{
else if (activation == "hardsigmoid") {
nvinfer1::IActivationLayer* hardsigmoid = network->addActivation(*input, nvinfer1::ActivationType::kHARD_SIGMOID);
assert(hardsigmoid != nullptr);
std::string hardsigmoidLayerName = "hardsigmoid_" + std::to_string(layerIdx);
std::string hardsigmoidLayerName = "hardsigmoid_" + layerName + std::to_string(layerIdx);
hardsigmoid->setName(hardsigmoidLayerName.c_str());
hardsigmoid->setAlpha(1.0 / 6.0);
hardsigmoid->setBeta(0.5);
output = hardsigmoid->getOutput(0);
}
else if (activation == "hardswish")
{
else if (activation == "hardswish") {
nvinfer1::IActivationLayer* hardsigmoid = network->addActivation(*input, nvinfer1::ActivationType::kHARD_SIGMOID);
assert(hardsigmoid != nullptr);
std::string hardsigmoidLayerName = "hardsigmoid_" + std::to_string(layerIdx);
std::string hardsigmoidLayerName = "hardsigmoid_" + layerName + std::to_string(layerIdx);
hardsigmoid->setName(hardsigmoidLayerName.c_str());
hardsigmoid->setAlpha(1.0 / 6.0);
hardsigmoid->setBeta(0.5);
nvinfer1::IElementWiseLayer* hardswish
= network->addElementWise(*input, *hardsigmoid->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
nvinfer1::IElementWiseLayer* hardswish = network->addElementWise(*input, *hardsigmoid->getOutput(0),
nvinfer1::ElementWiseOperation::kPROD);
assert(hardswish != nullptr);
std::string hardswishLayerName = "hardswish_" + std::to_string(layerIdx);
std::string hardswishLayerName = "hardswish_" + layerName + std::to_string(layerIdx);
hardswish->setName(hardswishLayerName.c_str());
output = hardswish->getOutput(0);
}
else
{
else {
std::cerr << "Activation not supported: " << activation << std::endl;
std::abort();
assert(0);
}
return output;
}

View File

@@ -6,15 +6,11 @@
#ifndef __ACTIVATION_LAYER_H__
#define __ACTIVATION_LAYER_H__
#include <cassert>
#include <iostream>
#include <string>
#include "NvInfer.h"
nvinfer1::ITensor* activationLayer(
int layerIdx,
std::string activation,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
nvinfer1::ITensor* activationLayer(int layerIdx, std::string activation, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network, std::string layerName = "");
#endif

View File

@@ -3,18 +3,14 @@
* https://www.github.com/marcoslucianops
*/
#include <math.h>
#include "batchnorm_layer.h"
nvinfer1::ITensor* batchnormLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& weightPtr,
std::string weightsType,
float eps,
nvinfer1::ITensor* input,
#include <cassert>
#include <math.h>
nvinfer1::ITensor*
batchnormLayer(int layerIdx, std::map<std::string, std::string>& block, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, int& weightPtr, std::string weightsType, float eps, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;
@@ -30,50 +26,40 @@ nvinfer1::ITensor* batchnormLayer(
std::vector<float> bnRunningMean;
std::vector<float> bnRunningVar;
if (weightsType == "weights")
{
for (int i = 0; i < filters; ++i)
{
if (weightsType == "weights") {
for (int i = 0; i < filters; ++i) {
bnBiases.push_back(weights[weightPtr]);
weightPtr++;
++weightPtr;
}
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
bnWeights.push_back(weights[weightPtr]);
weightPtr++;
++weightPtr;
}
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
bnRunningMean.push_back(weights[weightPtr]);
weightPtr++;
++weightPtr;
}
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5));
weightPtr++;
++weightPtr;
}
}
else
{
for (int i = 0; i < filters; ++i)
{
else {
for (int i = 0; i < filters; ++i) {
bnWeights.push_back(weights[weightPtr]);
weightPtr++;
++weightPtr;
}
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
bnBiases.push_back(weights[weightPtr]);
weightPtr++;
++weightPtr;
}
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
bnRunningMean.push_back(weights[weightPtr]);
weightPtr++;
++weightPtr;
}
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
bnRunningVar.push_back(sqrt(weights[weightPtr] + eps));
weightPtr++;
++weightPtr;
}
}

View File

@@ -13,15 +13,8 @@
#include "activation_layer.h"
nvinfer1::ITensor* batchnormLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& weightPtr,
std::string weightsType,
float eps,
nvinfer1::ITensor* input,
nvinfer1::ITensor* batchnormLayer(int layerIdx, std::map<std::string, std::string>& block, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, int& weightPtr, std::string weightsType, float eps, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -0,0 +1,82 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "c2f_layer.h"
#include <cassert>
#include "convolutional_layer.h"
nvinfer1::ITensor*
c2fLayer(int layerIdx, std::map<std::string, std::string>& block, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, int& weightPtr, std::string weightsType, float eps, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;
assert(block.at("type") == "c2f");
assert(block.find("n") != block.end());
assert(block.find("shortcut") != block.end());
assert(block.find("filters") != block.end());
int n = std::stoi(block.at("n"));
bool shortcut = (block.at("shortcut") == "1");
int filters = std::stoi(block.at("filters"));
nvinfer1::Dims inputDims = input->getDimensions();
nvinfer1::ISliceLayer* sliceLt = network->addSlice(*input,nvinfer1::Dims{3, {0, 0, 0}},
nvinfer1::Dims{3, {inputDims.d[0] / 2, inputDims.d[1], inputDims.d[2]}}, nvinfer1::Dims{3, {1, 1, 1}});
assert(sliceLt != nullptr);
std::string sliceLtLayerName = "slice_lt_" + std::to_string(layerIdx);
sliceLt->setName(sliceLtLayerName.c_str());
nvinfer1::ITensor* lt = sliceLt->getOutput(0);
nvinfer1::ISliceLayer* sliceRb = network->addSlice(*input,nvinfer1::Dims{3, {inputDims.d[0] / 2, 0, 0}},
nvinfer1::Dims{3, {inputDims.d[0] / 2, inputDims.d[1], inputDims.d[2]}}, nvinfer1::Dims{3, {1, 1, 1}});
assert(sliceRb != nullptr);
std::string sliceRbLayerName = "slice_rb_" + std::to_string(layerIdx);
sliceRb->setName(sliceRbLayerName.c_str());
nvinfer1::ITensor* rb = sliceRb->getOutput(0);
std::vector<nvinfer1::ITensor*> concatInputs;
concatInputs.push_back(lt);
concatInputs.push_back(rb);
output = rb;
for (int i = 0; i < n; ++i) {
std::string cv1MlayerName = "c2f_1_" + std::to_string(i + 1) + "_";
nvinfer1::ITensor* cv1M = convolutionalLayer(layerIdx, block, weights, trtWeights, weightPtr, weightsType, filters, eps,
output, network, cv1MlayerName);
assert(cv1M != nullptr);
std::string cv2MlayerName = "c2f_2_" + std::to_string(i + 1) + "_";
nvinfer1::ITensor* cv2M = convolutionalLayer(layerIdx, block, weights, trtWeights, weightPtr, weightsType, filters, eps,
cv1M, network, cv2MlayerName);
assert(cv2M != nullptr);
if (shortcut) {
nvinfer1::IElementWiseLayer* ew = network->addElementWise(*rb, *cv2M, nvinfer1::ElementWiseOperation::kSUM);
assert(ew != nullptr);
std::string ewLayerName = "shortcut_c2f_" + std::to_string(i + 1) + "_" + std::to_string(layerIdx);
ew->setName(ewLayerName.c_str());
output = ew->getOutput(0);
concatInputs.push_back(output);
}
else {
output = cv2M;
concatInputs.push_back(output);
}
}
nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size());
assert(concat != nullptr);
std::string concatLayerName = "route_" + std::to_string(layerIdx);
concat->setName(concatLayerName.c_str());
concat->setAxis(0);
output = concat->getOutput(0);
return output;
}

View File

@@ -0,0 +1,18 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#ifndef __C2F_LAYER_H__
#define __C2F_LAYER_H__
#include <map>
#include <vector>
#include "NvInfer.h"
nvinfer1::ITensor* c2fLayer(int layerIdx, std::map<std::string, std::string>& block, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, int& weightPtr, std::string weightsType, float eps, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -5,28 +5,27 @@
#include "channels_layer.h"
nvinfer1::ITensor* channelsLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
nvinfer1::ITensor* implicitTensor,
nvinfer1::INetworkDefinition* network)
#include <cassert>
nvinfer1::ITensor*
channelsLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::ITensor* implicitTensor, nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;
assert(block.at("type") == "shift_channels" || block.at("type") == "control_channels");
if (block.at("type") == "shift_channels") {
nvinfer1::IElementWiseLayer* shift
= network->addElementWise(*input, *implicitTensor, nvinfer1::ElementWiseOperation::kSUM);
nvinfer1::IElementWiseLayer* shift = network->addElementWise(*input, *implicitTensor,
nvinfer1::ElementWiseOperation::kSUM);
assert(shift != nullptr);
std::string shiftLayerName = "shift_channels_" + std::to_string(layerIdx);
shift->setName(shiftLayerName.c_str());
output = shift->getOutput(0);
}
else if (block.at("type") == "control_channels") {
nvinfer1::IElementWiseLayer* control
= network->addElementWise(*input, *implicitTensor, nvinfer1::ElementWiseOperation::kPROD);
nvinfer1::IElementWiseLayer* control = network->addElementWise(*input, *implicitTensor,
nvinfer1::ElementWiseOperation::kPROD);
assert(control != nullptr);
std::string controlLayerName = "control_channels_" + std::to_string(layerIdx);
control->setName(controlLayerName.c_str());

View File

@@ -7,15 +7,10 @@
#define __CHANNELS_LAYER_H__
#include <map>
#include <cassert>
#include "NvInfer.h"
nvinfer1::ITensor* channelsLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
nvinfer1::ITensor* implicitTensor,
nvinfer1::INetworkDefinition* network);
nvinfer1::ITensor* channelsLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::ITensor* implicitTensor, nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -5,10 +5,10 @@
#include "cls_layer.h"
nvinfer1::ITensor* clsLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
#include <cassert>
nvinfer1::ITensor*
clsLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;

View File

@@ -7,14 +7,10 @@
#define __CLS_LAYER_H__
#include <map>
#include <cassert>
#include "NvInfer.h"
nvinfer1::ITensor* clsLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
nvinfer1::ITensor* clsLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -3,24 +3,19 @@
* https://www.github.com/marcoslucianops
*/
#include <math.h>
#include "convolutional_layer.h"
nvinfer1::ITensor* convolutionalLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& weightPtr,
std::string weightsType,
int& inputChannels,
float eps,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
#include <cassert>
#include <math.h>
nvinfer1::ITensor*
convolutionalLayer(int layerIdx, std::map<std::string, std::string>& block, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, int& weightPtr, std::string weightsType, int& inputChannels, float eps,
nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network, std::string layerName)
{
nvinfer1::ITensor* output;
assert(block.at("type") == "convolutional");
assert(block.at("type") == "convolutional" || block.at("type") == "c2f");
assert(block.find("filters") != block.end());
assert(block.find("pad") != block.end());
assert(block.find("size") != block.end());
@@ -34,8 +29,7 @@ nvinfer1::ITensor* convolutionalLayer(
int bias = filters;
bool batchNormalize = false;
if (block.find("batch_normalize") != block.end())
{
if (block.find("batch_normalize") != block.end()) {
bias = 0;
batchNormalize = (block.at("batch_normalize") == "1");
}
@@ -61,57 +55,47 @@ nvinfer1::ITensor* convolutionalLayer(
nvinfer1::Weights convWt {nvinfer1::DataType::kFLOAT, nullptr, size};
nvinfer1::Weights convBias {nvinfer1::DataType::kFLOAT, nullptr, bias};
if (weightsType == "weights")
{
if (batchNormalize == false)
{
if (weightsType == "weights") {
if (batchNormalize == false) {
float* val;
if (bias != 0) {
val = new float[filters];
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
val[i] = weights[weightPtr];
weightPtr++;
++weightPtr;
}
convBias.values = val;
trtWeights.push_back(convBias);
}
val = new float[size];
for (int i = 0; i < size; ++i)
{
for (int i = 0; i < size; ++i) {
val[i] = weights[weightPtr];
weightPtr++;
++weightPtr;
}
convWt.values = val;
trtWeights.push_back(convWt);
}
else
{
for (int i = 0; i < filters; ++i)
{
else {
for (int i = 0; i < filters; ++i) {
bnBiases.push_back(weights[weightPtr]);
weightPtr++;
++weightPtr;
}
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
bnWeights.push_back(weights[weightPtr]);
weightPtr++;
++weightPtr;
}
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
bnRunningMean.push_back(weights[weightPtr]);
weightPtr++;
++weightPtr;
}
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5));
weightPtr++;
++weightPtr;
}
float* val = new float[size];
for (int i = 0; i < size; ++i)
{
for (int i = 0; i < size; ++i) {
val[i] = weights[weightPtr];
weightPtr++;
++weightPtr;
}
convWt.values = val;
trtWeights.push_back(convWt);
@@ -119,57 +103,47 @@ nvinfer1::ITensor* convolutionalLayer(
trtWeights.push_back(convBias);
}
}
else
{
if (batchNormalize == false)
{
else {
if (batchNormalize == false) {
float* val = new float[size];
for (int i = 0; i < size; ++i)
{
for (int i = 0; i < size; ++i) {
val[i] = weights[weightPtr];
weightPtr++;
++weightPtr;
}
convWt.values = val;
trtWeights.push_back(convWt);
if (bias != 0) {
val = new float[filters];
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
val[i] = weights[weightPtr];
weightPtr++;
++weightPtr;
}
convBias.values = val;
trtWeights.push_back(convBias);
}
}
else
{
else {
float* val = new float[size];
for (int i = 0; i < size; ++i)
{
for (int i = 0; i < size; ++i) {
val[i] = weights[weightPtr];
weightPtr++;
++weightPtr;
}
convWt.values = val;
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
bnWeights.push_back(weights[weightPtr]);
weightPtr++;
++weightPtr;
}
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
bnBiases.push_back(weights[weightPtr]);
weightPtr++;
++weightPtr;
}
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
bnRunningMean.push_back(weights[weightPtr]);
weightPtr++;
++weightPtr;
}
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
bnRunningVar.push_back(sqrt(weights[weightPtr] + eps));
weightPtr++;
++weightPtr;
}
trtWeights.push_back(convWt);
if (bias != 0)
@@ -177,10 +151,10 @@ nvinfer1::ITensor* convolutionalLayer(
}
}
nvinfer1::IConvolutionLayer* conv
= network->addConvolutionNd(*input, filters, nvinfer1::Dims{2, {kernelSize, kernelSize}}, convWt, convBias);
nvinfer1::IConvolutionLayer* conv = network->addConvolutionNd(*input, filters, nvinfer1::Dims{2, {kernelSize, kernelSize}},
convWt, convBias);
assert(conv != nullptr);
std::string convLayerName = "conv_" + std::to_string(layerIdx);
std::string convLayerName = "conv_" + layerName + std::to_string(layerIdx);
conv->setName(convLayerName.c_str());
conv->setStrideNd(nvinfer1::Dims{2, {stride, stride}});
conv->setPaddingNd(nvinfer1::Dims{2, {pad, pad}});
@@ -190,8 +164,7 @@ nvinfer1::ITensor* convolutionalLayer(
output = conv->getOutput(0);
if (batchNormalize == true)
{
if (batchNormalize == true) {
size = filters;
nvinfer1::Weights shift {nvinfer1::DataType::kFLOAT, nullptr, size};
nvinfer1::Weights scale {nvinfer1::DataType::kFLOAT, nullptr, size};
@@ -214,12 +187,12 @@ nvinfer1::ITensor* convolutionalLayer(
nvinfer1::IScaleLayer* batchnorm = network->addScale(*output, nvinfer1::ScaleMode::kCHANNEL, shift, scale, power);
assert(batchnorm != nullptr);
std::string batchnormLayerName = "batchnorm_" + std::to_string(layerIdx);
std::string batchnormLayerName = "batchnorm_" + layerName + std::to_string(layerIdx);
batchnorm->setName(batchnormLayerName.c_str());
output = batchnorm->getOutput(0);
}
output = activationLayer(layerIdx, activation, output, network);
output = activationLayer(layerIdx, activation, output, network, layerName);
assert(output != nullptr);
return output;

View File

@@ -13,16 +13,8 @@
#include "activation_layer.h"
nvinfer1::ITensor* convolutionalLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& weightPtr,
std::string weightsType,
int& inputChannels,
float eps,
nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
nvinfer1::ITensor* convolutionalLayer(int layerIdx, std::map<std::string, std::string>& block, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, int& weightPtr, std::string weightsType, int& inputChannels, float eps,
nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network, std::string layerName = "");
#endif

View File

@@ -0,0 +1,196 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include "detect_v8_layer.h"
#include <cassert>
nvinfer1::ITensor*
detectV8Layer(int layerIdx, std::map<std::string, std::string>& block, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, int& weightPtr, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;
assert(block.at("type") == "detect_v8");
assert(block.find("num") != block.end());
assert(block.find("classes") != block.end());
int num = std::stoi(block.at("num"));
int classes = std::stoi(block.at("classes"));
int reg_max = num / 4;
nvinfer1::Dims inputDims = input->getDimensions();
nvinfer1::ISliceLayer* sliceBox = network->addSlice(*input, nvinfer1::Dims{2, {0, 0}},
nvinfer1::Dims{2, {num, inputDims.d[1]}}, nvinfer1::Dims{2, {1, 1}});
assert(sliceBox != nullptr);
std::string sliceBoxLayerName = "slice_box_" + std::to_string(layerIdx);
sliceBox->setName(sliceBoxLayerName.c_str());
nvinfer1::ITensor* box = sliceBox->getOutput(0);
nvinfer1::ISliceLayer* sliceCls = network->addSlice(*input, nvinfer1::Dims{2, {num, 0}},
nvinfer1::Dims{2, {classes, inputDims.d[1]}}, nvinfer1::Dims{2, {1, 1}});
assert(sliceCls != nullptr);
std::string sliceClsLayerName = "slice_cls_" + std::to_string(layerIdx);
sliceCls->setName(sliceClsLayerName.c_str());
nvinfer1::ITensor* cls = sliceCls->getOutput(0);
nvinfer1::IShuffleLayer* shuffle1Box = network->addShuffle(*box);
assert(shuffle1Box != nullptr);
std::string shuffle1BoxLayerName = "shuffle1_box_" + std::to_string(layerIdx);
shuffle1Box->setName(shuffle1BoxLayerName.c_str());
nvinfer1::Dims reshape1Dims = {3, {4, reg_max, inputDims.d[1]}};
shuffle1Box->setReshapeDimensions(reshape1Dims);
nvinfer1::Permutation permutation1;
permutation1.order[0] = 1;
permutation1.order[1] = 0;
permutation1.order[2] = 2;
shuffle1Box->setSecondTranspose(permutation1);
box = shuffle1Box->getOutput(0);
nvinfer1::ISoftMaxLayer* softmax = network->addSoftMax(*box);
assert(softmax != nullptr);
std::string softmaxLayerName = "softmax_box_" + std::to_string(layerIdx);
softmax->setName(softmaxLayerName.c_str());
softmax->setAxes(1 << 0);
box = softmax->getOutput(0);
nvinfer1::Weights dflWt {nvinfer1::DataType::kFLOAT, nullptr, reg_max};
float* val = new float[reg_max];
for (int i = 0; i < reg_max; ++i) {
val[i] = i;
}
dflWt.values = val;
nvinfer1::IConvolutionLayer* conv = network->addConvolutionNd(*box, 1, nvinfer1::Dims{2, {1, 1}}, dflWt,
nvinfer1::Weights{});
assert(conv != nullptr);
std::string convLayerName = "conv_box_" + std::to_string(layerIdx);
conv->setName(convLayerName.c_str());
conv->setStrideNd(nvinfer1::Dims{2, {1, 1}});
conv->setPaddingNd(nvinfer1::Dims{2, {0, 0}});
box = conv->getOutput(0);
nvinfer1::IShuffleLayer* shuffle2Box = network->addShuffle(*box);
assert(shuffle2Box != nullptr);
std::string shuffle2BoxLayerName = "shuffle2_box_" + std::to_string(layerIdx);
shuffle2Box->setName(shuffle2BoxLayerName.c_str());
nvinfer1::Dims reshape2Dims = {2, {4, inputDims.d[1]}};
shuffle2Box->setReshapeDimensions(reshape2Dims);
box = shuffle2Box->getOutput(0);
nvinfer1::Dims shuffle2BoxDims = box->getDimensions();
nvinfer1::ISliceLayer* sliceLtBox = network->addSlice(*box, nvinfer1::Dims{2, {0, 0}},
nvinfer1::Dims{2, {2, shuffle2BoxDims.d[1]}}, nvinfer1::Dims{2, {1, 1}});
assert(sliceLtBox != nullptr);
std::string sliceLtBoxLayerName = "slice_lt_box_" + std::to_string(layerIdx);
sliceLtBox->setName(sliceLtBoxLayerName.c_str());
nvinfer1::ITensor* lt = sliceLtBox->getOutput(0);
nvinfer1::ISliceLayer* sliceRbBox = network->addSlice(*box, nvinfer1::Dims{2, {2, 0}},
nvinfer1::Dims{2, {2, shuffle2BoxDims.d[1]}}, nvinfer1::Dims{2, {1, 1}});
assert(sliceRbBox != nullptr);
std::string sliceRbBoxLayerName = "slice_rb_box_" + std::to_string(layerIdx);
sliceRbBox->setName(sliceRbBoxLayerName.c_str());
nvinfer1::ITensor* rb = sliceRbBox->getOutput(0);
int channels = 2 * shuffle2BoxDims.d[1];
nvinfer1::Weights anchorPointsWt {nvinfer1::DataType::kFLOAT, nullptr, channels};
val = new float[channels];
for (int i = 0; i < channels; ++i) {
val[i] = weights[weightPtr];
++weightPtr;
}
anchorPointsWt.values = val;
trtWeights.push_back(anchorPointsWt);
nvinfer1::IConstantLayer* anchorPoints = network->addConstant(nvinfer1::Dims{2, {2, shuffle2BoxDims.d[1]}},
anchorPointsWt);
assert(anchorPoints != nullptr);
std::string anchorPointsLayerName = "anchor_points_" + std::to_string(layerIdx);
anchorPoints->setName(anchorPointsLayerName.c_str());
nvinfer1::ITensor* anchorPointsTensor = anchorPoints->getOutput(0);
nvinfer1::IElementWiseLayer* x1y1 = network->addElementWise(*anchorPointsTensor, *lt,
nvinfer1::ElementWiseOperation::kSUB);
assert(x1y1 != nullptr);
std::string x1y1LayerName = "x1y1_" + std::to_string(layerIdx);
x1y1->setName(x1y1LayerName.c_str());
nvinfer1::ITensor* x1y1Tensor = x1y1->getOutput(0);
nvinfer1::IElementWiseLayer* x2y2 = network->addElementWise(*rb, *anchorPointsTensor,
nvinfer1::ElementWiseOperation::kSUM);
assert(x2y2 != nullptr);
std::string x2y2LayerName = "x2y2_" + std::to_string(layerIdx);
x2y2->setName(x2y2LayerName.c_str());
nvinfer1::ITensor* x2y2Tensor = x2y2->getOutput(0);
std::vector<nvinfer1::ITensor*> concatBoxInputs;
concatBoxInputs.push_back(x1y1Tensor);
concatBoxInputs.push_back(x2y2Tensor);
nvinfer1::IConcatenationLayer* concatBox = network->addConcatenation(concatBoxInputs.data(), concatBoxInputs.size());
assert(concatBox != nullptr);
std::string concatBoxLayerName = "concat_box_" + std::to_string(layerIdx);
concatBox->setName(concatBoxLayerName.c_str());
concatBox->setAxis(0);
box = concatBox->getOutput(0);
channels = shuffle2BoxDims.d[1];
nvinfer1::Weights stridePointsWt {nvinfer1::DataType::kFLOAT, nullptr, channels};
val = new float[channels];
for (int i = 0; i < channels; ++i) {
val[i] = weights[weightPtr];
++weightPtr;
}
stridePointsWt.values = val;
trtWeights.push_back(stridePointsWt);
nvinfer1::IConstantLayer* stridePoints = network->addConstant(nvinfer1::Dims{2, {1, shuffle2BoxDims.d[1]}},
stridePointsWt);
assert(stridePoints != nullptr);
std::string stridePointsLayerName = "stride_points_" + std::to_string(layerIdx);
stridePoints->setName(stridePointsLayerName.c_str());
nvinfer1::ITensor* stridePointsTensor = stridePoints->getOutput(0);
nvinfer1::IElementWiseLayer* pred = network->addElementWise(*box, *stridePointsTensor,
nvinfer1::ElementWiseOperation::kPROD);
assert(pred != nullptr);
std::string predLayerName = "pred_" + std::to_string(layerIdx);
pred->setName(predLayerName.c_str());
box = pred->getOutput(0);
nvinfer1::IActivationLayer* sigmoid = network->addActivation(*cls, nvinfer1::ActivationType::kSIGMOID);
assert(sigmoid != nullptr);
std::string sigmoidLayerName = "sigmoid_cls_" + std::to_string(layerIdx);
sigmoid->setName(sigmoidLayerName.c_str());
cls = sigmoid->getOutput(0);
std::vector<nvinfer1::ITensor*> concatInputs;
concatInputs.push_back(box);
concatInputs.push_back(cls);
nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size());
assert(concat != nullptr);
std::string concatLayerName = "concat_" + std::to_string(layerIdx);
concat->setName(concatLayerName.c_str());
concat->setAxis(0);
output = concat->getOutput(0);
nvinfer1::IShuffleLayer* shuffle = network->addShuffle(*output);
assert(shuffle != nullptr);
std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx);
shuffle->setName(shuffleLayerName.c_str());
nvinfer1::Permutation permutation2;
permutation2.order[0] = 1;
permutation2.order[1] = 0;
shuffle->setFirstTranspose(permutation2);
output = shuffle->getOutput(0);
return output;
}

View File

@@ -0,0 +1,18 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#ifndef __DETECT_V8_LAYER_H__
#define __DETECT_V8_LAYER_H__
#include <map>
#include <vector>
#include "NvInfer.h"
nvinfer1::ITensor* detectV8Layer(int layerIdx, std::map<std::string, std::string>& block, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, int& weightPtr, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -5,13 +5,11 @@
#include "implicit_layer.h"
nvinfer1::ITensor* implicitLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& weightPtr,
nvinfer1::INetworkDefinition* network)
#include <cassert>
nvinfer1::ITensor*
implicitLayer(int layerIdx, std::map<std::string, std::string>& block, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, int& weightPtr, nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;
@@ -23,10 +21,9 @@ nvinfer1::ITensor* implicitLayer(
nvinfer1::Weights convWt {nvinfer1::DataType::kFLOAT, nullptr, filters};
float* val = new float[filters];
for (int i = 0; i < filters; ++i)
{
for (int i = 0; i < filters; ++i) {
val[i] = weights[weightPtr];
weightPtr++;
++weightPtr;
}
convWt.values = val;
trtWeights.push_back(convWt);

View File

@@ -8,16 +8,10 @@
#include <map>
#include <vector>
#include <cassert>
#include "NvInfer.h"
nvinfer1::ITensor* implicitLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& weightPtr,
nvinfer1::INetworkDefinition* network);
nvinfer1::ITensor* implicitLayer(int layerIdx, std::map<std::string, std::string>& block, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, int& weightPtr, nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -5,52 +5,49 @@
#include "pooling_layer.h"
nvinfer1::ITensor* poolingLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
#include <cassert>
#include <iostream>
nvinfer1::ITensor*
poolingLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;
assert(block.at("type") == "maxpool" || block.at("type") == "avgpool");
if (block.at("type") == "maxpool")
{
if (block.at("type") == "maxpool") {
assert(block.find("size") != block.end());
assert(block.find("stride") != block.end());
int size = std::stoi(block.at("size"));
int stride = std::stoi(block.at("stride"));
nvinfer1::IPoolingLayer* maxpool
= network->addPoolingNd(*input, nvinfer1::PoolingType::kMAX, nvinfer1::Dims{2, {size, size}});
nvinfer1::IPoolingLayer* maxpool = network->addPoolingNd(*input, nvinfer1::PoolingType::kMAX,
nvinfer1::Dims{2, {size, size}});
assert(maxpool != nullptr);
std::string maxpoolLayerName = "maxpool_" + std::to_string(layerIdx);
maxpool->setName(maxpoolLayerName.c_str());
maxpool->setStrideNd(nvinfer1::Dims{2, {stride, stride}});
maxpool->setPaddingNd(nvinfer1::Dims{2, {(size - 1) / 2, (size - 1) / 2}});
if (size == 2 && stride == 1)
{
if (size == 2 && stride == 1) {
maxpool->setPrePadding(nvinfer1::Dims{2, {0, 0}});
maxpool->setPostPadding(nvinfer1::Dims{2, {1, 1}});
}
output = maxpool->getOutput(0);
}
else if (block.at("type") == "avgpool")
{
else if (block.at("type") == "avgpool") {
nvinfer1::Dims inputDims = input->getDimensions();
nvinfer1::IPoolingLayer* avgpool = network->addPoolingNd(
*input, nvinfer1::PoolingType::kAVERAGE, nvinfer1::Dims{2, {inputDims.d[1], inputDims.d[2]}});
nvinfer1::IPoolingLayer* avgpool = network->addPoolingNd(*input, nvinfer1::PoolingType::kAVERAGE,
nvinfer1::Dims{2, {inputDims.d[1], inputDims.d[2]}});
assert(avgpool != nullptr);
std::string avgpoolLayerName = "avgpool_" + std::to_string(layerIdx);
avgpool->setName(avgpoolLayerName.c_str());
output = avgpool->getOutput(0);
}
else
{
else {
std::cerr << "Pooling not supported: " << block.at("type") << std::endl;
std::abort();
assert(0);
}
return output;

View File

@@ -7,15 +7,10 @@
#define __POOLING_LAYER_H__
#include <map>
#include <cassert>
#include <iostream>
#include "NvInfer.h"
nvinfer1::ITensor* poolingLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
nvinfer1::ITensor* poolingLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -5,10 +5,8 @@
#include "reduce_layer.h"
nvinfer1::ITensor* reduceLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
nvinfer1::ITensor*
reduceLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;
@@ -26,14 +24,12 @@ nvinfer1::ITensor* reduceLayer(
std::string strAxes = block.at("axes");
std::vector<int32_t> axes;
size_t lastPos = 0, pos = 0;
while ((pos = strAxes.find(',', lastPos)) != std::string::npos)
{
while ((pos = strAxes.find(',', lastPos)) != std::string::npos) {
int vL = std::stoi(trim(strAxes.substr(lastPos, pos - lastPos)));
axes.push_back(vL);
lastPos = pos + 1;
}
if (lastPos < strAxes.length())
{
if (lastPos < strAxes.length()) {
std::string lastV = trim(strAxes.substr(lastPos));
if (!lastV.empty())
axes.push_back(std::stoi(lastV));

View File

@@ -6,13 +6,9 @@
#ifndef __REDUCE_LAYER_H__
#define __REDUCE_LAYER_H__
#include "NvInfer.h"
#include "../utils.h"
nvinfer1::ITensor* reduceLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
nvinfer1::ITensor* reduceLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -5,13 +5,11 @@
#include "reg_layer.h"
nvinfer1::ITensor* regLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& weightPtr,
nvinfer1::ITensor* input,
#include <cassert>
nvinfer1::ITensor*
regLayer(int layerIdx, std::map<std::string, std::string>& block, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, int& weightPtr, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;
@@ -29,15 +27,15 @@ nvinfer1::ITensor* regLayer(
output = shuffle->getOutput(0);
nvinfer1::Dims shuffleDims = output->getDimensions();
nvinfer1::ISliceLayer* sliceLt = network->addSlice(
*output, nvinfer1::Dims{2, {0, 0}}, nvinfer1::Dims{2, {shuffleDims.d[0], 2}}, nvinfer1::Dims{2, {1, 1}});
nvinfer1::ISliceLayer* sliceLt = network->addSlice(*output, nvinfer1::Dims{2, {0, 0}},
nvinfer1::Dims{2, {shuffleDims.d[0], 2}}, nvinfer1::Dims{2, {1, 1}});
assert(sliceLt != nullptr);
std::string sliceLtLayerName = "slice_lt_" + std::to_string(layerIdx);
sliceLt->setName(sliceLtLayerName.c_str());
nvinfer1::ITensor* lt = sliceLt->getOutput(0);
nvinfer1::ISliceLayer* sliceRb = network->addSlice(
*output, nvinfer1::Dims{2, {0, 2}}, nvinfer1::Dims{2, {shuffleDims.d[0], 2}}, nvinfer1::Dims{2, {1, 1}});
nvinfer1::ISliceLayer* sliceRb = network->addSlice(*output, nvinfer1::Dims{2, {0, 2}},
nvinfer1::Dims{2, {shuffleDims.d[0], 2}}, nvinfer1::Dims{2, {1, 1}});
assert(sliceRb != nullptr);
std::string sliceRbLayerName = "slice_rb_" + std::to_string(layerIdx);
sliceRb->setName(sliceRbLayerName.c_str());
@@ -46,10 +44,9 @@ nvinfer1::ITensor* regLayer(
int channels = shuffleDims.d[0] * 2;
nvinfer1::Weights anchorPointsWt {nvinfer1::DataType::kFLOAT, nullptr, channels};
float* val = new float[channels];
for (int i = 0; i < channels; ++i)
{
for (int i = 0; i < channels; ++i) {
val[i] = weights[weightPtr];
weightPtr++;
++weightPtr;
}
anchorPointsWt.values = val;
trtWeights.push_back(anchorPointsWt);
@@ -60,15 +57,15 @@ nvinfer1::ITensor* regLayer(
anchorPoints->setName(anchorPointsLayerName.c_str());
nvinfer1::ITensor* anchorPointsTensor = anchorPoints->getOutput(0);
nvinfer1::IElementWiseLayer* x1y1
= network->addElementWise(*anchorPointsTensor, *lt, nvinfer1::ElementWiseOperation::kSUB);
nvinfer1::IElementWiseLayer* x1y1 = network->addElementWise(*anchorPointsTensor, *lt,
nvinfer1::ElementWiseOperation::kSUB);
assert(x1y1 != nullptr);
std::string x1y1LayerName = "x1y1_" + std::to_string(layerIdx);
x1y1->setName(x1y1LayerName.c_str());
nvinfer1::ITensor* x1y1Tensor = x1y1->getOutput(0);
nvinfer1::IElementWiseLayer* x2y2
= network->addElementWise(*rb, *anchorPointsTensor, nvinfer1::ElementWiseOperation::kSUM);
nvinfer1::IElementWiseLayer* x2y2 = network->addElementWise(*rb, *anchorPointsTensor,
nvinfer1::ElementWiseOperation::kSUM);
assert(x2y2 != nullptr);
std::string x2y2LayerName = "x2y2_" + std::to_string(layerIdx);
x2y2->setName(x2y2LayerName.c_str());
@@ -88,10 +85,9 @@ nvinfer1::ITensor* regLayer(
channels = shuffleDims.d[0];
nvinfer1::Weights stridePointsWt {nvinfer1::DataType::kFLOAT, nullptr, channels};
val = new float[channels];
for (int i = 0; i < channels; ++i)
{
for (int i = 0; i < channels; ++i) {
val[i] = weights[weightPtr];
weightPtr++;
++weightPtr;
}
stridePointsWt.values = val;
trtWeights.push_back(stridePointsWt);
@@ -102,8 +98,8 @@ nvinfer1::ITensor* regLayer(
stridePoints->setName(stridePointsLayerName.c_str());
nvinfer1::ITensor* stridePointsTensor = stridePoints->getOutput(0);
nvinfer1::IElementWiseLayer* pred
= network->addElementWise(*output, *stridePointsTensor, nvinfer1::ElementWiseOperation::kPROD);
nvinfer1::IElementWiseLayer* pred = network->addElementWise(*output, *stridePointsTensor,
nvinfer1::ElementWiseOperation::kPROD);
assert(pred != nullptr);
std::string predLayerName = "pred_" + std::to_string(layerIdx);
pred->setName(predLayerName.c_str());

View File

@@ -8,17 +8,11 @@
#include <map>
#include <vector>
#include <cassert>
#include "NvInfer.h"
nvinfer1::ITensor* regLayer(
int layerIdx,
std::map<std::string, std::string>& block,
std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights,
int& weightPtr,
nvinfer1::ITensor* input,
nvinfer1::ITensor* regLayer(int layerIdx, std::map<std::string, std::string>& block, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, int& weightPtr, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -5,10 +5,11 @@
#include "reorg_layer.h"
nvinfer1::ITensor* reorgLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
#include <vector>
#include <cassert>
nvinfer1::ITensor*
reorgLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;
@@ -17,30 +18,26 @@ nvinfer1::ITensor* reorgLayer(
nvinfer1::Dims inputDims = input->getDimensions();
nvinfer1::ISliceLayer *slice1 = network->addSlice(
*input, nvinfer1::Dims{3, {0, 0, 0}}, nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}},
nvinfer1::Dims{3, {1, 2, 2}});
nvinfer1::ISliceLayer *slice1 = network->addSlice(*input, nvinfer1::Dims{3, {0, 0, 0}},
nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}}, nvinfer1::Dims{3, {1, 2, 2}});
assert(slice1 != nullptr);
std::string slice1LayerName = "slice1_" + std::to_string(layerIdx);
slice1->setName(slice1LayerName.c_str());
nvinfer1::ISliceLayer *slice2 = network->addSlice(
*input, nvinfer1::Dims{3, {0, 1, 0}}, nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}},
nvinfer1::Dims{3, {1, 2, 2}});
nvinfer1::ISliceLayer *slice2 = network->addSlice(*input, nvinfer1::Dims{3, {0, 1, 0}},
nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}}, nvinfer1::Dims{3, {1, 2, 2}});
assert(slice2 != nullptr);
std::string slice2LayerName = "slice2_" + std::to_string(layerIdx);
slice2->setName(slice2LayerName.c_str());
nvinfer1::ISliceLayer *slice3 = network->addSlice(
*input, nvinfer1::Dims{3, {0, 0, 1}}, nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}},
nvinfer1::Dims{3, {1, 2, 2}});
nvinfer1::ISliceLayer *slice3 = network->addSlice(*input, nvinfer1::Dims{3, {0, 0, 1}},
nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}}, nvinfer1::Dims{3, {1, 2, 2}});
assert(slice3 != nullptr);
std::string slice3LayerName = "slice3_" + std::to_string(layerIdx);
slice3->setName(slice3LayerName.c_str());
nvinfer1::ISliceLayer *slice4 = network->addSlice(
*input, nvinfer1::Dims{3, {0, 1, 1}}, nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}},
nvinfer1::Dims{3, {1, 2, 2}});
nvinfer1::ISliceLayer *slice4 = network->addSlice(*input, nvinfer1::Dims{3, {0, 1, 1}},
nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}}, nvinfer1::Dims{3, {1, 2, 2}});
assert(slice4 != nullptr);
std::string slice4LayerName = "slice4_" + std::to_string(layerIdx);
slice4->setName(slice4LayerName.c_str());

View File

@@ -3,19 +3,14 @@
* https://www.github.com/marcoslucianops
*/
#ifndef __REORGV5_LAYER_H__
#define __REORGV5_LAYER_H__
#ifndef __REORG_LAYER_H__
#define __REORG_LAYER_H__
#include <map>
#include <vector>
#include <cassert>
#include "NvInfer.h"
nvinfer1::ITensor* reorgLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
nvinfer1::ITensor* reorgLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -5,12 +5,9 @@
#include "route_layer.h"
nvinfer1::ITensor* routeLayer(
int layerIdx,
std::string& layers,
std::map<std::string, std::string>& block,
std::vector<nvinfer1::ITensor*> tensorOutputs,
nvinfer1::INetworkDefinition* network)
nvinfer1::ITensor*
routeLayer(int layerIdx, std::string& layers, std::map<std::string, std::string>& block,
std::vector<nvinfer1::ITensor*> tensorOutputs, nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;
@@ -20,22 +17,19 @@ nvinfer1::ITensor* routeLayer(
std::string strLayers = block.at("layers");
std::vector<int> idxLayers;
size_t lastPos = 0, pos = 0;
while ((pos = strLayers.find(',', lastPos)) != std::string::npos)
{
while ((pos = strLayers.find(',', lastPos)) != std::string::npos) {
int vL = std::stoi(trim(strLayers.substr(lastPos, pos - lastPos)));
idxLayers.push_back(vL);
lastPos = pos + 1;
}
if (lastPos < strLayers.length())
{
if (lastPos < strLayers.length()) {
std::string lastV = trim(strLayers.substr(lastPos));
if (!lastV.empty())
idxLayers.push_back(std::stoi(lastV));
}
assert (!idxLayers.empty());
std::vector<nvinfer1::ITensor*> concatInputs;
for (uint i = 0; i < idxLayers.size(); ++i)
{
for (uint i = 0; i < idxLayers.size(); ++i) {
if (idxLayers[i] < 0)
idxLayers[i] = tensorOutputs.size() + idxLayers[i];
assert (idxLayers[i] >= 0 && idxLayers[i] < (int)tensorOutputs.size());
@@ -62,15 +56,13 @@ nvinfer1::ITensor* routeLayer(
output = concat->getOutput(0);
}
if (block.find("groups") != block.end())
{
if (block.find("groups") != block.end()) {
nvinfer1::Dims prevTensorDims = output->getDimensions();
int groups = stoi(block.at("groups"));
int group_id = stoi(block.at("group_id"));
int startSlice = (prevTensorDims.d[0] / groups) * group_id;
int channelSlice = (prevTensorDims.d[0] / groups);
nvinfer1::ISliceLayer* slice = network->addSlice(
*output, nvinfer1::Dims{3, {startSlice, 0, 0}},
nvinfer1::ISliceLayer* slice = network->addSlice(*output, nvinfer1::Dims{3, {startSlice, 0, 0}},
nvinfer1::Dims{3, {channelSlice, prevTensorDims.d[1], prevTensorDims.d[2]}}, nvinfer1::Dims{3, {1, 1, 1}});
assert(slice != nullptr);
std::string sliceLayerName = "slice_" + std::to_string(layerIdx);

View File

@@ -6,14 +6,9 @@
#ifndef __ROUTE_LAYER_H__
#define __ROUTE_LAYER_H__
#include "NvInfer.h"
#include "../utils.h"
nvinfer1::ITensor* routeLayer(
int layerIdx,
std::string& layers,
std::map<std::string, std::string>& block,
std::vector<nvinfer1::ITensor*> tensorOutputs,
nvinfer1::INetworkDefinition* network);
nvinfer1::ITensor* routeLayer(int layerIdx, std::string& layers, std::map<std::string, std::string>& block,
std::vector<nvinfer1::ITensor*> tensorOutputs, nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -5,15 +5,11 @@
#include "shortcut_layer.h"
nvinfer1::ITensor* shortcutLayer(
int layerIdx,
std::string mode,
std::string activation,
std::string inputVol,
std::string shortcutVol,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
nvinfer1::ITensor* shortcutInput,
#include <cassert>
nvinfer1::ITensor*
shortcutLayer(int layerIdx, std::string mode, std::string activation, std::string inputVol, std::string shortcutVol,
std::map<std::string, std::string>& block, nvinfer1::ITensor* input, nvinfer1::ITensor* shortcutInput,
nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;
@@ -25,19 +21,16 @@ nvinfer1::ITensor* shortcutLayer(
if (mode == "mul")
operation = nvinfer1::ElementWiseOperation::kPROD;
if (mode == "add" && inputVol != shortcutVol)
{
nvinfer1::ISliceLayer* slice = network->addSlice(
*shortcutInput, nvinfer1::Dims{3, {0, 0, 0}}, input->getDimensions(), nvinfer1::Dims{3, {1, 1, 1}});
if (mode == "add" && inputVol != shortcutVol) {
nvinfer1::ISliceLayer* slice = network->addSlice(*shortcutInput, nvinfer1::Dims{3, {0, 0, 0}}, input->getDimensions(),
nvinfer1::Dims{3, {1, 1, 1}});
assert(slice != nullptr);
std::string sliceLayerName = "slice_" + std::to_string(layerIdx);
slice->setName(sliceLayerName.c_str());
output = slice->getOutput(0);
}
else
{
output = shortcutInput;
}
nvinfer1::IElementWiseLayer* shortcut = network->addElementWise(*input, *output, operation);
assert(shortcut != nullptr);

View File

@@ -12,15 +12,8 @@
#include "activation_layer.h"
nvinfer1::ITensor* shortcutLayer(
int layerIdx,
std::string mode,
std::string activation,
std::string inputVol,
std::string shortcutVol,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
nvinfer1::ITensor* shortcut,
nvinfer1::INetworkDefinition* network);
nvinfer1::ITensor* shortcutLayer(int layerIdx, std::string mode, std::string activation, std::string inputVol,
std::string shortcutVol, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::ITensor* shortcut, nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -5,13 +5,9 @@
#include "shuffle_layer.h"
nvinfer1::ITensor* shuffleLayer(
int layerIdx,
std::string& layer,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
std::vector<nvinfer1::ITensor*> tensorOutputs,
nvinfer1::INetworkDefinition* network)
nvinfer1::ITensor*
shuffleLayer(int layerIdx, std::string& layer, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
std::vector<nvinfer1::ITensor*> tensorOutputs, nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;
@@ -22,25 +18,7 @@ nvinfer1::ITensor* shuffleLayer(
std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx);
shuffle->setName(shuffleLayerName.c_str());
if (block.find("reshape") != block.end())
{
std::string strReshape = block.at("reshape");
std::vector<int32_t> reshape;
size_t lastPos = 0, pos = 0;
while ((pos = strReshape.find(',', lastPos)) != std::string::npos)
{
int vL = std::stoi(trim(strReshape.substr(lastPos, pos - lastPos)));
reshape.push_back(vL);
lastPos = pos + 1;
}
if (lastPos < strReshape.length())
{
std::string lastV = trim(strReshape.substr(lastPos));
if (!lastV.empty())
reshape.push_back(std::stoi(lastV));
}
assert(!reshape.empty());
if (block.find("reshape") != block.end()) {
int from = -1;
if (block.find("from") != block.end())
from = std::stoi(block.at("from"));
@@ -51,33 +29,72 @@ nvinfer1::ITensor* shuffleLayer(
layer = std::to_string(from);
nvinfer1::Dims inputTensorDims = tensorOutputs[from]->getDimensions();
int32_t l = inputTensorDims.d[1] * inputTensorDims.d[2];
std::string strReshape = block.at("reshape");
std::vector<int32_t> reshape;
size_t lastPos = 0, pos = 0;
while ((pos = strReshape.find(',', lastPos)) != std::string::npos) {
std::string V = trim(strReshape.substr(lastPos, pos - lastPos));
if (V == "c")
reshape.push_back(inputTensorDims.d[0]);
else if (V == "ch")
reshape.push_back(inputTensorDims.d[0] * inputTensorDims.d[1]);
else if (V == "cw")
reshape.push_back(inputTensorDims.d[0] * inputTensorDims.d[2]);
else if (V == "h")
reshape.push_back(inputTensorDims.d[1]);
else if (V == "hw")
reshape.push_back(inputTensorDims.d[1] * inputTensorDims.d[2]);
else if (V == "w")
reshape.push_back(inputTensorDims.d[2]);
else if (V == "chw")
reshape.push_back(inputTensorDims.d[0] * inputTensorDims.d[1] * inputTensorDims.d[2]);
else
reshape.push_back(std::stoi(V));
lastPos = pos + 1;
}
if (lastPos < strReshape.length()) {
std::string lastV = trim(strReshape.substr(lastPos));
if (!lastV.empty()) {
if (lastV == "c")
reshape.push_back(inputTensorDims.d[0]);
else if (lastV == "ch")
reshape.push_back(inputTensorDims.d[0] * inputTensorDims.d[1]);
else if (lastV == "cw")
reshape.push_back(inputTensorDims.d[0] * inputTensorDims.d[2]);
else if (lastV == "h")
reshape.push_back(inputTensorDims.d[1]);
else if (lastV == "hw")
reshape.push_back(inputTensorDims.d[1] * inputTensorDims.d[2]);
else if (lastV == "w")
reshape.push_back(inputTensorDims.d[2]);
else if (lastV == "chw")
reshape.push_back(inputTensorDims.d[0] * inputTensorDims.d[1] * inputTensorDims.d[2]);
else
reshape.push_back(std::stoi(lastV));
}
}
assert(!reshape.empty());
nvinfer1::Dims reshapeDims;
reshapeDims.nbDims = reshape.size();
for (uint i = 0; i < reshape.size(); ++i)
if (reshape[i] == 0)
reshapeDims.d[i] = l;
else
reshapeDims.d[i] = reshape[i];
shuffle->setReshapeDimensions(reshapeDims);
}
if (block.find("transpose1") != block.end())
{
if (block.find("transpose1") != block.end()) {
std::string strTranspose1 = block.at("transpose1");
std::vector<int32_t> transpose1;
size_t lastPos = 0, pos = 0;
while ((pos = strTranspose1.find(',', lastPos)) != std::string::npos)
{
while ((pos = strTranspose1.find(',', lastPos)) != std::string::npos) {
int vL = std::stoi(trim(strTranspose1.substr(lastPos, pos - lastPos)));
transpose1.push_back(vL);
lastPos = pos + 1;
}
if (lastPos < strTranspose1.length())
{
if (lastPos < strTranspose1.length()) {
std::string lastV = trim(strTranspose1.substr(lastPos));
if (!lastV.empty())
transpose1.push_back(std::stoi(lastV));
@@ -91,19 +108,16 @@ nvinfer1::ITensor* shuffleLayer(
shuffle->setFirstTranspose(permutation1);
}
if (block.find("transpose2") != block.end())
{
if (block.find("transpose2") != block.end()) {
std::string strTranspose2 = block.at("transpose2");
std::vector<int32_t> transpose2;
size_t lastPos = 0, pos = 0;
while ((pos = strTranspose2.find(',', lastPos)) != std::string::npos)
{
while ((pos = strTranspose2.find(',', lastPos)) != std::string::npos) {
int vL = std::stoi(trim(strTranspose2.substr(lastPos, pos - lastPos)));
transpose2.push_back(vL);
lastPos = pos + 1;
}
if (lastPos < strTranspose2.length())
{
if (lastPos < strTranspose2.length()) {
std::string lastV = trim(strTranspose2.substr(lastPos));
if (!lastV.empty())
transpose2.push_back(std::stoi(lastV));

View File

@@ -6,15 +6,9 @@
#ifndef __SHUFFLE_LAYER_H__
#define __SHUFFLE_LAYER_H__
#include "NvInfer.h"
#include "../utils.h"
nvinfer1::ITensor* shuffleLayer(
int layerIdx,
std::string& layer,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
std::vector<nvinfer1::ITensor*> tensorOutputs,
nvinfer1::INetworkDefinition* network);
nvinfer1::ITensor* shuffleLayer(int layerIdx, std::string& layer, std::map<std::string, std::string>& block,
nvinfer1::ITensor* input, std::vector<nvinfer1::ITensor*> tensorOutputs, nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -5,10 +5,10 @@
#include "softmax_layer.h"
nvinfer1::ITensor* softmaxLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
#include <cassert>
nvinfer1::ITensor*
softmaxLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;

View File

@@ -7,14 +7,10 @@
#define __SOFTMAX_LAYER_H__
#include <map>
#include <cassert>
#include "NvInfer.h"
nvinfer1::ITensor* softmaxLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
nvinfer1::ITensor* softmaxLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -5,10 +5,10 @@
#include "upsample_layer.h"
nvinfer1::ITensor* upsampleLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
#include <cassert>
nvinfer1::ITensor*
upsampleLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network)
{
nvinfer1::ITensor* output;

View File

@@ -7,14 +7,10 @@
#define __UPSAMPLE_LAYER_H__
#include <map>
#include <cassert>
#include "NvInfer.h"
nvinfer1::ITensor* upsampleLayer(
int layerIdx,
std::map<std::string, std::string>& block,
nvinfer1::ITensor* input,
nvinfer1::ITensor* upsampleLayer(int layerIdx, std::map<std::string, std::string>& block, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network);
#endif

View File

@@ -23,16 +23,17 @@
* https://www.github.com/marcoslucianops
*/
#include <algorithm>
#include "nvdsinfer_custom_impl.h"
#include "nvdsinfer_context.h"
#include "yoloPlugins.h"
#include "yolo.h"
#include <algorithm>
#include "yolo.h"
#define USE_CUDA_ENGINE_GET_API 1
static bool getYoloNetworkInfo(NetworkInfo &networkInfo, const NvDsInferContextInitParams* initParams)
static bool
getYoloNetworkInfo(NetworkInfo& networkInfo, const NvDsInferContextInitParams* initParams)
{
std::string yoloCfg = initParams->customNetworkConfigFilePath;
std::string yoloType;
@@ -60,14 +61,12 @@ static bool getYoloNetworkInfo(NetworkInfo &networkInfo, const NvDsInferContextI
else if (initParams->networkMode == 2)
networkInfo.networkMode = "FP16";
if (networkInfo.configFilePath.empty() || networkInfo.wtsFilePath.empty())
{
if (networkInfo.configFilePath.empty() || networkInfo.wtsFilePath.empty()) {
std::cerr << "YOLO config file or weights file is not specified\n" << std::endl;
return false;
}
if (!fileExists(networkInfo.configFilePath) || !fileExists(networkInfo.wtsFilePath))
{
if (!fileExists(networkInfo.configFilePath) || !fileExists(networkInfo.wtsFilePath)) {
std::cerr << "YOLO config file or weights file is not exist\n" << std::endl;
return false;
}
@@ -76,8 +75,9 @@ static bool getYoloNetworkInfo(NetworkInfo &networkInfo, const NvDsInferContextI
}
#if !USE_CUDA_ENGINE_GET_API
IModelParser* NvDsInferCreateModelParser(
const NvDsInferContextInitParams* initParams) {
IModelParser*
NvDsInferCreateModelParser(const NvDsInferContextInitParams* initParams)
{
NetworkInfo networkInfo;
if (!getYoloNetworkInfo(networkInfo, initParams))
return nullptr;
@@ -85,19 +85,13 @@ IModelParser* NvDsInferCreateModelParser(
return new Yolo(networkInfo);
}
#else
extern "C"
bool NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder * const builder,
nvinfer1::IBuilderConfig * const builderConfig,
const NvDsInferContextInitParams * const initParams,
nvinfer1::DataType dataType,
nvinfer1::ICudaEngine *& cudaEngine);
extern "C" bool
NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder* const builder, nvinfer1::IBuilderConfig* const builderConfig,
const NvDsInferContextInitParams* const initParams, nvinfer1::DataType dataType, nvinfer1::ICudaEngine*& cudaEngine);
extern "C"
bool NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder * const builder,
nvinfer1::IBuilderConfig * const builderConfig,
const NvDsInferContextInitParams * const initParams,
nvinfer1::DataType dataType,
nvinfer1::ICudaEngine *& cudaEngine)
extern "C" bool
NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder* const builder, nvinfer1::IBuilderConfig* const builderConfig,
const NvDsInferContextInitParams* const initParams, nvinfer1::DataType dataType, nvinfer1::ICudaEngine*& cudaEngine)
{
NetworkInfo networkInfo;
if (!getYoloNetworkInfo(networkInfo, initParams))
@@ -105,8 +99,7 @@ bool NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder * const builder,
Yolo yolo(networkInfo);
cudaEngine = yolo.createEngine(builder, builderConfig);
if (cudaEngine == nullptr)
{
if (cudaEngine == nullptr) {
std::cerr << "Failed to build CUDA engine on " << networkInfo.configFilePath << std::endl;
return false;
}

View File

@@ -23,20 +23,17 @@
* https://www.github.com/marcoslucianops
*/
#include <algorithm>
#include <cmath>
#include <sstream>
#include "nvdsinfer_custom_impl.h"
#include "utils.h"
#include "utils.h"
#include "yoloPlugins.h"
extern "C" bool NvDsInferParseYolo(
std::vector<NvDsInferLayerInfo> const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo,
extern "C" bool
NvDsInferParseYolo(std::vector<NvDsInferLayerInfo> const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams, std::vector<NvDsInferParseObjectInfo>& objectList);
static NvDsInferParseObjectInfo convertBBox(
const float& bx1, const float& by1, const float& bx2, const float& by2, const uint& netW, const uint& netH)
static NvDsInferParseObjectInfo
convertBBox(const float& bx1, const float& by1, const float& bx2, const float& by2, const uint& netW, const uint& netH)
{
NvDsInferParseObjectInfo b;
@@ -58,8 +55,8 @@ static NvDsInferParseObjectInfo convertBBox(
return b;
}
static void addBBoxProposal(
const float bx1, const float by1, const float bx2, const float by2, const uint& netW, const uint& netH,
static void
addBBoxProposal(const float bx1, const float by1, const float bx2, const float by2, const uint& netW, const uint& netH,
const int maxIndex, const float maxProb, std::vector<NvDsInferParseObjectInfo>& binfo)
{
NvDsInferParseObjectInfo bbi = convertBBox(bx1, by1, bx2, by2, netW, netH);
@@ -70,14 +67,14 @@ static void addBBoxProposal(
binfo.push_back(bbi);
}
static std::vector<NvDsInferParseObjectInfo> decodeYoloTensor(
const int* counts, const float* boxes, const float* scores, const int* classes, const uint& netW, const uint& netH)
static std::vector<NvDsInferParseObjectInfo>
decodeYoloTensor(const int* counts, const float* boxes, const float* scores, const int* classes, const uint& netW,
const uint& netH)
{
std::vector<NvDsInferParseObjectInfo> binfo;
uint numBoxes = counts[0];
for (uint b = 0; b < numBoxes; ++b)
{
for (uint b = 0; b < numBoxes; ++b) {
float bx1 = boxes[b * 4 + 0];
float by1 = boxes[b * 4 + 1];
float bx2 = boxes[b * 4 + 2];
@@ -91,23 +88,15 @@ static std::vector<NvDsInferParseObjectInfo> decodeYoloTensor(
return binfo;
}
static bool NvDsInferParseCustomYolo(
std::vector<NvDsInferLayerInfo> const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams, std::vector<NvDsInferParseObjectInfo>& objectList,
const uint &numClasses)
{
if (outputLayersInfo.empty())
static bool
NvDsInferParseCustomYolo(std::vector<NvDsInferLayerInfo> const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams, std::vector<NvDsInferParseObjectInfo>& objectList)
{
if (outputLayersInfo.empty()) {
std::cerr << "ERROR: Could not find output layer in bbox parsing" << std::endl;
return false;
}
if (numClasses != detectionParams.numClassesConfigured)
{
std::cerr << "WARNING: Num classes mismatch. Configured: " << detectionParams.numClassesConfigured
<< ", detected by network: " << numClasses << std::endl;
}
std::vector<NvDsInferParseObjectInfo> objects;
const NvDsInferLayerInfo& counts = outputLayersInfo[0];
@@ -115,10 +104,9 @@ static bool NvDsInferParseCustomYolo(
const NvDsInferLayerInfo& scores = outputLayersInfo[2];
const NvDsInferLayerInfo& classes = outputLayersInfo[3];
std::vector<NvDsInferParseObjectInfo> outObjs =
decodeYoloTensor(
(const int*)(counts.buffer), (const float*)(boxes.buffer), (const float*)(scores.buffer),
(const int*)(classes.buffer), networkInfo.width, networkInfo.height);
std::vector<NvDsInferParseObjectInfo> outObjs = decodeYoloTensor((const int*) (counts.buffer),
(const float*) (boxes.buffer), (const float*) (scores.buffer), (const int*) (classes.buffer), networkInfo.width,
networkInfo.height);
objects.insert(objects.end(), outObjs.begin(), outObjs.end());
@@ -127,14 +115,11 @@ static bool NvDsInferParseCustomYolo(
return true;
}
extern "C" bool NvDsInferParseYolo(
std::vector<NvDsInferLayerInfo> const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo,
extern "C" bool
NvDsInferParseYolo(std::vector<NvDsInferLayerInfo> const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams, std::vector<NvDsInferParseObjectInfo>& objectList)
{
int num_classes = kNUM_CLASSES;
return NvDsInferParseCustomYolo (
outputLayersInfo, networkInfo, detectionParams, objectList, num_classes);
return NvDsInferParseCustomYolo(outputLayersInfo, networkInfo, detectionParams, objectList);
}
CHECK_CUSTOM_PARSE_FUNC_PROTOTYPE(NvDsInferParseYolo);

View File

@@ -25,45 +25,50 @@
#include "utils.h"
#include <experimental/filesystem>
#include <iomanip>
#include <algorithm>
#include <math.h>
#include <experimental/filesystem>
static void leftTrim(std::string& s)
static void
leftTrim(std::string& s)
{
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !isspace(ch); }));
}
static void rightTrim(std::string& s)
static void
rightTrim(std::string& s)
{
s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !isspace(ch); }).base(), s.end());
}
std::string trim(std::string s)
std::string
trim(std::string s)
{
leftTrim(s);
rightTrim(s);
return s;
}
float clamp(const float val, const float minVal, const float maxVal)
float
clamp(const float val, const float minVal, const float maxVal)
{
assert(minVal <= maxVal);
return std::min(maxVal, std::max(minVal, val));
}
bool fileExists(const std::string fileName, bool verbose)
bool
fileExists(const std::string fileName, bool verbose)
{
if (!std::experimental::filesystem::exists(std::experimental::filesystem::path(fileName)))
{
if (verbose) std::cout << "\nFile does not exist: " << fileName << std::endl;
if (!std::experimental::filesystem::exists(std::experimental::filesystem::path(fileName))) {
if (verbose)
std::cout << "\nFile does not exist: " << fileName << std::endl;
return false;
}
return true;
}
std::vector<float> loadWeights(const std::string weightsFilePath, const std::string& networkType)
std::vector<float>
loadWeights(const std::string weightsFilePath, const std::string& networkType)
{
assert(fileExists(weightsFilePath));
std::cout << "\nLoading pre-trained weights" << std::endl;
@@ -75,27 +80,24 @@ std::vector<float> loadWeights(const std::string weightsFilePath, const std::str
assert(file.good());
std::string line;
if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos)
{
if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos) {
// Remove 4 int32 bytes of data from the stream belonging to the header
file.ignore(4 * 4);
}
else
{
else {
// Remove 5 int32 bytes of data from the stream belonging to the header
file.ignore(4 * 5);
}
char floatWeight[4];
while (!file.eof())
{
while (!file.eof()) {
file.read(floatWeight, 4);
assert(file.gcount() == 4);
weights.push_back(*reinterpret_cast<float*>(floatWeight));
if (file.peek() == std::istream::traits_type::eof()) break;
if (file.peek() == std::istream::traits_type::eof())
break;
}
}
else if (weightsFilePath.find(".wts") != std::string::npos) {
std::ifstream file(weightsFilePath);
assert(file.good());
@@ -109,29 +111,29 @@ std::vector<float> loadWeights(const std::string weightsFilePath, const std::str
while (count--) {
file >> name >> std::dec >> size;
for (uint32_t x = 0, y = size; x < y; ++x)
{
for (uint32_t x = 0, y = size; x < y; ++x) {
file >> std::hex >> floatWeight;
weights.push_back(*reinterpret_cast<float*>(&floatWeight));
};
}
}
else {
std::cerr << "\nFile " << weightsFilePath << " is not supported" << std::endl;
std::abort();
assert(0);
}
std::cout << "Loading weights of " << networkType << " complete"
<< std::endl;
std::cout << "Loading weights of " << networkType << " complete" << std::endl;
std::cout << "Total weights read: " << weights.size() << std::endl;
return weights;
}
std::string dimsToString(const nvinfer1::Dims d)
std::string
dimsToString(const nvinfer1::Dims d)
{
std::stringstream s;
assert(d.nbDims >= 1);
std::stringstream s;
s << "[";
for (int i = 0; i < d.nbDims - 1; ++i)
s << d.d[i] << ", ";
@@ -140,7 +142,8 @@ std::string dimsToString(const nvinfer1::Dims d)
return s.str();
}
int getNumChannels(nvinfer1::ITensor* t)
int
getNumChannels(nvinfer1::ITensor* t)
{
nvinfer1::Dims d = t->getDimensions();
assert(d.nbDims == 3);
@@ -148,8 +151,9 @@ int getNumChannels(nvinfer1::ITensor* t)
return d.d[0];
}
void printLayerInfo(
std::string layerIndex, std::string layerName, std::string layerInput, std::string layerOutput, std::string weightPtr)
void
printLayerInfo(std::string layerIndex, std::string layerName, std::string layerInput, std::string layerOutput,
std::string weightPtr)
{
std::cout << std::setw(8) << std::left << layerIndex << std::setw(30) << std::left << layerName;
std::cout << std::setw(20) << std::left << layerInput << std::setw(20) << std::left << layerOutput;

View File

@@ -23,7 +23,6 @@
* https://www.github.com/marcoslucianops
*/
#ifndef __UTILS_H__
#define __UTILS_H__
@@ -36,11 +35,17 @@
#include "NvInfer.h"
std::string trim(std::string s);
float clamp(const float val, const float minVal, const float maxVal);
bool fileExists(const std::string fileName, bool verbose = true);
std::vector<float> loadWeights(const std::string weightsFilePath, const std::string& networkType);
std::string dimsToString(const nvinfer1::Dims d);
int getNumChannels(nvinfer1::ITensor* t);
void printLayerInfo(
std::string layerIndex, std::string layerName, std::string layerInput, std::string layerOutput, std::string weightPtr);

View File

@@ -25,39 +25,27 @@
#include "yolo.h"
#include "yoloPlugins.h"
#include <stdlib.h>
#ifdef OPENCV
#include "calibrator.h"
#endif
Yolo::Yolo(const NetworkInfo& networkInfo)
: m_InputBlobName(networkInfo.inputBlobName),
m_NetworkType(networkInfo.networkType),
m_ConfigFilePath(networkInfo.configFilePath),
m_WtsFilePath(networkInfo.wtsFilePath),
m_Int8CalibPath(networkInfo.int8CalibPath),
m_DeviceType(networkInfo.deviceType),
m_NumDetectedClasses(networkInfo.numDetectedClasses),
m_ClusterMode(networkInfo.clusterMode),
m_NetworkMode(networkInfo.networkMode),
m_ScoreThreshold(networkInfo.scoreThreshold),
m_InputH(0),
m_InputW(0),
m_InputC(0),
m_InputSize(0),
m_NumClasses(0),
m_LetterBox(0),
m_NewCoords(0),
m_YoloCount(0)
{}
Yolo::Yolo(const NetworkInfo& networkInfo) : m_InputBlobName(networkInfo.inputBlobName),
m_NetworkType(networkInfo.networkType), m_ConfigFilePath(networkInfo.configFilePath),
m_WtsFilePath(networkInfo.wtsFilePath), m_Int8CalibPath(networkInfo.int8CalibPath), m_DeviceType(networkInfo.deviceType),
m_NumDetectedClasses(networkInfo.numDetectedClasses), m_ClusterMode(networkInfo.clusterMode),
m_NetworkMode(networkInfo.networkMode), m_ScoreThreshold(networkInfo.scoreThreshold), m_InputH(0), m_InputW(0),
m_InputC(0), m_InputSize(0), m_NumClasses(0), m_LetterBox(0), m_NewCoords(0), m_YoloCount(0)
{
}
Yolo::~Yolo()
{
destroyNetworkUtils();
}
nvinfer1::ICudaEngine *Yolo::createEngine(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config)
nvinfer1::ICudaEngine*
Yolo::createEngine(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config)
{
assert (builder);
@@ -65,52 +53,44 @@ nvinfer1::ICudaEngine *Yolo::createEngine(nvinfer1::IBuilder* builder, nvinfer1:
parseConfigBlocks();
nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0);
if (parseModel(*network) != NVDSINFER_SUCCESS)
{
if (parseModel(*network) != NVDSINFER_SUCCESS) {
delete network;
return nullptr;
}
std::cout << "Building the TensorRT Engine\n" << std::endl;
if (m_NumClasses != m_NumDetectedClasses)
{
if (m_NumClasses != m_NumDetectedClasses) {
std::cout << "NOTE: Number of classes mismatch, make sure to set num-detected-classes=" << m_NumClasses
<< " in config_infer file\n" << std::endl;
}
if (m_LetterBox == 1)
{
if (m_LetterBox == 1) {
std::cout << "NOTE: letter_box is set in cfg file, make sure to set maintain-aspect-ratio=1 in config_infer file"
<< " to get better accuracy\n" << std::endl;
}
if (m_ClusterMode != 2)
{
std::cout << "NOTE: Wrong cluster-mode is set, make sure to set cluster-mode=2 in config_infer file\n"
<< std::endl;
if (m_ClusterMode != 2) {
std::cout << "NOTE: Wrong cluster-mode is set, make sure to set cluster-mode=2 in config_infer file\n" << std::endl;
}
if (m_NetworkMode == "INT8" && !fileExists(m_Int8CalibPath))
{
if (m_NetworkMode == "INT8" && !fileExists(m_Int8CalibPath)) {
assert(builder->platformHasFastInt8());
#ifdef OPENCV
std::string calib_image_list;
int calib_batch_size;
if (getenv("INT8_CALIB_IMG_PATH"))
calib_image_list = getenv("INT8_CALIB_IMG_PATH");
else
{
else {
std::cerr << "INT8_CALIB_IMG_PATH not set" << std::endl;
std::abort();
assert(0);
}
if (getenv("INT8_CALIB_BATCH_SIZE"))
calib_batch_size = std::stoi(getenv("INT8_CALIB_BATCH_SIZE"));
else
{
else {
std::cerr << "INT8_CALIB_BATCH_SIZE not set" << std::endl;
std::abort();
assert(0);
}
nvinfer1::Int8EntropyCalibrator2 *calibrator = new nvinfer1::Int8EntropyCalibrator2(
calib_batch_size, m_InputC, m_InputH, m_InputW, m_LetterBox, calib_image_list, m_Int8CalibPath);
nvinfer1::IInt8EntropyCalibrator2 *calibrator = new Int8EntropyCalibrator2(calib_batch_size, m_InputC, m_InputH,
m_InputW, m_LetterBox, calib_image_list, m_Int8CalibPath);
config->setFlag(nvinfer1::BuilderFlag::kINT8);
config->setInt8Calibrator(calibrator);
#else
@@ -129,7 +109,8 @@ nvinfer1::ICudaEngine *Yolo::createEngine(nvinfer1::IBuilder* builder, nvinfer1:
return engine;
}
NvDsInferStatus Yolo::parseModel(nvinfer1::INetworkDefinition& network) {
NvDsInferStatus
Yolo::parseModel(nvinfer1::INetworkDefinition& network) {
destroyNetworkUtils();
std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType);
@@ -144,24 +125,23 @@ NvDsInferStatus Yolo::parseModel(nvinfer1::INetworkDefinition& network) {
return status;
}
NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::INetworkDefinition& network)
NvDsInferStatus
Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::INetworkDefinition& network)
{
int weightPtr = 0;
std::string weightsType;
std::string weightsType = "wts";
if (m_WtsFilePath.find(".weights") != std::string::npos)
weightsType = "weights";
else
weightsType = "wts";
float eps = 1.0e-5;
if (m_NetworkType.find("yolov5") != std::string::npos || m_NetworkType.find("yolov7") != std::string::npos)
if (m_NetworkType.find("yolov5") != std::string::npos || m_NetworkType.find("yolov7") != std::string::npos ||
m_NetworkType.find("yolov8") != std::string::npos)
eps = 1.0e-3;
else if (m_NetworkType.find("yolor") != std::string::npos)
eps = 1.0e-4;
nvinfer1::ITensor* data = network.addInput(
m_InputBlobName.c_str(), nvinfer1::DataType::kFLOAT,
nvinfer1::ITensor* data = network.addInput(m_InputBlobName.c_str(), nvinfer1::DataType::kFLOAT,
nvinfer1::Dims{3, {static_cast<int>(m_InputC), static_cast<int>(m_InputH), static_cast<int>(m_InputW)}});
assert(data != nullptr && data->getDimensions().nbDims > 0);
@@ -173,40 +153,42 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
int modelType = -1;
for (uint i = 0; i < m_ConfigBlocks.size(); ++i)
{
for (uint i = 0; i < m_ConfigBlocks.size(); ++i) {
std::string layerIndex = "(" + std::to_string(tensorOutputs.size()) + ")";
if (m_ConfigBlocks.at(i).at("type") == "net")
printLayerInfo("", "Layer", "Input Shape", "Output Shape", "WeightPtr");
else if (m_ConfigBlocks.at(i).at("type") == "convolutional")
{
else if (m_ConfigBlocks.at(i).at("type") == "convolutional") {
int channels = getNumChannels(previous);
std::string inputVol = dimsToString(previous->getDimensions());
previous = convolutionalLayer(
i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, weightsType, channels, eps, previous, &network);
previous = convolutionalLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, weightsType, channels, eps,
previous, &network);
assert(previous != nullptr);
std::string outputVol = dimsToString(previous->getDimensions());
tensorOutputs.push_back(previous);
std::string layerName = "conv_" + m_ConfigBlocks.at(i).at("activation");
printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr));
}
else if (m_ConfigBlocks.at(i).at("type") == "batchnorm")
{
else if (m_ConfigBlocks.at(i).at("type") == "c2f") {
std::string inputVol = dimsToString(previous->getDimensions());
previous = batchnormLayer(
i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, weightsType, eps, previous, &network);
previous = c2fLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, weightsType, eps, previous, &network);
assert(previous != nullptr);
std::string outputVol = dimsToString(previous->getDimensions());
tensorOutputs.push_back(previous);
std::string layerName = "c2f_" + m_ConfigBlocks.at(i).at("activation");
printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr));
}
else if (m_ConfigBlocks.at(i).at("type") == "batchnorm") {
std::string inputVol = dimsToString(previous->getDimensions());
previous = batchnormLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, weightsType, eps, previous,
&network);
assert(previous != nullptr);
std::string outputVol = dimsToString(previous->getDimensions());
tensorOutputs.push_back(previous);
std::string layerName = "batchnorm_" + m_ConfigBlocks.at(i).at("activation");
printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr));
}
else if (m_ConfigBlocks.at(i).at("type") == "implicit_add" || m_ConfigBlocks.at(i).at("type") == "implicit_mul")
{
else if (m_ConfigBlocks.at(i).at("type") == "implicit_add" || m_ConfigBlocks.at(i).at("type") == "implicit_mul") {
previous = implicitLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, &network);
assert(previous != nullptr);
std::string outputVol = dimsToString(previous->getDimensions());
@@ -214,10 +196,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
std::string layerName = m_ConfigBlocks.at(i).at("type");
printLayerInfo(layerIndex, layerName, "-", outputVol, std::to_string(weightPtr));
}
else if (m_ConfigBlocks.at(i).at("type") == "shift_channels" ||
m_ConfigBlocks.at(i).at("type") == "control_channels")
{
else if (m_ConfigBlocks.at(i).at("type") == "shift_channels" || m_ConfigBlocks.at(i).at("type") == "control_channels") {
assert(m_ConfigBlocks.at(i).find("from") != m_ConfigBlocks.at(i).end());
int from = stoi(m_ConfigBlocks.at(i).at("from"));
if (from > 0)
@@ -234,9 +213,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
std::string layerName = m_ConfigBlocks.at(i).at("type") + ": " + std::to_string(i + from - 1);
printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-");
}
else if (m_ConfigBlocks.at(i).at("type") == "shortcut")
{
else if (m_ConfigBlocks.at(i).at("type") == "shortcut") {
assert(m_ConfigBlocks.at(i).find("from") != m_ConfigBlocks.at(i).end());
int from = stoi(m_ConfigBlocks.at(i).at("from"));
if (from > 0)
@@ -255,9 +232,8 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
std::string inputVol = dimsToString(previous->getDimensions());
std::string shortcutVol = dimsToString(tensorOutputs[i + from - 1]->getDimensions());
previous = shortcutLayer(
i, mode, activation, inputVol, shortcutVol, m_ConfigBlocks.at(i), previous, tensorOutputs[i + from - 1],
&network);
previous = shortcutLayer(i, mode, activation, inputVol, shortcutVol, m_ConfigBlocks.at(i), previous,
tensorOutputs[i + from - 1], &network);
assert(previous != nullptr);
std::string outputVol = dimsToString(previous->getDimensions());
tensorOutputs.push_back(previous);
@@ -267,9 +243,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
if (mode == "add" && inputVol != shortcutVol)
std::cout << inputVol << " +" << shortcutVol << std::endl;
}
else if (m_ConfigBlocks.at(i).at("type") == "route")
{
else if (m_ConfigBlocks.at(i).at("type") == "route") {
std::string layers;
previous = routeLayer(i, layers, m_ConfigBlocks.at(i), tensorOutputs, &network);
assert(previous != nullptr);
@@ -278,9 +252,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
std::string layerName = "route: " + layers;
printLayerInfo(layerIndex, layerName, "-", outputVol, "-");
}
else if (m_ConfigBlocks.at(i).at("type") == "upsample")
{
else if (m_ConfigBlocks.at(i).at("type") == "upsample") {
std::string inputVol = dimsToString(previous->getDimensions());
previous = upsampleLayer(i, m_ConfigBlocks[i], previous, &network);
assert(previous != nullptr);
@@ -289,9 +261,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
std::string layerName = "upsample";
printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-");
}
else if (m_ConfigBlocks.at(i).at("type") == "maxpool" || m_ConfigBlocks.at(i).at("type") == "avgpool")
{
else if (m_ConfigBlocks.at(i).at("type") == "maxpool" || m_ConfigBlocks.at(i).at("type") == "avgpool") {
std::string inputVol = dimsToString(previous->getDimensions());
previous = poolingLayer(i, m_ConfigBlocks.at(i), previous, &network);
assert(previous != nullptr);
@@ -300,9 +270,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
std::string layerName = m_ConfigBlocks.at(i).at("type");
printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-");
}
else if (m_ConfigBlocks.at(i).at("type") == "reorg")
{
else if (m_ConfigBlocks.at(i).at("type") == "reorg") {
std::string inputVol = dimsToString(previous->getDimensions());
if (m_NetworkType.find("yolov2") != std::string::npos) {
nvinfer1::IPluginV2* reorgPlugin = createReorgPlugin(2);
@@ -321,9 +289,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
std::string layerName = "reorg";
printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-");
}
else if (m_ConfigBlocks.at(i).at("type") == "reduce")
{
else if (m_ConfigBlocks.at(i).at("type") == "reduce") {
std::string inputVol = dimsToString(previous->getDimensions());
previous = reduceLayer(i, m_ConfigBlocks.at(i), previous, &network);
assert(previous != nullptr);
@@ -332,9 +298,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
std::string layerName = "reduce";
printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-");
}
else if (m_ConfigBlocks.at(i).at("type") == "shuffle")
{
else if (m_ConfigBlocks.at(i).at("type") == "shuffle") {
std::string layer;
std::string inputVol = dimsToString(previous->getDimensions());
previous = shuffleLayer(i, layer, m_ConfigBlocks.at(i), previous, tensorOutputs, &network);
@@ -344,9 +308,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
std::string layerName = "shuffle: " + layer;
printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-");
}
else if (m_ConfigBlocks.at(i).at("type") == "softmax")
{
else if (m_ConfigBlocks.at(i).at("type") == "softmax") {
std::string inputVol = dimsToString(previous->getDimensions());
previous = softmaxLayer(i, m_ConfigBlocks.at(i), previous, &network);
assert(previous != nullptr);
@@ -355,9 +317,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
std::string layerName = "softmax";
printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-");
}
else if (m_ConfigBlocks.at(i).at("type") == "yolo" || m_ConfigBlocks.at(i).at("type") == "region")
{
else if (m_ConfigBlocks.at(i).at("type") == "yolo" || m_ConfigBlocks.at(i).at("type") == "region") {
if (m_ConfigBlocks.at(i).at("type") == "yolo")
if (m_NetworkType.find("yolor") != std::string::npos)
modelType = 2;
@@ -380,9 +340,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
std::string layerName = modelType != 0 ? "yolo" : "region";
printLayerInfo(layerIndex, layerName, inputVol, "-", "-");
}
else if (m_ConfigBlocks.at(i).at("type") == "cls")
{
else if (m_ConfigBlocks.at(i).at("type") == "cls") {
modelType = 3;
std::string blobName = "cls_" + std::to_string(i);
@@ -402,9 +360,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
std::string layerName = "cls";
printLayerInfo(layerIndex, layerName, inputVol, outputVol, "-");
}
else if (m_ConfigBlocks.at(i).at("type") == "reg")
{
else if (m_ConfigBlocks.at(i).at("type") == "reg") {
modelType = 3;
std::string blobName = "reg_" + std::to_string(i);
@@ -423,36 +379,50 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
std::string layerName = "reg";
printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr));
}
else if (m_ConfigBlocks.at(i).at("type") == "detect_v8") {
modelType = 4;
else
{
std::cout << "\nUnsupported layer type --> \"" << m_ConfigBlocks.at(i).at("type") << "\"" << std::endl;
std::string blobName = "detect_v8_" + std::to_string(i);
nvinfer1::Dims prevTensorDims = previous->getDimensions();
TensorInfo& curYoloTensor = m_YoloTensors.at(yoloCountInputs);
curYoloTensor.blobName = blobName;
curYoloTensor.numBBoxes = prevTensorDims.d[1];
std::string inputVol = dimsToString(previous->getDimensions());
previous = detectV8Layer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, previous, &network);
assert(previous != nullptr);
std::string outputVol = dimsToString(previous->getDimensions());
tensorOutputs.push_back(previous);
yoloTensorInputs[yoloCountInputs] = previous;
++yoloCountInputs;
std::string layerName = "detect_v8";
printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr));
}
else {
std::cerr << "\nUnsupported layer type --> \"" << m_ConfigBlocks.at(i).at("type") << "\"" << std::endl;
assert(0);
}
}
if ((int)weights.size() != weightPtr)
{
std::cout << "\nNumber of unused weights left: " << weights.size() - weightPtr << std::endl;
if ((int) weights.size() != weightPtr) {
std::cerr << "\nNumber of unused weights left: " << weights.size() - weightPtr << std::endl;
assert(0);
}
if (m_YoloCount == yoloCountInputs)
{
if (m_YoloCount == yoloCountInputs) {
assert((modelType != -1) && "\nCould not determine model type");
uint64_t outputSize = 0;
for (uint j = 0; j < yoloCountInputs; ++j)
{
for (uint j = 0; j < yoloCountInputs; ++j) {
TensorInfo& curYoloTensor = m_YoloTensors.at(j);
if (modelType == 3)
if (modelType == 3 || modelType == 4)
outputSize = curYoloTensor.numBBoxes;
else
outputSize += curYoloTensor.gridSizeX * curYoloTensor.gridSizeY * curYoloTensor.numBBoxes;
}
nvinfer1::IPluginV2* yoloPlugin = new YoloLayer(
m_InputW, m_InputH, m_NumClasses, m_NewCoords, m_YoloTensors, outputSize, modelType, m_ScoreThreshold);
nvinfer1::IPluginV2* yoloPlugin = new YoloLayer(m_InputW, m_InputH, m_NumClasses, m_NewCoords, m_YoloTensors, outputSize,
modelType, m_ScoreThreshold);
assert(yoloPlugin != nullptr);
nvinfer1::IPluginV2Layer* yolo = network.addPluginV2(yoloTensorInputs, m_YoloCount, *yoloPlugin);
assert(yolo != nullptr);
@@ -478,15 +448,13 @@ NvDsInferStatus Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::IN
network.markOutput(*detection_classes);
}
else {
std::cout << "\nError in yolo cfg file" << std::endl;
std::cerr << "\nError in yolo cfg file" << std::endl;
assert(0);
}
std::cout << "\nOutput YOLO blob names: " << std::endl;
for (auto& tensor : m_YoloTensors)
{
std::cout << tensor.blobName << std::endl;
}
int nbLayers = network.getNbLayers();
std::cout << "\nTotal number of YOLO layers: " << nbLayers << "\n" << std::endl;
@@ -504,16 +472,13 @@ Yolo::parseConfigFile (const std::string cfgFilePath)
std::vector<std::map<std::string, std::string>> blocks;
std::map<std::string, std::string> block;
while (getline(file, line))
{
if (line.size() == 0) continue;
if (line.front() == ' ') continue;
if (line.front() == '#') continue;
while (getline(file, line)) {
if (line.size() == 0 || line.front() == ' ' || line.front() == '#')
continue;
line = trim(line);
if (line.front() == '[')
{
if (block.size() > 0)
{
if (line.front() == '[') {
if (block.size() > 0) {
blocks.push_back(block);
block.clear();
}
@@ -521,24 +486,23 @@ Yolo::parseConfigFile (const std::string cfgFilePath)
std::string value = trim(line.substr(1, line.size() - 2));
block.insert(std::pair<std::string, std::string>(key, value));
}
else
{
else {
int cpos = line.find('=');
std::string key = trim(line.substr(0, cpos));
std::string value = trim(line.substr(cpos + 1));
block.insert(std::pair<std::string, std::string>(key, value));
}
}
blocks.push_back(block);
return blocks;
}
void Yolo::parseConfigBlocks()
{
for (auto block : m_ConfigBlocks)
{
if (block.at("type") == "net")
void
Yolo::parseConfigBlocks()
{
for (auto block : m_ConfigBlocks) {
if (block.at("type") == "net") {
assert((block.find("height") != block.end()) && "Missing 'height' param in network cfg");
assert((block.find("width") != block.end()) && "Missing 'width' param in network cfg");
assert((block.find("channels") != block.end()) && "Missing 'channels' param in network cfg");
@@ -549,62 +513,51 @@ void Yolo::parseConfigBlocks()
m_InputSize = m_InputC * m_InputH * m_InputW;
if (block.find("letter_box") != block.end())
{
m_LetterBox = std::stoul(block.at("letter_box"));
}
}
else if ((block.at("type") == "region") || (block.at("type") == "yolo"))
{
assert((block.find("num") != block.end())
&& std::string("Missing 'num' param in " + block.at("type") + " layer").c_str());
assert((block.find("classes") != block.end())
&& std::string("Missing 'classes' param in " + block.at("type") + " layer").c_str());
assert((block.find("anchors") != block.end())
&& std::string("Missing 'anchors' param in " + block.at("type") + " layer").c_str());
assert((block.find("num") != block.end()) &&
std::string("Missing 'num' param in " + block.at("type") + " layer").c_str());
assert((block.find("classes") != block.end()) &&
std::string("Missing 'classes' param in " + block.at("type") + " layer").c_str());
assert((block.find("anchors") != block.end()) &&
std::string("Missing 'anchors' param in " + block.at("type") + " layer").c_str());
++m_YoloCount;
m_NumClasses = std::stoul(block.at("classes"));
if (block.find("new_coords") != block.end())
{
m_NewCoords = std::stoul(block.at("new_coords"));
}
TensorInfo outputTensor;
std::string anchorString = block.at("anchors");
while (!anchorString.empty())
{
while (!anchorString.empty()) {
int npos = anchorString.find_first_of(',');
if (npos != -1)
{
if (npos != -1) {
float anchor = std::stof(trim(anchorString.substr(0, npos)));
outputTensor.anchors.push_back(anchor);
anchorString.erase(0, npos + 1);
}
else
{
else {
float anchor = std::stof(trim(anchorString));
outputTensor.anchors.push_back(anchor);
break;
}
}
if (block.find("mask") != block.end())
{
if (block.find("mask") != block.end()) {
std::string maskString = block.at("mask");
while (!maskString.empty())
{
while (!maskString.empty()) {
int npos = maskString.find_first_of(',');
if (npos != -1)
{
if (npos != -1) {
int mask = std::stoul(trim(maskString.substr(0, npos)));
outputTensor.mask.push_back(mask);
maskString.erase(0, npos + 1);
}
else
{
else {
int mask = std::stoul(trim(maskString));
outputTensor.mask.push_back(mask);
break;
@@ -613,29 +566,32 @@ void Yolo::parseConfigBlocks()
}
if (block.find("scale_x_y") != block.end())
{
outputTensor.scaleXY = std::stof(block.at("scale_x_y"));
}
else
{
outputTensor.scaleXY = 1.0;
}
outputTensor.numBBoxes
= outputTensor.mask.size() > 0 ? outputTensor.mask.size() : std::stoul(trim(block.at("num")));
outputTensor.numBBoxes = outputTensor.mask.size() > 0 ? outputTensor.mask.size() : std::stoul(trim(block.at("num")));
m_YoloTensors.push_back(outputTensor);
}
else if ((block.at("type") == "cls") || (block.at("type") == "reg"))
{
else if ((block.at("type") == "cls") || (block.at("type") == "reg")) {
++m_YoloCount;
TensorInfo outputTensor;
m_YoloTensors.push_back(outputTensor);
}
else if (block.at("type") == "detect_v8") {
++m_YoloCount;
m_NumClasses = std::stoul(block.at("classes"));
TensorInfo outputTensor;
m_YoloTensors.push_back(outputTensor);
}
}
}
void Yolo::destroyNetworkUtils()
void
Yolo::destroyNetworkUtils()
{
for (uint i = 0; i < m_TrtWeights.size(); ++i)
if (m_TrtWeights[i].count > 0)

View File

@@ -26,7 +26,11 @@
#ifndef _YOLO_H_
#define _YOLO_H_
#include "NvInferPlugin.h"
#include "nvdsinfer_custom_impl.h"
#include "layers/convolutional_layer.h"
#include "layers/c2f_layer.h"
#include "layers/batchnorm_layer.h"
#include "layers/implicit_layer.h"
#include "layers/channels_layer.h"
@@ -40,8 +44,7 @@
#include "layers/softmax_layer.h"
#include "layers/cls_layer.h"
#include "layers/reg_layer.h"
#include "nvdsinfer_custom_impl.h"
#include "layers/detect_v8_layer.h"
struct NetworkInfo
{

View File

@@ -7,10 +7,10 @@
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
__global__ void gpuYoloLayer(
const float* input, int* num_detections, float* detection_boxes, float* detection_scores, int* detection_classes,
const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, const uint gridSizeY,
const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors, const int* mask)
__global__ void gpuYoloLayer(const float* input, int* num_detections, float* detection_boxes, float* detection_scores,
int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX,
const uint gridSizeY, const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors,
const int* mask)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
@@ -22,8 +22,7 @@ __global__ void gpuYoloLayer(
const int numGridCells = gridSizeX * gridSizeY;
const int bbindex = y_id * gridSizeX + x_id;
const float objectness
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
const float objectness = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
if (objectness < scoreThreshold)
return;
@@ -33,32 +32,22 @@ __global__ void gpuYoloLayer(
const float alpha = scaleXY;
const float beta = -0.5 * (scaleXY - 1);
float x
= (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)])
* alpha + beta + x_id) * netWidth / gridSizeX;
float x = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta + x_id)
* netWidth / gridSizeX;
float y
= (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)])
* alpha + beta + y_id) * netHeight / gridSizeY;
float y = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * alpha + beta + y_id)
* netHeight / gridSizeY;
float w
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)])
* anchors[mask[z_id] * 2];
float w = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * anchors[mask[z_id] * 2];
float h
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)])
* anchors[mask[z_id] * 2 + 1];
float h = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * anchors[mask[z_id] * 2 + 1];
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numOutputClasses; ++i)
{
float prob
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]);
if (prob > maxProb)
{
for (uint i = 0; i < numOutputClasses; ++i) {
float prob = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]);
if (prob > maxProb) {
maxProb = prob;
maxIndex = i;
}
@@ -72,33 +61,28 @@ __global__ void gpuYoloLayer(
detection_classes[count] = maxIndex;
}
cudaError_t cudaYoloLayer(
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
cudaError_t cudaYoloLayer(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
cudaError_t cudaYoloLayer(
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream)
cudaError_t cudaYoloLayer(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream)
{
dim3 threads_per_block(16, 16, 4);
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
(gridSizeY / threads_per_block.y) + 1,
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, (gridSizeY / threads_per_block.y) + 1,
(numBBoxes / threads_per_block.z) + 1);
for (unsigned int batch = 0; batch < batchSize; ++batch)
{
for (unsigned int batch = 0; batch < batchSize; ++batch) {
gpuYoloLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
reinterpret_cast<const float*>(input) + (batch * inputSize),
reinterpret_cast<int*>(num_detections) + (batch),
reinterpret_cast<const float*>(input) + (batch * inputSize), reinterpret_cast<int*>(num_detections) + (batch),
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
scoreThreshold, netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes, scaleXY,
reinterpret_cast<const float*>(anchors), reinterpret_cast<const int*>(mask));
reinterpret_cast<int*>(detection_classes) + (batch * outputSize), scoreThreshold, netWidth, netHeight, gridSizeX,
gridSizeY, numOutputClasses, numBBoxes, scaleXY, reinterpret_cast<const float*>(anchors),
reinterpret_cast<const int*>(mask));
}
return cudaGetLastError();
}

View File

@@ -4,11 +4,9 @@
*/
#include <stdint.h>
#include <stdio.h>
__global__ void gpuYoloLayer_e(
const float* cls, const float* reg, int* num_detections, float* detection_boxes, float* detection_scores,
int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight,
__global__ void gpuYoloLayer_e(const float* cls, const float* reg, int* num_detections, float* detection_boxes,
float* detection_scores, int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight,
const uint numOutputClasses, const uint64_t outputSize)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
@@ -19,13 +17,9 @@ __global__ void gpuYoloLayer_e(
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numOutputClasses; ++i)
{
float prob
= cls[x_id * numOutputClasses + i];
if (prob > maxProb)
{
for (uint i = 0; i < numOutputClasses; ++i) {
float prob = cls[x_id * numOutputClasses + i];
if (prob > maxProb) {
maxProb = prob;
maxIndex = i;
}
@@ -44,29 +38,27 @@ __global__ void gpuYoloLayer_e(
detection_classes[count] = maxIndex;
}
cudaError_t cudaYoloLayer_e(
const void* cls, const void* reg, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream);
cudaError_t cudaYoloLayer_e(const void* cls, const void* reg, void* num_detections, void* detection_boxes,
void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& outputSize,
const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& numOutputClasses,
cudaStream_t stream);
cudaError_t cudaYoloLayer_e(
const void* cls, const void* reg, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream)
cudaError_t cudaYoloLayer_e(const void* cls, const void* reg, void* num_detections, void* detection_boxes,
void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& outputSize,
const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& numOutputClasses,
cudaStream_t stream)
{
int threads_per_block = 16;
int number_of_blocks = (outputSize / threads_per_block) + 1;
for (unsigned int batch = 0; batch < batchSize; ++batch)
{
for (unsigned int batch = 0; batch < batchSize; ++batch) {
gpuYoloLayer_e<<<number_of_blocks, threads_per_block, 0, stream>>>(
reinterpret_cast<const float*>(cls) + (batch * numOutputClasses * outputSize),
reinterpret_cast<const float*>(reg) + (batch * 4 * outputSize),
reinterpret_cast<int*>(num_detections) + (batch),
reinterpret_cast<const float*>(reg) + (batch * 4 * outputSize), reinterpret_cast<int*>(num_detections) + (batch),
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
scoreThreshold, netWidth, netHeight, numOutputClasses, outputSize);
reinterpret_cast<int*>(detection_classes) + (batch * outputSize), scoreThreshold, netWidth, netHeight,
numOutputClasses, outputSize);
}
return cudaGetLastError();
}

View File

@@ -5,10 +5,10 @@
#include <stdint.h>
__global__ void gpuYoloLayer_nc(
const float* input, int* num_detections, float* detection_boxes, float* detection_scores, int* detection_classes,
const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, const uint gridSizeY,
const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors, const int* mask)
__global__ void gpuYoloLayer_nc(const float* input, int* num_detections, float* detection_boxes, float* detection_scores,
int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX,
const uint gridSizeY, const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors,
const int* mask)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
@@ -20,8 +20,7 @@ __global__ void gpuYoloLayer_nc(
const int numGridCells = gridSizeX * gridSizeY;
const int bbindex = y_id * gridSizeX + x_id;
const float objectness
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)];
const float objectness = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)];
if (objectness < scoreThreshold)
return;
@@ -31,32 +30,22 @@ __global__ void gpuYoloLayer_nc(
const float alpha = scaleXY;
const float beta = -0.5 * (scaleXY - 1);
float x
= (input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
* alpha + beta + x_id) * netWidth / gridSizeX;
float x = (input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)] * alpha + beta + x_id) * netWidth /
gridSizeX;
float y
= (input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]
* alpha + beta + y_id) * netHeight / gridSizeY;
float y = (input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)] * alpha + beta + y_id) * netHeight /
gridSizeY;
float w
= __powf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] * 2, 2)
* anchors[mask[z_id] * 2];
float w = __powf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)] * 2, 2) * anchors[mask[z_id] * 2];
float h
= __powf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] * 2, 2)
* anchors[mask[z_id] * 2 + 1];
float h = __powf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)] * 2, 2) * anchors[mask[z_id] * 2 + 1];
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numOutputClasses; ++i)
{
float prob
= input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
if (prob > maxProb)
{
for (uint i = 0; i < numOutputClasses; ++i) {
float prob = input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
if (prob > maxProb) {
maxProb = prob;
maxIndex = i;
}
@@ -70,33 +59,28 @@ __global__ void gpuYoloLayer_nc(
detection_classes[count] = maxIndex;
}
cudaError_t cudaYoloLayer_nc(
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
cudaError_t cudaYoloLayer_nc(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
cudaError_t cudaYoloLayer_nc(
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream)
cudaError_t cudaYoloLayer_nc(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream)
{
dim3 threads_per_block(16, 16, 4);
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
(gridSizeY / threads_per_block.y) + 1,
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, (gridSizeY / threads_per_block.y) + 1,
(numBBoxes / threads_per_block.z) + 1);
for (unsigned int batch = 0; batch < batchSize; ++batch)
{
for (unsigned int batch = 0; batch < batchSize; ++batch) {
gpuYoloLayer_nc<<<number_of_blocks, threads_per_block, 0, stream>>>(
reinterpret_cast<const float*>(input) + (batch * inputSize),
reinterpret_cast<int*>(num_detections) + (batch),
reinterpret_cast<const float*>(input) + (batch * inputSize), reinterpret_cast<int*>(num_detections) + (batch),
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
scoreThreshold, netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes, scaleXY,
reinterpret_cast<const float*>(anchors), reinterpret_cast<const int*>(mask));
reinterpret_cast<int*>(detection_classes) + (batch * outputSize), scoreThreshold, netWidth, netHeight, gridSizeX,
gridSizeY, numOutputClasses, numBBoxes, scaleXY, reinterpret_cast<const float*>(anchors),
reinterpret_cast<const int*>(mask));
}
return cudaGetLastError();
}

View File

@@ -7,10 +7,10 @@
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
__global__ void gpuYoloLayer_r(
const float* input, int* num_detections, float* detection_boxes, float* detection_scores, int* detection_classes,
const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX, const uint gridSizeY,
const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors, const int* mask)
__global__ void gpuYoloLayer_r(const float* input, int* num_detections, float* detection_boxes, float* detection_scores,
int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX,
const uint gridSizeY, const uint numOutputClasses, const uint numBBoxes, const float scaleXY, const float* anchors,
const int* mask)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
@@ -22,8 +22,7 @@ __global__ void gpuYoloLayer_r(
const int numGridCells = gridSizeX * gridSizeY;
const int bbindex = y_id * gridSizeX + x_id;
const float objectness
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
const float objectness = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
if (objectness < scoreThreshold)
return;
@@ -33,32 +32,24 @@ __global__ void gpuYoloLayer_r(
const float alpha = scaleXY;
const float beta = -0.5 * (scaleXY - 1);
float x
= (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)])
* alpha + beta + x_id) * netWidth / gridSizeX;
float x = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) * alpha + beta + x_id)
* netWidth / gridSizeX;
float y
= (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)])
* alpha + beta + y_id) * netHeight / gridSizeY;
float y = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) * alpha + beta + y_id)
* netHeight / gridSizeY;
float w
= __powf(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * 2, 2)
float w = __powf(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * 2, 2)
* anchors[mask[z_id] * 2];
float h
= __powf(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * 2, 2)
float h = __powf(sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * 2, 2)
* anchors[mask[z_id] * 2 + 1];
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numOutputClasses; ++i)
{
float prob
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]);
if (prob > maxProb)
{
for (uint i = 0; i < numOutputClasses; ++i) {
float prob = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))]);
if (prob > maxProb) {
maxProb = prob;
maxIndex = i;
}
@@ -72,33 +63,28 @@ __global__ void gpuYoloLayer_r(
detection_classes[count] = maxIndex;
}
cudaError_t cudaYoloLayer_r(
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
cudaError_t cudaYoloLayer_r(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
cudaError_t cudaYoloLayer_r(
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream)
cudaError_t cudaYoloLayer_r(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream)
{
dim3 threads_per_block(16, 16, 4);
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
(gridSizeY / threads_per_block.y) + 1,
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, (gridSizeY / threads_per_block.y) + 1,
(numBBoxes / threads_per_block.z) + 1);
for (unsigned int batch = 0; batch < batchSize; ++batch)
{
for (unsigned int batch = 0; batch < batchSize; ++batch) {
gpuYoloLayer_r<<<number_of_blocks, threads_per_block, 0, stream>>>(
reinterpret_cast<const float*>(input) + (batch * inputSize),
reinterpret_cast<int*>(num_detections) + (batch),
reinterpret_cast<const float*>(input) + (batch * inputSize), reinterpret_cast<int*>(num_detections) + (batch),
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
scoreThreshold, netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes, scaleXY,
reinterpret_cast<const float*>(anchors), reinterpret_cast<const int*>(mask));
reinterpret_cast<int*>(detection_classes) + (batch * outputSize), scoreThreshold, netWidth, netHeight, gridSizeX,
gridSizeY, numOutputClasses, numBBoxes, scaleXY, reinterpret_cast<const float*>(anchors),
reinterpret_cast<const int*>(mask));
}
return cudaGetLastError();
}

View File

@@ -7,9 +7,8 @@
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
__device__ void softmaxGPU(
const float* input, const int bbindex, const int numGridCells, uint z_id, const uint numOutputClasses, float temp,
float* output)
__device__ void softmaxGPU(const float* input, const int bbindex, const int numGridCells, uint z_id,
const uint numOutputClasses, float temp, float* output)
{
int i;
float sum = 0;
@@ -28,10 +27,9 @@ __device__ void softmaxGPU(
}
}
__global__ void gpuRegionLayer(
const float* input, float* softmax, int* num_detections, float* detection_boxes, float* detection_scores,
int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, const uint gridSizeX,
const uint gridSizeY, const uint numOutputClasses, const uint numBBoxes, const float* anchors)
__global__ void gpuRegionLayer(const float* input, float* softmax, int* num_detections, float* detection_boxes,
float* detection_scores, int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight,
const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses, const uint numBBoxes, const float* anchors)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
@@ -43,42 +41,31 @@ __global__ void gpuRegionLayer(
const int numGridCells = gridSizeX * gridSizeY;
const int bbindex = y_id * gridSizeX + x_id;
const float objectness
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
const float objectness = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 4)]);
if (objectness < scoreThreshold)
return;
int count = (int)atomicAdd(num_detections, 1);
float x
= (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)])
+ x_id) * netWidth / gridSizeX;
float x = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]) + x_id) * netWidth / gridSizeX;
float y
= (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)])
+ y_id) * netHeight / gridSizeY;
float y = (sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 1)]) + y_id) * netHeight / gridSizeY;
float w
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)])
* anchors[z_id * 2] * netWidth / gridSizeX;
float w = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 2)]) * anchors[z_id * 2] * netWidth /
gridSizeX;
float h
= __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)])
* anchors[z_id * 2 + 1] * netHeight / gridSizeY;
float h = __expf(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 3)]) * anchors[z_id * 2 + 1] * netHeight /
gridSizeY;
softmaxGPU(input, bbindex, numGridCells, z_id, numOutputClasses, 1.0, softmax);
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numOutputClasses; ++i)
{
float prob
= softmax[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
if (prob > maxProb)
{
for (uint i = 0; i < numOutputClasses; ++i) {
float prob = softmax[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + (5 + i))];
if (prob > maxProb) {
maxProb = prob;
maxIndex = i;
}
@@ -92,34 +79,28 @@ __global__ void gpuRegionLayer(
detection_classes[count] = maxIndex;
}
cudaError_t cudaRegionLayer(
const void* input, void* softmax, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
const uint& numBBoxes, const void* anchors, cudaStream_t stream);
cudaError_t cudaRegionLayer(const void* input, void* softmax, void* num_detections, void* detection_boxes,
void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize,
const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY,
const uint& numOutputClasses, const uint& numBBoxes, const void* anchors, cudaStream_t stream);
cudaError_t cudaRegionLayer(
const void* input, void* softmax, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
const uint& numBBoxes, const void* anchors, cudaStream_t stream)
cudaError_t cudaRegionLayer(const void* input, void* softmax, void* num_detections, void* detection_boxes,
void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize,
const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY,
const uint& numOutputClasses, const uint& numBBoxes, const void* anchors, cudaStream_t stream)
{
dim3 threads_per_block(16, 16, 4);
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
(gridSizeY / threads_per_block.y) + 1,
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1, (gridSizeY / threads_per_block.y) + 1,
(numBBoxes / threads_per_block.z) + 1);
for (unsigned int batch = 0; batch < batchSize; ++batch)
{
for (unsigned int batch = 0; batch < batchSize; ++batch) {
gpuRegionLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
reinterpret_cast<const float*>(input) + (batch * inputSize),
reinterpret_cast<float*>(softmax) + (batch * inputSize),
reinterpret_cast<const float*>(input) + (batch * inputSize), reinterpret_cast<float*>(softmax) + (batch * inputSize),
reinterpret_cast<int*>(num_detections) + (batch),
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
scoreThreshold, netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes,
reinterpret_cast<const float*>(anchors));
reinterpret_cast<int*>(detection_classes) + (batch * outputSize), scoreThreshold, netWidth, netHeight, gridSizeX,
gridSizeY, numOutputClasses, numBBoxes, reinterpret_cast<const float*>(anchors));
}
return cudaGetLastError();
}

View File

@@ -0,0 +1,62 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include <stdint.h>
__global__ void gpuYoloLayer_v8(const float* input, int* num_detections, float* detection_boxes, float* detection_scores,
int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight,
const uint numOutputClasses, const uint64_t outputSize)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
if (x_id >= outputSize)
return;
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numOutputClasses; ++i) {
float prob = input[x_id * (4 + numOutputClasses) + i + 4];
if (prob > maxProb) {
maxProb = prob;
maxIndex = i;
}
}
if (maxProb < scoreThreshold)
return;
int count = (int)atomicAdd(num_detections, 1);
detection_boxes[count * 4 + 0] = input[x_id * (4 + numOutputClasses) + 0];
detection_boxes[count * 4 + 1] = input[x_id * (4 + numOutputClasses) + 1];
detection_boxes[count * 4 + 2] = input[x_id * (4 + numOutputClasses) + 2];
detection_boxes[count * 4 + 3] = input[x_id * (4 + numOutputClasses) + 3];
detection_scores[count] = maxProb;
detection_classes[count] = maxIndex;
}
cudaError_t cudaYoloLayer_v8(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream);
cudaError_t cudaYoloLayer_v8(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream)
{
int threads_per_block = 16;
int number_of_blocks = (outputSize / threads_per_block) + 1;
for (unsigned int batch = 0; batch < batchSize; ++batch) {
gpuYoloLayer_v8<<<number_of_blocks, threads_per_block, 0, stream>>>(
reinterpret_cast<const float*>(input) + (batch * (4 + numOutputClasses) * outputSize),
reinterpret_cast<int*>(num_detections) + (batch),
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
scoreThreshold, netWidth, netHeight, numOutputClasses, outputSize);
}
return cudaGetLastError();
}

View File

@@ -24,60 +24,50 @@
*/
#include "yoloPlugins.h"
#include "NvInferPlugin.h"
#include <cassert>
#include <iostream>
#include <memory>
uint kNUM_CLASSES;
namespace {
template <typename T>
void write(char*& buffer, const T& val)
{
void write(char*& buffer, const T& val) {
*reinterpret_cast<T*>(buffer) = val;
buffer += sizeof(T);
}
template <typename T>
void read(const char*& buffer, T& val)
{
void read(const char*& buffer, T& val) {
val = *reinterpret_cast<const T*>(buffer);
buffer += sizeof(T);
}
}
cudaError_t cudaYoloLayer_e(
const void* cls, const void* reg, void* num_detections, void* detection_boxes, void* detection_scores,
cudaError_t cudaYoloLayer_v8(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream);
cudaError_t cudaYoloLayer_r(
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
cudaError_t cudaYoloLayer_e(const void* cls, const void* reg, void* num_detections, void* detection_boxes,
void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& outputSize,
const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& numOutputClasses,
cudaStream_t stream);
cudaError_t cudaYoloLayer_nc(
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
cudaError_t cudaYoloLayer(
const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes,
const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses, const uint& numBBoxes,
const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
cudaError_t cudaRegionLayer(
const void* input, void* softmax, void* num_detections, void* detection_boxes, void* detection_scores,
cudaError_t cudaYoloLayer_r(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
const uint& numBBoxes, const void* anchors, cudaStream_t stream);
const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
YoloLayer::YoloLayer (const void* data, size_t length)
{
cudaError_t cudaYoloLayer_nc(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
cudaError_t cudaYoloLayer(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize, const float& scoreThreshold,
const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY, const uint& numOutputClasses,
const uint& numBBoxes, const float& scaleXY, const void* anchors, const void* mask, cudaStream_t stream);
cudaError_t cudaRegionLayer(const void* input, void* softmax, void* num_detections, void* detection_boxes,
void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& inputSize, uint64_t& outputSize,
const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& gridSizeX, const uint& gridSizeY,
const uint& numOutputClasses, const uint& numBBoxes, const void* anchors, cudaStream_t stream);
YoloLayer::YoloLayer(const void* data, size_t length) {
const char* d = static_cast<const char*>(data);
read(d, m_NetWidth);
@@ -88,11 +78,10 @@ YoloLayer::YoloLayer (const void* data, size_t length)
read(d, m_Type);
read(d, m_ScoreThreshold);
if (m_Type != 3) {
if (m_Type != 3 && m_Type != 4) {
uint yoloTensorsSize;
read(d, yoloTensorsSize);
for (uint i = 0; i < yoloTensorsSize; ++i)
{
for (uint i = 0; i < yoloTensorsSize; ++i) {
TensorInfo curYoloTensor;
read(d, curYoloTensor.gridSizeX);
read(d, curYoloTensor.gridSizeY);
@@ -101,8 +90,7 @@ YoloLayer::YoloLayer (const void* data, size_t length)
uint anchorsSize;
read(d, anchorsSize);
for (uint j = 0; j < anchorsSize; j++)
{
for (uint j = 0; j < anchorsSize; ++j) {
float result;
read(d, result);
curYoloTensor.anchors.push_back(result);
@@ -110,72 +98,55 @@ YoloLayer::YoloLayer (const void* data, size_t length)
uint maskSize;
read(d, maskSize);
for (uint j = 0; j < maskSize; j++)
{
for (uint j = 0; j < maskSize; ++j) {
int result;
read(d, result);
curYoloTensor.mask.push_back(result);
}
m_YoloTensors.push_back(curYoloTensor);
}
}
kNUM_CLASSES = m_NumClasses;
};
YoloLayer::YoloLayer(
const uint& netWidth, const uint& netHeight, const uint& numClasses, const uint& newCoords,
YoloLayer::YoloLayer(const uint& netWidth, const uint& netHeight, const uint& numClasses, const uint& newCoords,
const std::vector<TensorInfo>& yoloTensors, const uint64_t& outputSize, const uint& modelType,
const float& scoreThreshold) :
m_NetWidth(netWidth),
m_NetHeight(netHeight),
m_NumClasses(numClasses),
m_NewCoords(newCoords),
m_YoloTensors(yoloTensors),
m_OutputSize(outputSize),
m_Type(modelType),
const float& scoreThreshold) : m_NetWidth(netWidth), m_NetHeight(netHeight), m_NumClasses(numClasses),
m_NewCoords(newCoords), m_YoloTensors(yoloTensors), m_OutputSize(outputSize), m_Type(modelType),
m_ScoreThreshold(scoreThreshold)
{
assert(m_NetWidth > 0);
assert(m_NetHeight > 0);
kNUM_CLASSES = m_NumClasses;
};
nvinfer1::Dims
YoloLayer::getOutputDimensions(
int index, const nvinfer1::Dims* inputs, int nbInputDims) noexcept
YoloLayer::getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) noexcept
{
assert(index <= 4);
if (index == 0) {
if (index == 0)
return nvinfer1::Dims{1, {1}};
}
else if (index == 1) {
else if (index == 1)
return nvinfer1::Dims{2, {static_cast<int>(m_OutputSize), 4}};
}
return nvinfer1::Dims{1, {static_cast<int>(m_OutputSize)}};
}
bool YoloLayer::supportsFormat (
nvinfer1::DataType type, nvinfer1::PluginFormat format) const noexcept {
return (type == nvinfer1::DataType::kFLOAT &&
format == nvinfer1::PluginFormat::kLINEAR);
bool
YoloLayer::supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) const noexcept {
return (type == nvinfer1::DataType::kFLOAT && format == nvinfer1::PluginFormat::kLINEAR);
}
void
YoloLayer::configureWithFormat (
const nvinfer1::Dims* inputDims, int nbInputs,
const nvinfer1::Dims* outputDims, int nbOutputs,
nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) noexcept
YoloLayer::configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs, const nvinfer1::Dims* outputDims,
int nbOutputs, nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) noexcept
{
assert(nbInputs > 0);
assert(format == nvinfer1::PluginFormat::kLINEAR);
assert(inputDims != nullptr);
}
int32_t YoloLayer::enqueue (
int batchSize, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
int32_t
YoloLayer::enqueue(int batchSize, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream)
noexcept
{
void* num_detections = outputs[0];
void* detection_boxes = outputs[1];
@@ -187,17 +158,17 @@ int32_t YoloLayer::enqueue (
CUDA_CHECK(cudaMemsetAsync((float*)detection_scores, 0, sizeof(float) * m_OutputSize * batchSize, stream));
CUDA_CHECK(cudaMemsetAsync((int*)detection_classes, 0, sizeof(int) * m_OutputSize * batchSize, stream));
if (m_Type == 3)
{
CUDA_CHECK(cudaYoloLayer_e(
inputs[0], inputs[1], num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
if (m_Type == 4) {
CUDA_CHECK(cudaYoloLayer_v8(inputs[0], num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, stream));
}
else
{
else if (m_Type == 3) {
CUDA_CHECK(cudaYoloLayer_e(inputs[0], inputs[1], num_detections, detection_boxes, detection_scores, detection_classes,
batchSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, stream));
}
else {
uint yoloTensorsSize = m_YoloTensors.size();
for (uint i = 0; i < yoloTensorsSize; ++i)
{
for (uint i = 0; i < yoloTensorsSize; ++i) {
TensorInfo& curYoloTensor = m_YoloTensors.at(i);
uint numBBoxes = curYoloTensor.numBBoxes;
@@ -212,8 +183,7 @@ int32_t YoloLayer::enqueue (
if (anchors.size() > 0) {
float* f_anchors = anchors.data();
CUDA_CHECK(cudaMalloc(&v_anchors, sizeof(float) * anchors.size()));
CUDA_CHECK(cudaMemcpyAsync(v_anchors, f_anchors, sizeof(float) * anchors.size(), cudaMemcpyHostToDevice,
stream));
CUDA_CHECK(cudaMemcpyAsync(v_anchors, f_anchors, sizeof(float) * anchors.size(), cudaMemcpyHostToDevice, stream));
}
if (mask.size() > 0) {
int* f_mask = mask.data();
@@ -224,22 +194,19 @@ int32_t YoloLayer::enqueue (
uint64_t inputSize = gridSizeX * gridSizeY * (numBBoxes * (4 + 1 + m_NumClasses));
if (m_Type == 2) { // YOLOR incorrect param: scale_x_y = 2.0
CUDA_CHECK(cudaYoloLayer_r(
inputs[i], num_detections, detection_boxes, detection_scores, detection_classes, batchSize, inputSize,
m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, m_NumClasses, numBBoxes,
2.0, v_anchors, v_mask, stream));
CUDA_CHECK(cudaYoloLayer_r(inputs[i], num_detections, detection_boxes, detection_scores, detection_classes,
batchSize, inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY,
m_NumClasses, numBBoxes, 2.0, v_anchors, v_mask, stream));
}
else if (m_Type == 1) {
if (m_NewCoords) {
CUDA_CHECK(cudaYoloLayer_nc(
inputs[i], num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY,
CUDA_CHECK(cudaYoloLayer_nc( inputs[i], num_detections, detection_boxes, detection_scores, detection_classes,
batchSize, inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY,
m_NumClasses, numBBoxes, scaleXY, v_anchors, v_mask, stream));
}
else {
CUDA_CHECK(cudaYoloLayer(
inputs[i], num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY,
CUDA_CHECK(cudaYoloLayer(inputs[i], num_detections, detection_boxes, detection_scores, detection_classes,
batchSize, inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY,
m_NumClasses, numBBoxes, scaleXY, v_anchors, v_mask, stream));
}
}
@@ -248,10 +215,9 @@ int32_t YoloLayer::enqueue (
CUDA_CHECK(cudaMalloc(&softmax, sizeof(float) * inputSize * batchSize));
CUDA_CHECK(cudaMemsetAsync((float*)softmax, 0, sizeof(float) * inputSize * batchSize, stream));
CUDA_CHECK(cudaRegionLayer(
inputs[i], softmax, num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY, m_NumClasses,
numBBoxes, v_anchors, stream));
CUDA_CHECK(cudaRegionLayer(inputs[i], softmax, num_detections, detection_boxes, detection_scores, detection_classes,
batchSize, inputSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, gridSizeX, gridSizeY,
m_NumClasses, numBBoxes, v_anchors, stream));
CUDA_CHECK(cudaFree(softmax));
}
@@ -268,7 +234,8 @@ int32_t YoloLayer::enqueue (
return 0;
}
size_t YoloLayer::getSerializationSize() const noexcept
size_t
YoloLayer::getSerializationSize() const noexcept
{
size_t totalSize = 0;
@@ -280,12 +247,11 @@ size_t YoloLayer::getSerializationSize() const noexcept
totalSize += sizeof(m_Type);
totalSize += sizeof(m_ScoreThreshold);
if (m_Type != 3) {
if (m_Type != 3 && m_Type != 4) {
uint yoloTensorsSize = m_YoloTensors.size();
totalSize += sizeof(yoloTensorsSize);
for (uint i = 0; i < yoloTensorsSize; ++i)
{
for (uint i = 0; i < yoloTensorsSize; ++i) {
const TensorInfo& curYoloTensor = m_YoloTensors.at(i);
totalSize += sizeof(curYoloTensor.gridSizeX);
totalSize += sizeof(curYoloTensor.gridSizeY);
@@ -299,7 +265,8 @@ size_t YoloLayer::getSerializationSize() const noexcept
return totalSize;
}
void YoloLayer::serialize(void* buffer) const noexcept
void
YoloLayer::serialize(void* buffer) const noexcept
{
char* d = static_cast<char*>(buffer);
@@ -311,11 +278,10 @@ void YoloLayer::serialize(void* buffer) const noexcept
write(d, m_Type);
write(d, m_ScoreThreshold);
if (m_Type != 3) {
if (m_Type != 3 && m_Type != 4) {
uint yoloTensorsSize = m_YoloTensors.size();
write(d, yoloTensorsSize);
for (uint i = 0; i < yoloTensorsSize; ++i)
{
for (uint i = 0; i < yoloTensorsSize; ++i) {
const TensorInfo& curYoloTensor = m_YoloTensors.at(i);
write(d, curYoloTensor.gridSizeX);
write(d, curYoloTensor.gridSizeY);
@@ -325,24 +291,21 @@ void YoloLayer::serialize(void* buffer) const noexcept
uint anchorsSize = curYoloTensor.anchors.size();
write(d, anchorsSize);
for (uint j = 0; j < anchorsSize; ++j)
{
write(d, curYoloTensor.anchors[j]);
}
uint maskSize = curYoloTensor.mask.size();
write(d, maskSize);
for (uint j = 0; j < maskSize; ++j)
{
write(d, curYoloTensor.mask[j]);
}
}
}
}
nvinfer1::IPluginV2* YoloLayer::clone() const noexcept
nvinfer1::IPluginV2*
YoloLayer::clone() const noexcept
{
return new YoloLayer (
m_NetWidth, m_NetHeight, m_NumClasses, m_NewCoords, m_YoloTensors, m_OutputSize, m_Type, m_ScoreThreshold);
return new YoloLayer(m_NetWidth, m_NetHeight, m_NumClasses, m_NewCoords, m_YoloTensors, m_OutputSize, m_Type,
m_ScoreThreshold);
}
REGISTER_TENSORRT_PLUGIN(YoloLayerPluginCreator);

View File

@@ -26,41 +26,26 @@
#ifndef __YOLO_PLUGINS__
#define __YOLO_PLUGINS__
#include <cassert>
#include <cstring>
#include <cuda_runtime_api.h>
#include <iostream>
#include <memory>
#include <vector>
#include "NvInferPlugin.h"
#include "yolo.h"
#define CUDA_CHECK(status) \
{ \
if (status != 0) \
{ \
std::cout << "CUDA failure: " << cudaGetErrorString(status) << " in file " << __FILE__ << " at line " \
<< __LINE__ << std::endl; \
#define CUDA_CHECK(status) { \
if (status != 0) { \
std::cout << "CUDA failure: " << cudaGetErrorString(status) << " in file " << __FILE__ << " at line " << __LINE__ << \
std::endl; \
abort(); \
} \
}
namespace
{
namespace {
const char* YOLOLAYER_PLUGIN_VERSION {"1"};
const char* YOLOLAYER_PLUGIN_NAME {"YoloLayer_TRT"};
} // namespace
class YoloLayer : public nvinfer1::IPluginV2
{
class YoloLayer : public nvinfer1::IPluginV2 {
public:
YoloLayer(const void* data, size_t length);
YoloLayer (
const uint& netWidth, const uint& netHeight, const uint& numClasses, const uint& newCoords,
YoloLayer(const uint& netWidth, const uint& netHeight, const uint& numClasses, const uint& newCoords,
const std::vector<TensorInfo>& yoloTensors, const uint64_t& outputSize, const uint& modelType,
const float& scoreThreshold);
@@ -70,15 +55,11 @@ public:
int getNbOutputs() const noexcept override { return 4; }
nvinfer1::Dims getOutputDimensions (
int index, const nvinfer1::Dims* inputs,
int nbInputDims) noexcept override;
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) noexcept override;
bool supportsFormat (
nvinfer1::DataType type, nvinfer1::PluginFormat format) const noexcept override;
bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) const noexcept override;
void configureWithFormat (
const nvinfer1::Dims* inputDims, int nbInputs, const nvinfer1::Dims* outputDims, int nbOutputs,
void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs, const nvinfer1::Dims* outputDims, int nbOutputs,
nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) noexcept override;
int initialize() noexcept override { return 0; }
@@ -87,8 +68,7 @@ public:
size_t getWorkspaceSize(int maxBatchSize) const noexcept override { return 0; }
int32_t enqueue (
int batchSize, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream)
int32_t enqueue(int batchSize, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream)
noexcept override;
size_t getSerializationSize() const noexcept override;
@@ -99,13 +79,9 @@ public:
nvinfer1::IPluginV2* clone() const noexcept override;
void setPluginNamespace (const char* pluginNamespace) noexcept override {
m_Namespace = pluginNamespace;
}
void setPluginNamespace(const char* pluginNamespace) noexcept override { m_Namespace = pluginNamespace; }
virtual const char* getPluginNamespace () const noexcept override {
return m_Namespace.c_str();
}
virtual const char* getPluginNamespace() const noexcept override { return m_Namespace.c_str(); }
private:
std::string m_Namespace {""};
@@ -119,8 +95,7 @@ private:
float m_ScoreThreshold {0};
};
class YoloLayerPluginCreator : public nvinfer1::IPluginCreator
{
class YoloLayerPluginCreator : public nvinfer1::IPluginCreator {
public:
YoloLayerPluginCreator() {}
@@ -135,31 +110,22 @@ public:
return nullptr;
}
nvinfer1::IPluginV2* createPlugin (
const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override
{
nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override {
std::cerr<< "YoloLayerPluginCreator::getFieldNames is not implemented";
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin (
const char* name, const void* serialData, size_t serialLength) noexcept override
{
nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override {
std::cout << "Deserialize yoloLayer plugin: " << name << std::endl;
return new YoloLayer(serialData, serialLength);
}
void setPluginNamespace(const char* libNamespace) noexcept override {
m_Namespace = libNamespace;
}
const char* getPluginNamespace() const noexcept override {
return m_Namespace.c_str();
}
void setPluginNamespace(const char* libNamespace) noexcept override { m_Namespace = libNamespace; }
const char* getPluginNamespace() const noexcept override { return m_Namespace.c_str(); }
private:
std::string m_Namespace {""};
};
extern uint kNUM_CLASSES;
#endif // __YOLO_PLUGINS__

View File

@@ -8,6 +8,7 @@ from ppdet.utils.cli import ArgsParser
from ppdet.engine import Trainer
from ppdet.slim import build_slim_model
class Layers(object):
def __init__(self, size, fw, fc, letter_box):
self.blocks = [0 for _ in range(300)]
@@ -123,7 +124,7 @@ class Layers(object):
def Shuffle(self, reshape=None, transpose1=None, transpose2=None, route=None, output=''):
self.current += 1
r = 0
r = None
if route is not None:
r = self.get_route(route)
self.shuffle(reshape=reshape, transpose1=transpose1, transpose2=transpose2, route=r)
@@ -156,7 +157,7 @@ class Layers(object):
'channels=3\n' +
lb)
def convolutional(self, cv, act='linear', detect=False):
def convolutional(self, cv, act='linear'):
self.blocks[self.current] += 1
self.get_state_dict(cv.state_dict())
@@ -178,9 +179,6 @@ class Layers(object):
bias = cv.conv.bias
bn = True if hasattr(cv, 'bn') else False
if detect:
act = 'logistic'
b = 'batch_normalize=1\n' if bn is True else ''
g = 'groups=%d\n' % groups if groups > 1 else ''
w = 'bias=0\n' if bias is None and bn is False else ''
@@ -251,9 +249,9 @@ class Layers(object):
def shuffle(self, reshape=None, transpose1=None, transpose2=None, route=None):
self.blocks[self.current] += 1
r = 'reshape=%s\n' % str(reshape)[1:-1] if reshape is not None else ''
t1 = 'transpose1=%s\n' % str(transpose1)[1:-1] if transpose1 is not None else ''
t2 = 'transpose2=%s\n' % str(transpose2)[1:-1] if transpose2 is not None else ''
r = 'reshape=%s\n' % ', '.join(str(x) for x in reshape) if reshape is not None else ''
t1 = 'transpose1=%s\n' % ', '.join(str(x) for x in transpose1) if transpose1 is not None else ''
t2 = 'transpose2=%s\n' % ', '.join(str(x) for x in transpose2) if transpose2 is not None else ''
f = 'from=%d\n' % route if route is not None else ''
self.fc.write('\n[shuffle]\n' +
@@ -419,13 +417,13 @@ with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc:
layers.AvgPool2d()
layers.ESEAttn(model.yolo_head.stem_cls[i])
layers.Conv2D(model.yolo_head.pred_cls[i], act='sigmoid')
layers.Shuffle(reshape=[model.yolo_head.num_classes, 0], route=feat, output='cls')
layers.Shuffle(reshape=[model.yolo_head.num_classes, 'hw'], route=feat, output='cls')
layers.ESEAttn(model.yolo_head.stem_reg[i], route=-7)
layers.Conv2D(model.yolo_head.pred_reg[i])
layers.Shuffle(reshape=[4, model.yolo_head.reg_max + 1, 0], transpose2=[1, 0, 2], route=feat)
layers.Shuffle(reshape=[4, model.yolo_head.reg_max + 1, 'hw'], transpose2=[1, 0, 2], route=feat)
layers.SoftMax(0)
layers.Conv2D(model.yolo_head.proj_conv)
layers.Shuffle(reshape=[4, 0], route=feat, output='reg')
layers.Shuffle(reshape=[4, 'hw'], route=feat, output='reg')
layers.Detect('cls')
layers.Detect('reg')
layers.get_anchors(model.yolo_head.anchor_points.reshape([-1]), model.yolo_head.stride_tensor)

315
utils/gen_wts_yoloV8.py Normal file
View File

@@ -0,0 +1,315 @@
import argparse
import os
import struct
import torch
from ultralytics.yolo.utils.torch_utils import select_device
class Layers(object):
def __init__(self, n, size, fw, fc):
self.blocks = [0 for _ in range(n)]
self.current = -1
self.width = size[0] if len(size) == 1 else size[1]
self.height = size[0]
self.fw = fw
self.fc = fc
self.wc = 0
self.net()
def Conv(self, child):
self.current = child.i
self.fc.write('\n# Conv\n')
self.convolutional(child)
def C2f(self, child):
self.current = child.i
self.fc.write('\n# C2f\n')
self.convolutional(child.cv1)
self.c2f(child.m)
self.convolutional(child.cv2)
def SPPF(self, child):
self.current = child.i
self.fc.write('\n# SPPF\n')
self.convolutional(child.cv1)
self.maxpool(child.m)
self.maxpool(child.m)
self.maxpool(child.m)
self.route('-4, -3, -2, -1')
self.convolutional(child.cv2)
def Upsample(self, child):
self.current = child.i
self.fc.write('\n# Upsample\n')
self.upsample(child)
def Concat(self, child):
self.current = child.i
self.fc.write('\n# Concat\n')
r = []
for i in range(1, len(child.f)):
r.append(self.get_route(child.f[i]))
self.route('-1, %s' % str(r)[1:-1])
def Detect(self, child):
self.current = child.i
self.fc.write('\n# Detect\n')
output_idxs = [0 for _ in range(child.nl)]
for i in range(child.nl):
r = self.get_route(child.f[i])
self.route('%d' % r)
for j in range(len(child.cv3[i])):
self.convolutional(child.cv3[i][j])
self.route('%d' % (-1 - len(child.cv3[i])))
for j in range(len(child.cv2[i])):
self.convolutional(child.cv2[i][j])
self.route('-1, %d' % (-2 - len(child.cv2[i])))
self.shuffle(reshape=[child.no, -1])
output_idxs[i] = (-1 + i * (-4 - len(child.cv3[i]) - len(child.cv2[i])))
self.route('%s' % str(output_idxs[::-1])[1:-1], axis=1)
self.yolo(child)
def net(self):
self.fc.write('[net]\n' +
'width=%d\n' % self.width +
'height=%d\n' % self.height +
'channels=3\n' +
'letter_box=1\n')
def convolutional(self, cv, act=None, detect=False):
self.blocks[self.current] += 1
self.get_state_dict(cv.state_dict())
if cv._get_name() == 'Conv2d':
filters = cv.out_channels
size = cv.kernel_size
stride = cv.stride
pad = cv.padding
groups = cv.groups
bias = cv.bias
bn = False
act = 'linear' if not detect else 'logistic'
else:
filters = cv.conv.out_channels
size = cv.conv.kernel_size
stride = cv.conv.stride
pad = cv.conv.padding
groups = cv.conv.groups
bias = cv.conv.bias
bn = True if hasattr(cv, 'bn') else False
if act is None:
act = self.get_activation(cv.act._get_name()) if hasattr(cv, 'act') else 'linear'
b = 'batch_normalize=1\n' if bn is True else ''
g = 'groups=%d\n' % groups if groups > 1 else ''
w = 'bias=0\n' if bias is None and bn is False else ''
self.fc.write('\n[convolutional]\n' +
b +
'filters=%d\n' % filters +
'size=%s\n' % self.get_value(size) +
'stride=%s\n' % self.get_value(stride) +
'pad=%s\n' % self.get_value(pad) +
g +
w +
'activation=%s\n' % act)
def c2f(self, m):
self.blocks[self.current] += 1
for x in m:
self.get_state_dict(x.state_dict())
n = len(m)
shortcut = 1 if m[0].add else 0
filters = m[0].cv1.conv.out_channels
size = m[0].cv1.conv.kernel_size
stride = m[0].cv1.conv.stride
pad = m[0].cv1.conv.padding
groups = m[0].cv1.conv.groups
bias = m[0].cv1.conv.bias
bn = True if hasattr(m[0].cv1, 'bn') else False
act = 'linear'
if hasattr(m[0].cv1, 'act'):
act = self.get_activation(m[0].cv1.act._get_name())
b = 'batch_normalize=1\n' if bn is True else ''
g = 'groups=%d\n' % groups if groups > 1 else ''
w = 'bias=0\n' if bias is None and bn is False else ''
self.fc.write('\n[c2f]\n' +
'n=%d\n' % n +
'shortcut=%d\n' % shortcut +
b +
'filters=%d\n' % filters +
'size=%s\n' % self.get_value(size) +
'stride=%s\n' % self.get_value(stride) +
'pad=%s\n' % self.get_value(pad) +
g +
w +
'activation=%s\n' % act)
def route(self, layers, axis=0):
self.blocks[self.current] += 1
a = 'axis=%d\n' % axis if axis != 0 else ''
self.fc.write('\n[route]\n' +
'layers=%s\n' % layers +
a)
def shortcut(self, r, ew='add', act='linear'):
self.blocks[self.current] += 1
m = 'mode=mul\n' if ew == 'mul' else ''
self.fc.write('\n[shortcut]\n' +
'from=%d\n' % r +
m +
'activation=%s\n' % act)
def maxpool(self, m):
self.blocks[self.current] += 1
stride = m.stride
size = m.kernel_size
mode = m.ceil_mode
m = 'maxpool_up' if mode else 'maxpool'
self.fc.write('\n[%s]\n' % m +
'stride=%d\n' % stride +
'size=%d\n' % size)
def upsample(self, child):
self.blocks[self.current] += 1
stride = child.scale_factor
self.fc.write('\n[upsample]\n' +
'stride=%d\n' % stride)
def shuffle(self, reshape=None, transpose1=None, transpose2=None, route=None):
self.blocks[self.current] += 1
r = 'reshape=%s\n' % ', '.join(str(x) for x in reshape) if reshape is not None else ''
t1 = 'transpose1=%s\n' % ', '.join(str(x) for x in transpose1) if transpose1 is not None else ''
t2 = 'transpose2=%s\n' % ', '.join(str(x) for x in transpose2) if transpose2 is not None else ''
f = 'from=%d\n' % route if route is not None else ''
self.fc.write('\n[shuffle]\n' +
r +
t1 +
t2 +
f)
def yolo(self, child):
self.blocks[self.current] += 1
self.fc.write('\n[detect_v8]\n' +
'num=%d\n' % (child.reg_max * 4) +
'classes=%d\n' % child.nc)
def get_state_dict(self, state_dict):
for k, v in state_dict.items():
if 'num_batches_tracked' not in k:
vr = v.reshape(-1).numpy()
self.fw.write('{} {} '.format(k, len(vr)))
for vv in vr:
self.fw.write(' ')
self.fw.write(struct.pack('>f', float(vv)).hex())
self.fw.write('\n')
self.wc += 1
def get_anchors(self, anchor_points, stride_tensor):
vr = anchor_points.numpy()
self.fw.write('{} {} '.format('anchor_points', len(vr)))
for vv in vr:
self.fw.write(' ')
self.fw.write(struct.pack('>f', float(vv)).hex())
self.fw.write('\n')
self.wc += 1
vr = stride_tensor.numpy()
self.fw.write('{} {} '.format('stride_tensor', len(vr)))
for vv in vr:
self.fw.write(' ')
self.fw.write(struct.pack('>f', float(vv)).hex())
self.fw.write('\n')
self.wc += 1
def get_value(self, key):
if type(key) == int:
return key
return key[0] if key[0] == key[1] else str(key)[1:-1]
def get_route(self, n):
r = 0
for i, b in enumerate(self.blocks):
if i <= n:
r += b
else:
break
return r - 1
def get_activation(self, act):
if act == 'Hardswish':
return 'hardswish'
elif act == 'LeakyReLU':
return 'leaky'
elif act == 'SiLU':
return 'silu'
return 'linear'
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch YOLOv8 conversion')
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')
parser.add_argument(
'-s', '--size', nargs='+', type=int, default=[640], help='Inference size [H,W] (default [640])')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit('Invalid weights file')
return args.weights, args.size
pt_file, inference_size = parse_args()
model_name = os.path.basename(pt_file).split('.pt')[0]
wts_file = model_name + '.wts' if 'yolov8' in model_name else 'yolov8_' + model_name + '.wts'
cfg_file = model_name + '.cfg' if 'yolov8' in model_name else 'yolov8_' + model_name + '.cfg'
device = select_device('cpu')
model = torch.load(pt_file, map_location=device)['model'].float()
model.to(device).eval()
with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc:
layers = Layers(len(model.model), inference_size, fw, fc)
for child in model.model.children():
if child._get_name() == 'Conv':
layers.Conv(child)
elif child._get_name() == 'C2f':
layers.C2f(child)
elif child._get_name() == 'SPPF':
layers.SPPF(child)
elif child._get_name() == 'Upsample':
layers.Upsample(child)
elif child._get_name() == 'Concat':
layers.Concat(child)
elif child._get_name() == 'Detect':
layers.Detect(child)
layers.get_anchors(child.anchors.reshape([-1]), child.strides.reshape([-1]))
else:
raise SystemExit('Model not supported')
os.system('echo "%d" | cat - %s > temp && mv temp %s' % (layers.wc, wts_file, wts_file))