Added YOLOv5 6.0 native support

This commit is contained in:
unknown
2021-12-09 15:44:17 -03:00
parent dcc44b730c
commit bfd9268a31
8 changed files with 688 additions and 80 deletions

View File

@@ -0,0 +1,24 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
custom-network-config=yolov5n.cfg
model-file=yolov5n.wts
model-engine-file=model_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=4
maintain-aspect-ratio=0
parse-bbox-func-name=NvDsInferParseYolo
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
engine-create-func-name=NvDsInferYoloCudaEngineGet
[class-attrs-all]
pre-cluster-threshold=0.25

View File

@@ -12,7 +12,10 @@ nvinfer1::ILayer* activationLayer(
nvinfer1::ITensor* input, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network) nvinfer1::INetworkDefinition* network)
{ {
if (activation == "relu") if (activation == "linear") {
// Pass
}
else if (activation == "relu")
{ {
nvinfer1::IActivationLayer* relu = network->addActivation( nvinfer1::IActivationLayer* relu = network->addActivation(
*input, nvinfer1::ActivationType::kRELU); *input, nvinfer1::ActivationType::kRELU);
@@ -78,5 +81,24 @@ nvinfer1::ILayer* activationLayer(
mish->setName(mishLayerName.c_str()); mish->setName(mishLayerName.c_str());
output = mish; output = mish;
} }
else if (activation == "silu")
{
nvinfer1::IActivationLayer* sigmoid = network->addActivation(
*input, nvinfer1::ActivationType::kSIGMOID);
assert(sigmoid != nullptr);
std::string sigmoidLayerName = "sigmoid_" + std::to_string(layerIdx);
sigmoid->setName(sigmoidLayerName.c_str());
nvinfer1::IElementWiseLayer* silu = network->addElementWise(
*sigmoid->getOutput(0), *input,
nvinfer1::ElementWiseOperation::kPROD);
assert(silu != nullptr);
std::string siluLayerName = "silu_" + std::to_string(layerIdx);
silu->setName(siluLayerName.c_str());
output = silu;
}
else {
std::cerr << "Activation not supported: " << activation << std::endl;
std::abort();
}
return output; return output;
} }

View File

@@ -12,6 +12,7 @@ nvinfer1::ILayer* convolutionalLayer(
std::vector<float>& weights, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, std::vector<nvinfer1::Weights>& trtWeights,
int& weightPtr, int& weightPtr,
std::string weightsType,
int& inputChannels, int& inputChannels,
nvinfer1::ITensor* input, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network) nvinfer1::INetworkDefinition* network)
@@ -56,57 +57,111 @@ nvinfer1::ILayer* convolutionalLayer(
nvinfer1::Weights convWt{nvinfer1::DataType::kFLOAT, nullptr, size}; nvinfer1::Weights convWt{nvinfer1::DataType::kFLOAT, nullptr, size};
nvinfer1::Weights convBias{nvinfer1::DataType::kFLOAT, nullptr, bias}; nvinfer1::Weights convBias{nvinfer1::DataType::kFLOAT, nullptr, bias};
if (batchNormalize == false) if (weightsType == "weights") {
{ if (batchNormalize == false)
float* val = new float[filters];
for (int i = 0; i < filters; ++i)
{ {
val[i] = weights[weightPtr]; float* val = new float[filters];
weightPtr++; for (int i = 0; i < filters; ++i)
{
val[i] = weights[weightPtr];
weightPtr++;
}
convBias.values = val;
trtWeights.push_back(convBias);
val = new float[size];
for (int i = 0; i < size; ++i)
{
val[i] = weights[weightPtr];
weightPtr++;
}
convWt.values = val;
trtWeights.push_back(convWt);
} }
convBias.values = val; else
trtWeights.push_back(convBias);
val = new float[size];
for (int i = 0; i < size; ++i)
{ {
val[i] = weights[weightPtr]; for (int i = 0; i < filters; ++i)
weightPtr++; {
bnBiases.push_back(weights[weightPtr]);
weightPtr++;
}
for (int i = 0; i < filters; ++i)
{
bnWeights.push_back(weights[weightPtr]);
weightPtr++;
}
for (int i = 0; i < filters; ++i)
{
bnRunningMean.push_back(weights[weightPtr]);
weightPtr++;
}
for (int i = 0; i < filters; ++i)
{
bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5));
weightPtr++;
}
float* val = new float[size];
for (int i = 0; i < size; ++i)
{
val[i] = weights[weightPtr];
weightPtr++;
}
convWt.values = val;
trtWeights.push_back(convWt);
trtWeights.push_back(convBias);
} }
convWt.values = val;
trtWeights.push_back(convWt);
} }
else else {
{ if (batchNormalize == false)
for (int i = 0; i < filters; ++i)
{ {
bnBiases.push_back(weights[weightPtr]); float* val = new float[size];
weightPtr++; for (int i = 0; i < size; ++i)
{
val[i] = weights[weightPtr];
weightPtr++;
}
convWt.values = val;
trtWeights.push_back(convWt);
val = new float[filters];
for (int i = 0; i < filters; ++i)
{
val[i] = weights[weightPtr];
weightPtr++;
}
convBias.values = val;
trtWeights.push_back(convBias);
} }
else
for (int i = 0; i < filters; ++i)
{ {
bnWeights.push_back(weights[weightPtr]); float* val = new float[size];
weightPtr++; for (int i = 0; i < size; ++i)
{
val[i] = weights[weightPtr];
weightPtr++;
}
convWt.values = val;
for (int i = 0; i < filters; ++i)
{
bnWeights.push_back(weights[weightPtr]);
weightPtr++;
}
for (int i = 0; i < filters; ++i)
{
bnBiases.push_back(weights[weightPtr]);
weightPtr++;
}
for (int i = 0; i < filters; ++i)
{
bnRunningMean.push_back(weights[weightPtr]);
weightPtr++;
}
for (int i = 0; i < filters; ++i)
{
bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5));
weightPtr++;
}
trtWeights.push_back(convWt);
trtWeights.push_back(convBias);
} }
for (int i = 0; i < filters; ++i)
{
bnRunningMean.push_back(weights[weightPtr]);
weightPtr++;
}
for (int i = 0; i < filters; ++i)
{
bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5));
weightPtr++;
}
float* val = new float[size];
for (int i = 0; i < size; ++i)
{
val[i] = weights[weightPtr];
weightPtr++;
}
convWt.values = val;
trtWeights.push_back(convWt);
trtWeights.push_back(convBias);
} }
nvinfer1::IConvolutionLayer* conv = network->addConvolution( nvinfer1::IConvolutionLayer* conv = network->addConvolution(

View File

@@ -19,6 +19,7 @@ nvinfer1::ILayer* convolutionalLayer(
std::vector<float>& weights, std::vector<float>& weights,
std::vector<nvinfer1::Weights>& trtWeights, std::vector<nvinfer1::Weights>& trtWeights,
int& weightPtr, int& weightPtr,
std::string weightsType,
int& inputChannels, int& inputChannels,
nvinfer1::ITensor* input, nvinfer1::ITensor* input,
nvinfer1::INetworkDefinition* network); nvinfer1::INetworkDefinition* network);

View File

@@ -67,32 +67,63 @@ std::vector<float> loadWeights(const std::string weightsFilePath, const std::str
{ {
assert(fileExists(weightsFilePath)); assert(fileExists(weightsFilePath));
std::cout << "\nLoading pre-trained weights" << std::endl; std::cout << "\nLoading pre-trained weights" << std::endl;
std::ifstream file(weightsFilePath, std::ios_base::binary);
assert(file.good());
std::string line;
if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos)
{
// Remove 4 int32 bytes of data from the stream belonging to the header
file.ignore(4 * 4);
}
else
{
// Remove 5 int32 bytes of data from the stream belonging to the header
file.ignore(4 * 5);
}
std::vector<float> weights; std::vector<float> weights;
char floatWeight[4];
while (!file.eof()) if (weightsFilePath.find(".weights") != std::string::npos) {
{ std::ifstream file(weightsFilePath, std::ios_base::binary);
file.read(floatWeight, 4); assert(file.good());
assert(file.gcount() == 4); std::string line;
weights.push_back(*reinterpret_cast<float*>(floatWeight));
if (file.peek() == std::istream::traits_type::eof()) break; if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos)
{
// Remove 4 int32 bytes of data from the stream belonging to the header
file.ignore(4 * 4);
}
else
{
// Remove 5 int32 bytes of data from the stream belonging to the header
file.ignore(4 * 5);
}
char floatWeight[4];
while (!file.eof())
{
file.read(floatWeight, 4);
assert(file.gcount() == 4);
weights.push_back(*reinterpret_cast<float*>(floatWeight));
if (file.peek() == std::istream::traits_type::eof()) break;
}
} }
else if (weightsFilePath.find(".wts") != std::string::npos) {
std::ifstream file(weightsFilePath);
assert(file.good());
int32_t count;
file >> count;
assert(count > 0 && "Invalid .wts file.");
uint32_t floatWeight;
std::string name;
uint32_t size;
while (count--) {
file >> name >> std::dec >> size;
for (uint32_t x = 0, y = size; x < y; ++x)
{
file >> std::hex >> floatWeight;
weights.push_back(*reinterpret_cast<float *>(&floatWeight));
};
}
}
else {
std::cerr << "File " << weightsFilePath << " is not supported" << std::endl;
std::abort();
}
std::cout << "Loading weights of " << networkType << " complete" std::cout << "Loading weights of " << networkType << " complete"
<< std::endl; << std::endl;
std::cout << "Total weights read: " << weights.size() << std::endl; std::cout << "Total weights read: " << weights.size() << std::endl;
return weights; return weights;
} }

View File

@@ -73,9 +73,6 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
parseConfigBlocks(); parseConfigBlocks();
orderParams(&m_OutputMasks); orderParams(&m_OutputMasks);
std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType);
std::vector<nvinfer1::Weights> trtWeights;
nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0); nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0);
if (parseModel(*network) != NVDSINFER_SUCCESS) { if (parseModel(*network) != NVDSINFER_SUCCESS) {
network->destroy(); network->destroy();
@@ -134,7 +131,7 @@ NvDsInferStatus Yolo::parseModel(nvinfer1::INetworkDefinition& network) {
destroyNetworkUtils(); destroyNetworkUtils();
std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType); std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType);
std::cout << "Building YOLO network" << std::endl; std::cout << "Building YOLO network\n" << std::endl;
NvDsInferStatus status = buildYoloNetwork(weights, network); NvDsInferStatus status = buildYoloNetwork(weights, network);
if (status == NVDSINFER_SUCCESS) { if (status == NVDSINFER_SUCCESS) {
@@ -151,6 +148,15 @@ NvDsInferStatus Yolo::buildYoloNetwork(
int weightPtr = 0; int weightPtr = 0;
int channels = m_InputC; int channels = m_InputC;
std::string weightsType;
if (m_WtsFilePath.find(".weights") != std::string::npos) {
weightsType = "weights";
}
else {
weightsType = "wts";
}
nvinfer1::ITensor* data = nvinfer1::ITensor* data =
network.addInput(m_InputBlobName.c_str(), nvinfer1::DataType::kFLOAT, network.addInput(m_InputBlobName.c_str(), nvinfer1::DataType::kFLOAT,
nvinfer1::Dims3{static_cast<int>(m_InputC), nvinfer1::Dims3{static_cast<int>(m_InputC),
@@ -171,7 +177,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(
else if (m_ConfigBlocks.at(i).at("type") == "convolutional") { else if (m_ConfigBlocks.at(i).at("type") == "convolutional") {
std::string inputVol = dimsToString(previous->getDimensions()); std::string inputVol = dimsToString(previous->getDimensions());
nvinfer1::ILayer* out = convolutionalLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, channels, previous, &network); nvinfer1::ILayer* out = convolutionalLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, weightsType, channels, previous, &network);
previous = out->getOutput(0); previous = out->getOutput(0);
assert(previous != nullptr); assert(previous != nullptr);
channels = getNumChannels(previous); channels = getNumChannels(previous);
@@ -272,10 +278,10 @@ NvDsInferStatus Yolo::buildYoloNetwork(
beta_nms = std::stof(m_ConfigBlocks.at(i).at("beta_nms")); beta_nms = std::stof(m_ConfigBlocks.at(i).at("beta_nms"));
} }
nvinfer1::IPluginV2* yoloPlugin nvinfer1::IPluginV2* yoloPlugin
= new YoloLayer(m_OutputTensors.at(outputTensorCount).numBBoxes, = new YoloLayer(curYoloTensor.numBBoxes,
m_OutputTensors.at(outputTensorCount).numClasses, curYoloTensor.numClasses,
m_OutputTensors.at(outputTensorCount).gridSizeX, curYoloTensor.gridSizeX,
m_OutputTensors.at(outputTensorCount).gridSizeY, curYoloTensor.gridSizeY,
1, new_coords, scale_x_y, beta_nms, 1, new_coords, scale_x_y, beta_nms,
curYoloTensor.anchors, curYoloTensor.anchors,
m_OutputMasks); m_OutputMasks);
@@ -436,7 +442,7 @@ void Yolo::parseConfigBlocks()
m_LetterBox = 0; m_LetterBox = 0;
} }
} }
else if ((block.at("type") == "region") || (block.at("type") == "yolo")) else if ((block.at("type") == "region") || (block.at("type") == "yolo") || (block.at("type") == "detect"))
{ {
assert((block.find("num") != block.end()) assert((block.find("num") != block.end())
&& std::string("Missing 'num' param in " + block.at("type") + " layer").c_str()); && std::string("Missing 'num' param in " + block.at("type") + " layer").c_str());
@@ -466,9 +472,7 @@ void Yolo::parseConfigBlocks()
} }
} }
if (block.find("mask") != block.end()) { if (block.find("mask") != block.end()) {
std::string maskString = block.at("mask"); std::string maskString = block.at("mask");
std::vector<int> pMASKS; std::vector<int> pMASKS;
while (!maskString.empty()) while (!maskString.empty())

133
readme.md
View File

@@ -6,14 +6,12 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
* New documentation for multiple models * New documentation for multiple models
* DeepStream tutorials * DeepStream tutorials
* Native PyTorch support (YOLOv5 and YOLOR) * Native YOLOR support
* Native PP-YOLO support * Native PP-YOLO support
* Models benchmark * Models benchmark
* GPU NMS * GPU NMS
* Dynamic batch-size * Dynamic batch-size
**NOTE**: The support for YOLOv5 was removed in this current update. If you want the old repo version, please use the commit 297e0e9 and DeepStream 5.1 requirements.
### Improvements on this repository ### Improvements on this repository
* Darknet CFG params parser (it doesn't need to edit nvdsparsebbox_Yolo.cpp or another file for native models) * Darknet CFG params parser (it doesn't need to edit nvdsparsebbox_Yolo.cpp or another file for native models)
@@ -24,6 +22,7 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
* Support for convolutional groups * Support for convolutional groups
* Support for INT8 calibration * Support for INT8 calibration
* Support for non square models * Support for non square models
* **YOLOv5 6.0 native support**
## ##
@@ -33,6 +32,7 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
* [Tested models](#tested-models) * [Tested models](#tested-models)
* [dGPU installation](#dgpu-installation) * [dGPU installation](#dgpu-installation)
* [Basic usage](#basic-usage) * [Basic usage](#basic-usage)
* [YOLOv5 usage](#yolov5-usage)
* [INT8 calibration](#int8-calibration) * [INT8 calibration](#int8-calibration)
* [Using your custom model](docs/customModels.md) * [Using your custom model](docs/customModels.md)
@@ -48,9 +48,14 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
* [NVIDIA DeepStream SDK 6.0](https://developer.nvidia.com/deepstream-sdk) * [NVIDIA DeepStream SDK 6.0](https://developer.nvidia.com/deepstream-sdk)
* [DeepStream-Yolo](https://github.com/marcoslucianops/DeepStream-Yolo) * [DeepStream-Yolo](https://github.com/marcoslucianops/DeepStream-Yolo)
**For YOLOv5**:
* [PyTorch >= 1.7.0](https://pytorch.org/get-started/locally/)
## ##
### Tested models ### Tested models
* [YOLOv5 6.0](https://github.com/ultralytics/yolov5) [[pt]](https://github.com/ultralytics/yolov5/releases/tag/v6.0)
* [YOLOv4x-Mish](https://github.com/AlexeyAB/darknet) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4x-mish.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4x-mish.weights)] * [YOLOv4x-Mish](https://github.com/AlexeyAB/darknet) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4x-mish.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4x-mish.weights)]
* [YOLOv4-CSP](https://github.com/WongKinYiu/ScaledYOLOv4/tree/yolov4-csp) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-csp.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4-csp.weights)] * [YOLOv4-CSP](https://github.com/WongKinYiu/ScaledYOLOv4/tree/yolov4-csp) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-csp.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4-csp.weights)]
* [YOLOv4](https://github.com/AlexeyAB/darknet) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v3_optimal/yolov4.weights)] * [YOLOv4](https://github.com/AlexeyAB/darknet) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v3_optimal/yolov4.weights)]
@@ -265,6 +270,7 @@ deepstream-app -c deepstream_app_config.txt
**NOTE**: If you want to use YOLOv2 or YOLOv2-Tiny models, change the deepstream_app_config.txt file before run it **NOTE**: If you want to use YOLOv2 or YOLOv2-Tiny models, change the deepstream_app_config.txt file before run it
``` ```
...
[primary-gie] [primary-gie]
enable=1 enable=1
gpu-id=0 gpu-id=0
@@ -277,6 +283,127 @@ config-file=config_infer_primary_yoloV2.txt
## ##
### YOLOv5 usage
#### 1. Copy gen_wts_yoloV5.py from DeepStream-Yolo/utils to ultralytics/yolov5 folder
#### 2. Open the ultralytics/yolov5 folder
#### 3. Download pt file from [ultralytics/yolov5](https://github.com/ultralytics/yolov5/releases/tag/v6.0) website (example for YOLOv5n)
```
wget https://github.com/ultralytics/yolov5/releases/download/v6.0/yolov5n.pt
```
#### 4. Generate cfg and wts files (example for YOLOv5n)
```
python3 gen_wts_yoloV5.py -w yolov5n.pt
```
#### 5. Copy generated cfg and wts files to DeepStream-Yolo folder
#### 6. Open DeepStream-Yolo folder
#### 7. Compile lib
* x86 platform
```
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
```
* Jetson platform
```
CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
```
#### 8. Edit config_infer_primary_yoloV5.txt for your model (example for YOLOv5n)
```
[property]
...
# 0=RGB, 1=BGR, 2=GRAYSCALE
model-color-format=0
# CFG
custom-network-config=yolov5n.cfg
# WTS
model-file=yolov5n.wts
# Generated TensorRT model (will be created if it doesn't exist)
model-engine-file=model_b1_gpu0_fp32.engine
# Model labels file
labelfile-path=labels.txt
# Batch size
batch-size=1
# 0=FP32, 1=INT8, 2=FP16 mode
network-mode=0
# Number of classes in label file
num-detected-classes=80
...
[class-attrs-all]
# CONF_THRESH
pre-cluster-threshold=0.25
```
#### 8. Change the deepstream_app_config.txt file
```
...
[primary-gie]
enable=1
gpu-id=0
gie-unique-id=1
nvbuf-memory-type=0
config-file=config_infer_primary_yoloV5.txt
```
#### 9. Run
```
deepstream-app -c deepstream_app_config.txt
```
**NOTE**: For YOLOv5 P6 or custom models, check the gen_wts_yoloV5.py args and use them according to your model
* Input weights (.pt) file path **(required)**
```
-w or --weights
```
* Input cfg (.yaml) file path
```
-c or --yaml
```
* Model width **(default = 640 / 1280 [P6])**
```
-mw or --width
```
* Model height **(default = 640 / 1280 [P6])**
```
-mh or --height
```
* Model channels **(default = 3)**
```
-mc or --channels
```
* P6 model
```
--p6
```
##
### INT8 calibration ### INT8 calibration
#### 1. Install OpenCV #### 1. Install OpenCV

344
utils/gen_wts_yoloV5.py Normal file
View File

@@ -0,0 +1,344 @@
import argparse
import yaml
import math
import os
import struct
import torch
from utils.torch_utils import select_device
def parse_args():
parser = argparse.ArgumentParser(description="PyTorch conversion")
parser.add_argument("-w", "--weights", required=True, help="Input weights (.pt) file path (required)")
parser.add_argument("-c", "--yaml", help="Input cfg (.yaml) file path")
parser.add_argument("-mw", "--width", help="Model width (default = 640 / 1280 [P6])")
parser.add_argument("-mh", "--height", help="Model height (default = 640 / 1280 [P6])")
parser.add_argument("-mc", "--channels", help="Model channels (default = 3)")
parser.add_argument("--p6", action="store_true", help="P6 model")
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit("Invalid weights file")
if not args.yaml:
args.yaml = ""
if not args.width:
args.width = 1280 if args.p6 else 640
if not args.height:
args.height = 1280 if args.p6 else 640
if not args.channels:
args.channels = 3
return args.weights, args.yaml, args.width, args.height, args.channels, args.p6
def get_width(x, gw, divisor=8):
return int(math.ceil((x * gw) / divisor)) * divisor
def get_depth(x, gd):
if x == 1:
return 1
r = int(round(x * gd))
if x * gd - int(x * gd) == 0.5 and int(x * gd) % 2 == 0:
r -= 1
return max(r, 1)
pt_file, yaml_file, model_width, model_height, model_channels, p6 = parse_args()
model_name = pt_file.split(".pt")[0]
wts_file = model_name + ".wts"
cfg_file = model_name + ".cfg"
if yaml_file == "":
yaml_file = "models/" + model_name + ".yaml"
if not os.path.isfile(yaml_file):
yaml_file = "models/hub/" + model_name + ".yaml"
if not os.path.isfile(yaml_file):
raise SystemExit("YAML file not found")
elif not os.path.isfile(yaml_file):
raise SystemExit("Invalid YAML file")
device = select_device("cpu")
model = torch.load(pt_file, map_location=device)["model"].float()
model.to(device).eval()
with open(wts_file, "w") as f:
wts_write = ""
conv_count = 0
cv1 = ""
cv3 = ""
cv3_idx = 0
sppf_idx = 11 if p6 else 9
for k, v in model.state_dict().items():
if not "num_batches_tracked" in k and not "anchors" in k and not "anchor_grid" in k:
vr = v.reshape(-1).cpu().numpy()
idx = int(k.split(".")[1])
if ".cv1." in k and not ".m." in k and idx != sppf_idx:
cv1 += "{} {} ".format(k, len(vr))
for vv in vr:
cv1 += " "
cv1 += struct.pack(">f" ,float(vv)).hex()
cv1 += "\n"
conv_count += 1
elif cv1 != "" and ".m." in k:
wts_write += cv1
cv1 = ""
if ".cv3." in k:
cv3 += "{} {} ".format(k, len(vr))
for vv in vr:
cv3 += " "
cv3 += struct.pack(">f" ,float(vv)).hex()
cv3 += "\n"
cv3_idx = idx
conv_count += 1
elif cv3 != "" and cv3_idx != idx:
wts_write += cv3
cv3 = ""
cv3_idx = 0
if not ".cv3." in k and not (".cv1." in k and not ".m." in k and idx != sppf_idx):
wts_write += "{} {} ".format(k, len(vr))
for vv in vr:
wts_write += " "
wts_write += struct.pack(">f" ,float(vv)).hex()
wts_write += "\n"
conv_count += 1
f.write("{}\n".format(conv_count))
f.write(wts_write)
with open(cfg_file, "w") as c:
with open(yaml_file, "r") as f:
nc = 0
depth_multiple = 0
width_multiple = 0
anchors = ""
masks = []
num = 0
detections = []
layers = []
f = yaml.load(f,Loader=yaml.FullLoader)
c.write("[net]\n")
c.write("width=%d\n" % model_width)
c.write("height=%d\n" % model_height)
c.write("channels=%d\n" % model_channels)
for l in f:
if l == "nc":
nc = f[l]
elif l == "depth_multiple":
depth_multiple = f[l]
elif l == "width_multiple":
width_multiple = f[l]
elif l == "anchors":
a = []
for v in f[l]:
a.extend(v)
mask = []
for _ in range(int(len(v) / 2)):
mask.append(num)
num += 1
masks.append(mask)
anchors = str(a)[1:-1]
elif l == "backbone" or l == "head":
for v in f[l]:
if v[2] == "Conv":
layer = ""
blocks = 0
layer += "\n[convolutional]\n"
layer += "batch_normalize=1\n"
layer += "filters=%d\n" % get_width(v[3][0], width_multiple)
layer += "size=%d\n" % v[3][1]
layer += "stride=%d\n" % v[3][2]
layer += "pad=1\n"
layer += "activation=silu\n"
blocks += 1
layers.append([layer, blocks])
elif v[2] == "C3":
layer = ""
blocks = 0
layer += "\n# C3\n"
# SPLIT
layer += "\n[convolutional]\n"
layer += "batch_normalize=1\n"
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
layer += "size=1\n"
layer += "stride=1\n"
layer += "pad=1\n"
layer += "activation=silu\n"
blocks += 1
layer += "\n[route]\n"
layer += "layers=-2\n"
blocks += 1
layer += "\n[convolutional]\n"
layer += "batch_normalize=1\n"
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
layer += "size=1\n"
layer += "stride=1\n"
layer += "pad=1\n"
layer += "activation=silu\n"
blocks += 1
# Residual Block
if len(v[3]) == 1 or v[3][1] == True:
for _ in range(get_depth(v[1], depth_multiple)):
layer += "\n[convolutional]\n"
layer += "batch_normalize=1\n"
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
layer += "size=1\n"
layer += "stride=1\n"
layer += "pad=1\n"
layer += "activation=silu\n"
blocks += 1
layer += "\n[convolutional]\n"
layer += "batch_normalize=1\n"
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
layer += "size=3\n"
layer += "stride=1\n"
layer += "pad=1\n"
layer += "activation=silu\n"
blocks += 1
layer += "\n[shortcut]\n"
layer += "from=-3\n"
layer += "activation=linear\n"
blocks += 1
# Merge
layer += "\n[route]\n"
layer += "layers=-1, -%d\n" % (3 * get_depth(v[1], depth_multiple) + 3)
blocks += 1
else:
for _ in range(get_depth(v[1], depth_multiple)):
layer += "\n[convolutional]\n"
layer += "batch_normalize=1\n"
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
layer += "size=1\n"
layer += "stride=1\n"
layer += "pad=1\n"
layer += "activation=silu\n"
blocks += 1
layer += "\n[convolutional]\n"
layer += "batch_normalize=1\n"
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
layer += "size=3\n"
layer += "stride=1\n"
layer += "pad=1\n"
layer += "activation=silu\n"
blocks += 1
# Merge
layer += "\n[route]\n"
layer += "layers=-1, -%d\n" % (2 * get_depth(v[1], depth_multiple) + 3)
blocks += 1
# Transition
layer += "\n[convolutional]\n"
layer += "batch_normalize=1\n"
layer += "filters=%d\n" % get_width(v[3][0], width_multiple)
layer += "size=1\n"
layer += "stride=1\n"
layer += "pad=1\n"
layer += "activation=silu\n"
layer += "\n##########\n"
blocks += 1
layers.append([layer, blocks])
elif v[2] == "SPPF":
layer = ""
blocks = 0
layer += "\n# SPPF\n"
layer += "\n[convolutional]\n"
layer += "batch_normalize=1\n"
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
layer += "size=1\n"
layer += "stride=1\n"
layer += "pad=1\n"
layer += "activation=silu\n"
blocks += 1
layer += "\n[maxpool]\n"
layer += "stride=1\n"
layer += "size=%d\n" % v[3][1]
blocks += 1
layer += "\n[route]\n"
layer += "layers=-2\n"
blocks += 1
layer += "\n[maxpool]\n"
layer += "stride=1\n"
layer += "size=%d\n" % v[3][1]
blocks += 1
layer += "\n[route]\n"
layer += "layers=-2\n"
blocks += 1
layer += "\n[maxpool]\n"
layer += "stride=1\n"
layer += "size=%d\n" % v[3][1]
blocks += 1
layer += "\n[route]\n"
layer += "layers=-1, -3, -5, -6\n"
blocks += 1
layer += "\n[convolutional]\n"
layer += "batch_normalize=1\n"
layer += "filters=%d\n" % get_width(v[3][0], width_multiple)
layer += "size=1\n"
layer += "stride=1\n"
layer += "pad=1\n"
layer += "activation=silu\n"
layer += "\n##########\n"
blocks += 1
layers.append([layer, blocks])
elif v[2] == "nn.Upsample":
layer = ""
blocks = 0
layer += "\n[upsample]\n"
layer += "stride=%d\n" % v[3][1]
blocks += 1
layers.append([layer, blocks])
elif v[2] == "Concat":
layer = ""
blocks = 0
route = v[0][1]
r = 0
if route > 0:
for i, item in enumerate(layers):
if i <= route:
r += item[1]
else:
break
else:
route = len(layers) + route
for i, item in enumerate(layers):
if i <= route:
r += item[1]
else:
break
layer += "\n# Concat\n"
layer += "\n[route]\n"
layer += "layers=-1, %d\n" % (r - 1)
layer += "\n##########\n"
blocks += 1
layers.append([layer, blocks])
elif v[2] == "Detect":
for i, n in enumerate(v[0]):
layer = ""
blocks = 0
r = 0
for j, item in enumerate(layers):
if j <= n:
r += item[1]
else:
break
layer += "\n# Detect\n"
layer += "\n[route]\n"
layer += "layers=%d\n" % (r - 1)
blocks += 1
layer += "\n[convolutional]\n"
layer += "size=1\n"
layer += "stride=1\n"
layer += "pad=1\n"
layer += "filters=%d\n" % ((nc + 5) * 3)
layer += "activation=logistic\n"
blocks += 1
layer += "\n[yolo]\n"
layer += "mask=%s\n" % str(masks[i])[1:-1]
layer += "anchors=%s\n" % anchors
layer += "classes=%d\n" % nc
layer += "num=%d\n" % num
layer += "scale_x_y=2.0\n"
layer += "beta_nms=0.6\n"
layer += "new_coords=1\n"
layer += "\n##########\n"
blocks += 1
layers.append([layer, blocks])
for layer in layers:
c.write(layer[0])