Added YOLOv5 6.0 native support
This commit is contained in:
24
config_infer_primary_yoloV5.txt
Normal file
24
config_infer_primary_yoloV5.txt
Normal file
@@ -0,0 +1,24 @@
|
||||
[property]
|
||||
gpu-id=0
|
||||
net-scale-factor=0.0039215697906911373
|
||||
model-color-format=0
|
||||
custom-network-config=yolov5n.cfg
|
||||
model-file=yolov5n.wts
|
||||
model-engine-file=model_b1_gpu0_fp32.engine
|
||||
#int8-calib-file=calib.table
|
||||
labelfile-path=labels.txt
|
||||
batch-size=1
|
||||
network-mode=0
|
||||
num-detected-classes=80
|
||||
interval=0
|
||||
gie-unique-id=1
|
||||
process-mode=1
|
||||
network-type=0
|
||||
cluster-mode=4
|
||||
maintain-aspect-ratio=0
|
||||
parse-bbox-func-name=NvDsInferParseYolo
|
||||
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
|
||||
engine-create-func-name=NvDsInferYoloCudaEngineGet
|
||||
|
||||
[class-attrs-all]
|
||||
pre-cluster-threshold=0.25
|
||||
@@ -12,7 +12,10 @@ nvinfer1::ILayer* activationLayer(
|
||||
nvinfer1::ITensor* input,
|
||||
nvinfer1::INetworkDefinition* network)
|
||||
{
|
||||
if (activation == "relu")
|
||||
if (activation == "linear") {
|
||||
// Pass
|
||||
}
|
||||
else if (activation == "relu")
|
||||
{
|
||||
nvinfer1::IActivationLayer* relu = network->addActivation(
|
||||
*input, nvinfer1::ActivationType::kRELU);
|
||||
@@ -78,5 +81,24 @@ nvinfer1::ILayer* activationLayer(
|
||||
mish->setName(mishLayerName.c_str());
|
||||
output = mish;
|
||||
}
|
||||
else if (activation == "silu")
|
||||
{
|
||||
nvinfer1::IActivationLayer* sigmoid = network->addActivation(
|
||||
*input, nvinfer1::ActivationType::kSIGMOID);
|
||||
assert(sigmoid != nullptr);
|
||||
std::string sigmoidLayerName = "sigmoid_" + std::to_string(layerIdx);
|
||||
sigmoid->setName(sigmoidLayerName.c_str());
|
||||
nvinfer1::IElementWiseLayer* silu = network->addElementWise(
|
||||
*sigmoid->getOutput(0), *input,
|
||||
nvinfer1::ElementWiseOperation::kPROD);
|
||||
assert(silu != nullptr);
|
||||
std::string siluLayerName = "silu_" + std::to_string(layerIdx);
|
||||
silu->setName(siluLayerName.c_str());
|
||||
output = silu;
|
||||
}
|
||||
else {
|
||||
std::cerr << "Activation not supported: " << activation << std::endl;
|
||||
std::abort();
|
||||
}
|
||||
return output;
|
||||
}
|
||||
@@ -12,6 +12,7 @@ nvinfer1::ILayer* convolutionalLayer(
|
||||
std::vector<float>& weights,
|
||||
std::vector<nvinfer1::Weights>& trtWeights,
|
||||
int& weightPtr,
|
||||
std::string weightsType,
|
||||
int& inputChannels,
|
||||
nvinfer1::ITensor* input,
|
||||
nvinfer1::INetworkDefinition* network)
|
||||
@@ -56,57 +57,111 @@ nvinfer1::ILayer* convolutionalLayer(
|
||||
nvinfer1::Weights convWt{nvinfer1::DataType::kFLOAT, nullptr, size};
|
||||
nvinfer1::Weights convBias{nvinfer1::DataType::kFLOAT, nullptr, bias};
|
||||
|
||||
if (batchNormalize == false)
|
||||
{
|
||||
float* val = new float[filters];
|
||||
for (int i = 0; i < filters; ++i)
|
||||
if (weightsType == "weights") {
|
||||
if (batchNormalize == false)
|
||||
{
|
||||
val[i] = weights[weightPtr];
|
||||
weightPtr++;
|
||||
float* val = new float[filters];
|
||||
for (int i = 0; i < filters; ++i)
|
||||
{
|
||||
val[i] = weights[weightPtr];
|
||||
weightPtr++;
|
||||
}
|
||||
convBias.values = val;
|
||||
trtWeights.push_back(convBias);
|
||||
val = new float[size];
|
||||
for (int i = 0; i < size; ++i)
|
||||
{
|
||||
val[i] = weights[weightPtr];
|
||||
weightPtr++;
|
||||
}
|
||||
convWt.values = val;
|
||||
trtWeights.push_back(convWt);
|
||||
}
|
||||
convBias.values = val;
|
||||
trtWeights.push_back(convBias);
|
||||
val = new float[size];
|
||||
for (int i = 0; i < size; ++i)
|
||||
else
|
||||
{
|
||||
val[i] = weights[weightPtr];
|
||||
weightPtr++;
|
||||
for (int i = 0; i < filters; ++i)
|
||||
{
|
||||
bnBiases.push_back(weights[weightPtr]);
|
||||
weightPtr++;
|
||||
}
|
||||
for (int i = 0; i < filters; ++i)
|
||||
{
|
||||
bnWeights.push_back(weights[weightPtr]);
|
||||
weightPtr++;
|
||||
}
|
||||
for (int i = 0; i < filters; ++i)
|
||||
{
|
||||
bnRunningMean.push_back(weights[weightPtr]);
|
||||
weightPtr++;
|
||||
}
|
||||
for (int i = 0; i < filters; ++i)
|
||||
{
|
||||
bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5));
|
||||
weightPtr++;
|
||||
}
|
||||
float* val = new float[size];
|
||||
for (int i = 0; i < size; ++i)
|
||||
{
|
||||
val[i] = weights[weightPtr];
|
||||
weightPtr++;
|
||||
}
|
||||
convWt.values = val;
|
||||
trtWeights.push_back(convWt);
|
||||
trtWeights.push_back(convBias);
|
||||
}
|
||||
convWt.values = val;
|
||||
trtWeights.push_back(convWt);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < filters; ++i)
|
||||
else {
|
||||
if (batchNormalize == false)
|
||||
{
|
||||
bnBiases.push_back(weights[weightPtr]);
|
||||
weightPtr++;
|
||||
float* val = new float[size];
|
||||
for (int i = 0; i < size; ++i)
|
||||
{
|
||||
val[i] = weights[weightPtr];
|
||||
weightPtr++;
|
||||
}
|
||||
convWt.values = val;
|
||||
trtWeights.push_back(convWt);
|
||||
val = new float[filters];
|
||||
for (int i = 0; i < filters; ++i)
|
||||
{
|
||||
val[i] = weights[weightPtr];
|
||||
weightPtr++;
|
||||
}
|
||||
convBias.values = val;
|
||||
trtWeights.push_back(convBias);
|
||||
}
|
||||
|
||||
for (int i = 0; i < filters; ++i)
|
||||
else
|
||||
{
|
||||
bnWeights.push_back(weights[weightPtr]);
|
||||
weightPtr++;
|
||||
float* val = new float[size];
|
||||
for (int i = 0; i < size; ++i)
|
||||
{
|
||||
val[i] = weights[weightPtr];
|
||||
weightPtr++;
|
||||
}
|
||||
convWt.values = val;
|
||||
for (int i = 0; i < filters; ++i)
|
||||
{
|
||||
bnWeights.push_back(weights[weightPtr]);
|
||||
weightPtr++;
|
||||
}
|
||||
for (int i = 0; i < filters; ++i)
|
||||
{
|
||||
bnBiases.push_back(weights[weightPtr]);
|
||||
weightPtr++;
|
||||
}
|
||||
for (int i = 0; i < filters; ++i)
|
||||
{
|
||||
bnRunningMean.push_back(weights[weightPtr]);
|
||||
weightPtr++;
|
||||
}
|
||||
for (int i = 0; i < filters; ++i)
|
||||
{
|
||||
bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5));
|
||||
weightPtr++;
|
||||
}
|
||||
trtWeights.push_back(convWt);
|
||||
trtWeights.push_back(convBias);
|
||||
}
|
||||
for (int i = 0; i < filters; ++i)
|
||||
{
|
||||
bnRunningMean.push_back(weights[weightPtr]);
|
||||
weightPtr++;
|
||||
}
|
||||
for (int i = 0; i < filters; ++i)
|
||||
{
|
||||
bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5));
|
||||
weightPtr++;
|
||||
}
|
||||
float* val = new float[size];
|
||||
for (int i = 0; i < size; ++i)
|
||||
{
|
||||
val[i] = weights[weightPtr];
|
||||
weightPtr++;
|
||||
}
|
||||
convWt.values = val;
|
||||
trtWeights.push_back(convWt);
|
||||
trtWeights.push_back(convBias);
|
||||
}
|
||||
|
||||
nvinfer1::IConvolutionLayer* conv = network->addConvolution(
|
||||
|
||||
@@ -19,6 +19,7 @@ nvinfer1::ILayer* convolutionalLayer(
|
||||
std::vector<float>& weights,
|
||||
std::vector<nvinfer1::Weights>& trtWeights,
|
||||
int& weightPtr,
|
||||
std::string weightsType,
|
||||
int& inputChannels,
|
||||
nvinfer1::ITensor* input,
|
||||
nvinfer1::INetworkDefinition* network);
|
||||
|
||||
@@ -67,32 +67,63 @@ std::vector<float> loadWeights(const std::string weightsFilePath, const std::str
|
||||
{
|
||||
assert(fileExists(weightsFilePath));
|
||||
std::cout << "\nLoading pre-trained weights" << std::endl;
|
||||
std::ifstream file(weightsFilePath, std::ios_base::binary);
|
||||
assert(file.good());
|
||||
std::string line;
|
||||
|
||||
if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos)
|
||||
{
|
||||
// Remove 4 int32 bytes of data from the stream belonging to the header
|
||||
file.ignore(4 * 4);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Remove 5 int32 bytes of data from the stream belonging to the header
|
||||
file.ignore(4 * 5);
|
||||
}
|
||||
|
||||
std::vector<float> weights;
|
||||
char floatWeight[4];
|
||||
while (!file.eof())
|
||||
{
|
||||
file.read(floatWeight, 4);
|
||||
assert(file.gcount() == 4);
|
||||
weights.push_back(*reinterpret_cast<float*>(floatWeight));
|
||||
if (file.peek() == std::istream::traits_type::eof()) break;
|
||||
|
||||
if (weightsFilePath.find(".weights") != std::string::npos) {
|
||||
std::ifstream file(weightsFilePath, std::ios_base::binary);
|
||||
assert(file.good());
|
||||
std::string line;
|
||||
|
||||
if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos)
|
||||
{
|
||||
// Remove 4 int32 bytes of data from the stream belonging to the header
|
||||
file.ignore(4 * 4);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Remove 5 int32 bytes of data from the stream belonging to the header
|
||||
file.ignore(4 * 5);
|
||||
}
|
||||
|
||||
char floatWeight[4];
|
||||
while (!file.eof())
|
||||
{
|
||||
file.read(floatWeight, 4);
|
||||
assert(file.gcount() == 4);
|
||||
weights.push_back(*reinterpret_cast<float*>(floatWeight));
|
||||
if (file.peek() == std::istream::traits_type::eof()) break;
|
||||
}
|
||||
}
|
||||
|
||||
else if (weightsFilePath.find(".wts") != std::string::npos) {
|
||||
std::ifstream file(weightsFilePath);
|
||||
assert(file.good());
|
||||
int32_t count;
|
||||
file >> count;
|
||||
assert(count > 0 && "Invalid .wts file.");
|
||||
|
||||
uint32_t floatWeight;
|
||||
std::string name;
|
||||
uint32_t size;
|
||||
|
||||
while (count--) {
|
||||
file >> name >> std::dec >> size;
|
||||
for (uint32_t x = 0, y = size; x < y; ++x)
|
||||
{
|
||||
file >> std::hex >> floatWeight;
|
||||
weights.push_back(*reinterpret_cast<float *>(&floatWeight));
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
else {
|
||||
std::cerr << "File " << weightsFilePath << " is not supported" << std::endl;
|
||||
std::abort();
|
||||
}
|
||||
|
||||
std::cout << "Loading weights of " << networkType << " complete"
|
||||
<< std::endl;
|
||||
<< std::endl;
|
||||
std::cout << "Total weights read: " << weights.size() << std::endl;
|
||||
return weights;
|
||||
}
|
||||
|
||||
@@ -73,9 +73,6 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
|
||||
parseConfigBlocks();
|
||||
orderParams(&m_OutputMasks);
|
||||
|
||||
std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType);
|
||||
std::vector<nvinfer1::Weights> trtWeights;
|
||||
|
||||
nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0);
|
||||
if (parseModel(*network) != NVDSINFER_SUCCESS) {
|
||||
network->destroy();
|
||||
@@ -134,7 +131,7 @@ NvDsInferStatus Yolo::parseModel(nvinfer1::INetworkDefinition& network) {
|
||||
destroyNetworkUtils();
|
||||
|
||||
std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType);
|
||||
std::cout << "Building YOLO network" << std::endl;
|
||||
std::cout << "Building YOLO network\n" << std::endl;
|
||||
NvDsInferStatus status = buildYoloNetwork(weights, network);
|
||||
|
||||
if (status == NVDSINFER_SUCCESS) {
|
||||
@@ -151,6 +148,15 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
||||
int weightPtr = 0;
|
||||
int channels = m_InputC;
|
||||
|
||||
std::string weightsType;
|
||||
|
||||
if (m_WtsFilePath.find(".weights") != std::string::npos) {
|
||||
weightsType = "weights";
|
||||
}
|
||||
else {
|
||||
weightsType = "wts";
|
||||
}
|
||||
|
||||
nvinfer1::ITensor* data =
|
||||
network.addInput(m_InputBlobName.c_str(), nvinfer1::DataType::kFLOAT,
|
||||
nvinfer1::Dims3{static_cast<int>(m_InputC),
|
||||
@@ -171,7 +177,7 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
||||
|
||||
else if (m_ConfigBlocks.at(i).at("type") == "convolutional") {
|
||||
std::string inputVol = dimsToString(previous->getDimensions());
|
||||
nvinfer1::ILayer* out = convolutionalLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, channels, previous, &network);
|
||||
nvinfer1::ILayer* out = convolutionalLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, weightsType, channels, previous, &network);
|
||||
previous = out->getOutput(0);
|
||||
assert(previous != nullptr);
|
||||
channels = getNumChannels(previous);
|
||||
@@ -272,10 +278,10 @@ NvDsInferStatus Yolo::buildYoloNetwork(
|
||||
beta_nms = std::stof(m_ConfigBlocks.at(i).at("beta_nms"));
|
||||
}
|
||||
nvinfer1::IPluginV2* yoloPlugin
|
||||
= new YoloLayer(m_OutputTensors.at(outputTensorCount).numBBoxes,
|
||||
m_OutputTensors.at(outputTensorCount).numClasses,
|
||||
m_OutputTensors.at(outputTensorCount).gridSizeX,
|
||||
m_OutputTensors.at(outputTensorCount).gridSizeY,
|
||||
= new YoloLayer(curYoloTensor.numBBoxes,
|
||||
curYoloTensor.numClasses,
|
||||
curYoloTensor.gridSizeX,
|
||||
curYoloTensor.gridSizeY,
|
||||
1, new_coords, scale_x_y, beta_nms,
|
||||
curYoloTensor.anchors,
|
||||
m_OutputMasks);
|
||||
@@ -436,7 +442,7 @@ void Yolo::parseConfigBlocks()
|
||||
m_LetterBox = 0;
|
||||
}
|
||||
}
|
||||
else if ((block.at("type") == "region") || (block.at("type") == "yolo"))
|
||||
else if ((block.at("type") == "region") || (block.at("type") == "yolo") || (block.at("type") == "detect"))
|
||||
{
|
||||
assert((block.find("num") != block.end())
|
||||
&& std::string("Missing 'num' param in " + block.at("type") + " layer").c_str());
|
||||
@@ -466,9 +472,7 @@ void Yolo::parseConfigBlocks()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (block.find("mask") != block.end()) {
|
||||
|
||||
std::string maskString = block.at("mask");
|
||||
std::vector<int> pMASKS;
|
||||
while (!maskString.empty())
|
||||
|
||||
133
readme.md
133
readme.md
@@ -6,14 +6,12 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
|
||||
|
||||
* New documentation for multiple models
|
||||
* DeepStream tutorials
|
||||
* Native PyTorch support (YOLOv5 and YOLOR)
|
||||
* Native YOLOR support
|
||||
* Native PP-YOLO support
|
||||
* Models benchmark
|
||||
* GPU NMS
|
||||
* Dynamic batch-size
|
||||
|
||||
**NOTE**: The support for YOLOv5 was removed in this current update. If you want the old repo version, please use the commit 297e0e9 and DeepStream 5.1 requirements.
|
||||
|
||||
### Improvements on this repository
|
||||
|
||||
* Darknet CFG params parser (it doesn't need to edit nvdsparsebbox_Yolo.cpp or another file for native models)
|
||||
@@ -24,6 +22,7 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
|
||||
* Support for convolutional groups
|
||||
* Support for INT8 calibration
|
||||
* Support for non square models
|
||||
* **YOLOv5 6.0 native support**
|
||||
|
||||
##
|
||||
|
||||
@@ -33,6 +32,7 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
|
||||
* [Tested models](#tested-models)
|
||||
* [dGPU installation](#dgpu-installation)
|
||||
* [Basic usage](#basic-usage)
|
||||
* [YOLOv5 usage](#yolov5-usage)
|
||||
* [INT8 calibration](#int8-calibration)
|
||||
* [Using your custom model](docs/customModels.md)
|
||||
|
||||
@@ -48,9 +48,14 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models
|
||||
* [NVIDIA DeepStream SDK 6.0](https://developer.nvidia.com/deepstream-sdk)
|
||||
* [DeepStream-Yolo](https://github.com/marcoslucianops/DeepStream-Yolo)
|
||||
|
||||
**For YOLOv5**:
|
||||
|
||||
* [PyTorch >= 1.7.0](https://pytorch.org/get-started/locally/)
|
||||
|
||||
##
|
||||
|
||||
### Tested models
|
||||
* [YOLOv5 6.0](https://github.com/ultralytics/yolov5) [[pt]](https://github.com/ultralytics/yolov5/releases/tag/v6.0)
|
||||
* [YOLOv4x-Mish](https://github.com/AlexeyAB/darknet) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4x-mish.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4x-mish.weights)]
|
||||
* [YOLOv4-CSP](https://github.com/WongKinYiu/ScaledYOLOv4/tree/yolov4-csp) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-csp.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4-csp.weights)]
|
||||
* [YOLOv4](https://github.com/AlexeyAB/darknet) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v3_optimal/yolov4.weights)]
|
||||
@@ -265,6 +270,7 @@ deepstream-app -c deepstream_app_config.txt
|
||||
**NOTE**: If you want to use YOLOv2 or YOLOv2-Tiny models, change the deepstream_app_config.txt file before run it
|
||||
|
||||
```
|
||||
...
|
||||
[primary-gie]
|
||||
enable=1
|
||||
gpu-id=0
|
||||
@@ -277,6 +283,127 @@ config-file=config_infer_primary_yoloV2.txt
|
||||
|
||||
##
|
||||
|
||||
### YOLOv5 usage
|
||||
|
||||
#### 1. Copy gen_wts_yoloV5.py from DeepStream-Yolo/utils to ultralytics/yolov5 folder
|
||||
|
||||
#### 2. Open the ultralytics/yolov5 folder
|
||||
|
||||
#### 3. Download pt file from [ultralytics/yolov5](https://github.com/ultralytics/yolov5/releases/tag/v6.0) website (example for YOLOv5n)
|
||||
|
||||
```
|
||||
wget https://github.com/ultralytics/yolov5/releases/download/v6.0/yolov5n.pt
|
||||
```
|
||||
|
||||
#### 4. Generate cfg and wts files (example for YOLOv5n)
|
||||
|
||||
```
|
||||
python3 gen_wts_yoloV5.py -w yolov5n.pt
|
||||
```
|
||||
|
||||
#### 5. Copy generated cfg and wts files to DeepStream-Yolo folder
|
||||
|
||||
#### 6. Open DeepStream-Yolo folder
|
||||
|
||||
#### 7. Compile lib
|
||||
|
||||
* x86 platform
|
||||
|
||||
```
|
||||
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
|
||||
```
|
||||
|
||||
* Jetson platform
|
||||
|
||||
```
|
||||
CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
|
||||
```
|
||||
|
||||
#### 8. Edit config_infer_primary_yoloV5.txt for your model (example for YOLOv5n)
|
||||
|
||||
```
|
||||
[property]
|
||||
...
|
||||
# 0=RGB, 1=BGR, 2=GRAYSCALE
|
||||
model-color-format=0
|
||||
# CFG
|
||||
custom-network-config=yolov5n.cfg
|
||||
# WTS
|
||||
model-file=yolov5n.wts
|
||||
# Generated TensorRT model (will be created if it doesn't exist)
|
||||
model-engine-file=model_b1_gpu0_fp32.engine
|
||||
# Model labels file
|
||||
labelfile-path=labels.txt
|
||||
# Batch size
|
||||
batch-size=1
|
||||
# 0=FP32, 1=INT8, 2=FP16 mode
|
||||
network-mode=0
|
||||
# Number of classes in label file
|
||||
num-detected-classes=80
|
||||
...
|
||||
[class-attrs-all]
|
||||
# CONF_THRESH
|
||||
pre-cluster-threshold=0.25
|
||||
```
|
||||
|
||||
#### 8. Change the deepstream_app_config.txt file
|
||||
|
||||
```
|
||||
...
|
||||
[primary-gie]
|
||||
enable=1
|
||||
gpu-id=0
|
||||
gie-unique-id=1
|
||||
nvbuf-memory-type=0
|
||||
config-file=config_infer_primary_yoloV5.txt
|
||||
```
|
||||
|
||||
#### 9. Run
|
||||
|
||||
```
|
||||
deepstream-app -c deepstream_app_config.txt
|
||||
```
|
||||
|
||||
**NOTE**: For YOLOv5 P6 or custom models, check the gen_wts_yoloV5.py args and use them according to your model
|
||||
|
||||
* Input weights (.pt) file path **(required)**
|
||||
|
||||
```
|
||||
-w or --weights
|
||||
```
|
||||
|
||||
* Input cfg (.yaml) file path
|
||||
|
||||
```
|
||||
-c or --yaml
|
||||
```
|
||||
|
||||
* Model width **(default = 640 / 1280 [P6])**
|
||||
|
||||
```
|
||||
-mw or --width
|
||||
```
|
||||
|
||||
* Model height **(default = 640 / 1280 [P6])**
|
||||
|
||||
```
|
||||
-mh or --height
|
||||
```
|
||||
|
||||
* Model channels **(default = 3)**
|
||||
|
||||
```
|
||||
-mc or --channels
|
||||
```
|
||||
|
||||
* P6 model
|
||||
|
||||
```
|
||||
--p6
|
||||
```
|
||||
|
||||
##
|
||||
|
||||
### INT8 calibration
|
||||
|
||||
#### 1. Install OpenCV
|
||||
|
||||
344
utils/gen_wts_yoloV5.py
Normal file
344
utils/gen_wts_yoloV5.py
Normal file
@@ -0,0 +1,344 @@
|
||||
import argparse
|
||||
import yaml
|
||||
import math
|
||||
import os
|
||||
import struct
|
||||
import torch
|
||||
from utils.torch_utils import select_device
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="PyTorch conversion")
|
||||
parser.add_argument("-w", "--weights", required=True, help="Input weights (.pt) file path (required)")
|
||||
parser.add_argument("-c", "--yaml", help="Input cfg (.yaml) file path")
|
||||
parser.add_argument("-mw", "--width", help="Model width (default = 640 / 1280 [P6])")
|
||||
parser.add_argument("-mh", "--height", help="Model height (default = 640 / 1280 [P6])")
|
||||
parser.add_argument("-mc", "--channels", help="Model channels (default = 3)")
|
||||
parser.add_argument("--p6", action="store_true", help="P6 model")
|
||||
args = parser.parse_args()
|
||||
if not os.path.isfile(args.weights):
|
||||
raise SystemExit("Invalid weights file")
|
||||
if not args.yaml:
|
||||
args.yaml = ""
|
||||
if not args.width:
|
||||
args.width = 1280 if args.p6 else 640
|
||||
if not args.height:
|
||||
args.height = 1280 if args.p6 else 640
|
||||
if not args.channels:
|
||||
args.channels = 3
|
||||
return args.weights, args.yaml, args.width, args.height, args.channels, args.p6
|
||||
|
||||
|
||||
def get_width(x, gw, divisor=8):
|
||||
return int(math.ceil((x * gw) / divisor)) * divisor
|
||||
|
||||
|
||||
def get_depth(x, gd):
|
||||
if x == 1:
|
||||
return 1
|
||||
r = int(round(x * gd))
|
||||
if x * gd - int(x * gd) == 0.5 and int(x * gd) % 2 == 0:
|
||||
r -= 1
|
||||
return max(r, 1)
|
||||
|
||||
|
||||
pt_file, yaml_file, model_width, model_height, model_channels, p6 = parse_args()
|
||||
|
||||
model_name = pt_file.split(".pt")[0]
|
||||
wts_file = model_name + ".wts"
|
||||
cfg_file = model_name + ".cfg"
|
||||
|
||||
if yaml_file == "":
|
||||
yaml_file = "models/" + model_name + ".yaml"
|
||||
if not os.path.isfile(yaml_file):
|
||||
yaml_file = "models/hub/" + model_name + ".yaml"
|
||||
if not os.path.isfile(yaml_file):
|
||||
raise SystemExit("YAML file not found")
|
||||
elif not os.path.isfile(yaml_file):
|
||||
raise SystemExit("Invalid YAML file")
|
||||
|
||||
device = select_device("cpu")
|
||||
model = torch.load(pt_file, map_location=device)["model"].float()
|
||||
model.to(device).eval()
|
||||
|
||||
with open(wts_file, "w") as f:
|
||||
wts_write = ""
|
||||
conv_count = 0
|
||||
cv1 = ""
|
||||
cv3 = ""
|
||||
cv3_idx = 0
|
||||
sppf_idx = 11 if p6 else 9
|
||||
for k, v in model.state_dict().items():
|
||||
if not "num_batches_tracked" in k and not "anchors" in k and not "anchor_grid" in k:
|
||||
vr = v.reshape(-1).cpu().numpy()
|
||||
idx = int(k.split(".")[1])
|
||||
if ".cv1." in k and not ".m." in k and idx != sppf_idx:
|
||||
cv1 += "{} {} ".format(k, len(vr))
|
||||
for vv in vr:
|
||||
cv1 += " "
|
||||
cv1 += struct.pack(">f" ,float(vv)).hex()
|
||||
cv1 += "\n"
|
||||
conv_count += 1
|
||||
elif cv1 != "" and ".m." in k:
|
||||
wts_write += cv1
|
||||
cv1 = ""
|
||||
if ".cv3." in k:
|
||||
cv3 += "{} {} ".format(k, len(vr))
|
||||
for vv in vr:
|
||||
cv3 += " "
|
||||
cv3 += struct.pack(">f" ,float(vv)).hex()
|
||||
cv3 += "\n"
|
||||
cv3_idx = idx
|
||||
conv_count += 1
|
||||
elif cv3 != "" and cv3_idx != idx:
|
||||
wts_write += cv3
|
||||
cv3 = ""
|
||||
cv3_idx = 0
|
||||
if not ".cv3." in k and not (".cv1." in k and not ".m." in k and idx != sppf_idx):
|
||||
wts_write += "{} {} ".format(k, len(vr))
|
||||
for vv in vr:
|
||||
wts_write += " "
|
||||
wts_write += struct.pack(">f" ,float(vv)).hex()
|
||||
wts_write += "\n"
|
||||
conv_count += 1
|
||||
f.write("{}\n".format(conv_count))
|
||||
f.write(wts_write)
|
||||
|
||||
with open(cfg_file, "w") as c:
|
||||
with open(yaml_file, "r") as f:
|
||||
nc = 0
|
||||
depth_multiple = 0
|
||||
width_multiple = 0
|
||||
anchors = ""
|
||||
masks = []
|
||||
num = 0
|
||||
detections = []
|
||||
layers = []
|
||||
f = yaml.load(f,Loader=yaml.FullLoader)
|
||||
c.write("[net]\n")
|
||||
c.write("width=%d\n" % model_width)
|
||||
c.write("height=%d\n" % model_height)
|
||||
c.write("channels=%d\n" % model_channels)
|
||||
for l in f:
|
||||
if l == "nc":
|
||||
nc = f[l]
|
||||
elif l == "depth_multiple":
|
||||
depth_multiple = f[l]
|
||||
elif l == "width_multiple":
|
||||
width_multiple = f[l]
|
||||
elif l == "anchors":
|
||||
a = []
|
||||
for v in f[l]:
|
||||
a.extend(v)
|
||||
mask = []
|
||||
for _ in range(int(len(v) / 2)):
|
||||
mask.append(num)
|
||||
num += 1
|
||||
masks.append(mask)
|
||||
anchors = str(a)[1:-1]
|
||||
elif l == "backbone" or l == "head":
|
||||
for v in f[l]:
|
||||
if v[2] == "Conv":
|
||||
layer = ""
|
||||
blocks = 0
|
||||
layer += "\n[convolutional]\n"
|
||||
layer += "batch_normalize=1\n"
|
||||
layer += "filters=%d\n" % get_width(v[3][0], width_multiple)
|
||||
layer += "size=%d\n" % v[3][1]
|
||||
layer += "stride=%d\n" % v[3][2]
|
||||
layer += "pad=1\n"
|
||||
layer += "activation=silu\n"
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
elif v[2] == "C3":
|
||||
layer = ""
|
||||
blocks = 0
|
||||
layer += "\n# C3\n"
|
||||
# SPLIT
|
||||
layer += "\n[convolutional]\n"
|
||||
layer += "batch_normalize=1\n"
|
||||
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
|
||||
layer += "size=1\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "pad=1\n"
|
||||
layer += "activation=silu\n"
|
||||
blocks += 1
|
||||
layer += "\n[route]\n"
|
||||
layer += "layers=-2\n"
|
||||
blocks += 1
|
||||
layer += "\n[convolutional]\n"
|
||||
layer += "batch_normalize=1\n"
|
||||
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
|
||||
layer += "size=1\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "pad=1\n"
|
||||
layer += "activation=silu\n"
|
||||
blocks += 1
|
||||
# Residual Block
|
||||
if len(v[3]) == 1 or v[3][1] == True:
|
||||
for _ in range(get_depth(v[1], depth_multiple)):
|
||||
layer += "\n[convolutional]\n"
|
||||
layer += "batch_normalize=1\n"
|
||||
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
|
||||
layer += "size=1\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "pad=1\n"
|
||||
layer += "activation=silu\n"
|
||||
blocks += 1
|
||||
layer += "\n[convolutional]\n"
|
||||
layer += "batch_normalize=1\n"
|
||||
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
|
||||
layer += "size=3\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "pad=1\n"
|
||||
layer += "activation=silu\n"
|
||||
blocks += 1
|
||||
layer += "\n[shortcut]\n"
|
||||
layer += "from=-3\n"
|
||||
layer += "activation=linear\n"
|
||||
blocks += 1
|
||||
# Merge
|
||||
layer += "\n[route]\n"
|
||||
layer += "layers=-1, -%d\n" % (3 * get_depth(v[1], depth_multiple) + 3)
|
||||
blocks += 1
|
||||
else:
|
||||
for _ in range(get_depth(v[1], depth_multiple)):
|
||||
layer += "\n[convolutional]\n"
|
||||
layer += "batch_normalize=1\n"
|
||||
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
|
||||
layer += "size=1\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "pad=1\n"
|
||||
layer += "activation=silu\n"
|
||||
blocks += 1
|
||||
layer += "\n[convolutional]\n"
|
||||
layer += "batch_normalize=1\n"
|
||||
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
|
||||
layer += "size=3\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "pad=1\n"
|
||||
layer += "activation=silu\n"
|
||||
blocks += 1
|
||||
# Merge
|
||||
layer += "\n[route]\n"
|
||||
layer += "layers=-1, -%d\n" % (2 * get_depth(v[1], depth_multiple) + 3)
|
||||
blocks += 1
|
||||
# Transition
|
||||
layer += "\n[convolutional]\n"
|
||||
layer += "batch_normalize=1\n"
|
||||
layer += "filters=%d\n" % get_width(v[3][0], width_multiple)
|
||||
layer += "size=1\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "pad=1\n"
|
||||
layer += "activation=silu\n"
|
||||
layer += "\n##########\n"
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
elif v[2] == "SPPF":
|
||||
layer = ""
|
||||
blocks = 0
|
||||
layer += "\n# SPPF\n"
|
||||
layer += "\n[convolutional]\n"
|
||||
layer += "batch_normalize=1\n"
|
||||
layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple)
|
||||
layer += "size=1\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "pad=1\n"
|
||||
layer += "activation=silu\n"
|
||||
blocks += 1
|
||||
layer += "\n[maxpool]\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "size=%d\n" % v[3][1]
|
||||
blocks += 1
|
||||
layer += "\n[route]\n"
|
||||
layer += "layers=-2\n"
|
||||
blocks += 1
|
||||
layer += "\n[maxpool]\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "size=%d\n" % v[3][1]
|
||||
blocks += 1
|
||||
layer += "\n[route]\n"
|
||||
layer += "layers=-2\n"
|
||||
blocks += 1
|
||||
layer += "\n[maxpool]\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "size=%d\n" % v[3][1]
|
||||
blocks += 1
|
||||
layer += "\n[route]\n"
|
||||
layer += "layers=-1, -3, -5, -6\n"
|
||||
blocks += 1
|
||||
layer += "\n[convolutional]\n"
|
||||
layer += "batch_normalize=1\n"
|
||||
layer += "filters=%d\n" % get_width(v[3][0], width_multiple)
|
||||
layer += "size=1\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "pad=1\n"
|
||||
layer += "activation=silu\n"
|
||||
layer += "\n##########\n"
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
elif v[2] == "nn.Upsample":
|
||||
layer = ""
|
||||
blocks = 0
|
||||
layer += "\n[upsample]\n"
|
||||
layer += "stride=%d\n" % v[3][1]
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
elif v[2] == "Concat":
|
||||
layer = ""
|
||||
blocks = 0
|
||||
route = v[0][1]
|
||||
r = 0
|
||||
if route > 0:
|
||||
for i, item in enumerate(layers):
|
||||
if i <= route:
|
||||
r += item[1]
|
||||
else:
|
||||
break
|
||||
else:
|
||||
route = len(layers) + route
|
||||
for i, item in enumerate(layers):
|
||||
if i <= route:
|
||||
r += item[1]
|
||||
else:
|
||||
break
|
||||
layer += "\n# Concat\n"
|
||||
layer += "\n[route]\n"
|
||||
layer += "layers=-1, %d\n" % (r - 1)
|
||||
layer += "\n##########\n"
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
elif v[2] == "Detect":
|
||||
for i, n in enumerate(v[0]):
|
||||
layer = ""
|
||||
blocks = 0
|
||||
r = 0
|
||||
for j, item in enumerate(layers):
|
||||
if j <= n:
|
||||
r += item[1]
|
||||
else:
|
||||
break
|
||||
layer += "\n# Detect\n"
|
||||
layer += "\n[route]\n"
|
||||
layer += "layers=%d\n" % (r - 1)
|
||||
blocks += 1
|
||||
layer += "\n[convolutional]\n"
|
||||
layer += "size=1\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "pad=1\n"
|
||||
layer += "filters=%d\n" % ((nc + 5) * 3)
|
||||
layer += "activation=logistic\n"
|
||||
blocks += 1
|
||||
layer += "\n[yolo]\n"
|
||||
layer += "mask=%s\n" % str(masks[i])[1:-1]
|
||||
layer += "anchors=%s\n" % anchors
|
||||
layer += "classes=%d\n" % nc
|
||||
layer += "num=%d\n" % num
|
||||
layer += "scale_x_y=2.0\n"
|
||||
layer += "beta_nms=0.6\n"
|
||||
layer += "new_coords=1\n"
|
||||
layer += "\n##########\n"
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
for layer in layers:
|
||||
c.write(layer[0])
|
||||
Reference in New Issue
Block a user