Add YOLOX support

This commit is contained in:
Marcos Luciano
2023-01-30 23:59:51 -03:00
parent f9c7a4dfca
commit 825d6bfda8
11 changed files with 746 additions and 24 deletions

View File

@@ -5,7 +5,6 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models
### Future updates ### Future updates
* DeepStream tutorials * DeepStream tutorials
* YOLOX support
* YOLOv6 support * YOLOv6 support
* Dynamic batch-size * Dynamic batch-size
* PP-YOLOE+ support * PP-YOLOE+ support
@@ -29,6 +28,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models
* Optimized NMS [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142) * Optimized NMS [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142)
* Models benchmarks * Models benchmarks
* **YOLOv8 support** * **YOLOv8 support**
* **YOLOX support**
## ##
@@ -47,6 +47,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models
* [PP-YOLOE usage](docs/PPYOLOE.md) * [PP-YOLOE usage](docs/PPYOLOE.md)
* [YOLOv7 usage](docs/YOLOv7.md) * [YOLOv7 usage](docs/YOLOv7.md)
* [YOLOv8 usage](docs/YOLOv8.md) * [YOLOv8 usage](docs/YOLOv8.md)
* [YOLOX usage](docs/YOLOX.md)
* [Using your custom model](docs/customModels.md) * [Using your custom model](docs/customModels.md)
* [Multiple YOLO GIEs](docs/multipleGIEs.md) * [Multiple YOLO GIEs](docs/multipleGIEs.md)
@@ -112,6 +113,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models
* [PP-YOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe) * [PP-YOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)
* [YOLOv7](https://github.com/WongKinYiu/yolov7) * [YOLOv7](https://github.com/WongKinYiu/yolov7)
* [YOLOv8](https://github.com/ultralytics/ultralytics) * [YOLOv8](https://github.com/ultralytics/ultralytics)
* [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)
* [MobileNet-YOLO](https://github.com/dog-qiuqiu/MobileNet-Yolo) * [MobileNet-YOLO](https://github.com/dog-qiuqiu/MobileNet-Yolo)
* [YOLO-Fastest](https://github.com/dog-qiuqiu/Yolo-Fastest) * [YOLO-Fastest](https://github.com/dog-qiuqiu/Yolo-Fastest)
@@ -135,7 +137,7 @@ sample = 1920x1080 video
- Eval - Eval
``` ```
nms-iou-threshold = 0.6 (Darknet and YOLOv8) / 0.65 (YOLOR, YOLOv5 and YOLOv7) / 0.7 (Paddle) nms-iou-threshold = 0.6 (Darknet and YOLOv8) / 0.65 (YOLOR, YOLOv5, YOLOv7 and YOLOX) / 0.7 (Paddle)
pre-cluster-threshold = 0.001 pre-cluster-threshold = 0.001
topk = 300 topk = 300
``` ```

View File

@@ -0,0 +1,27 @@
[property]
gpu-id=0
net-scale-factor=0
model-color-format=0
custom-network-config=yolox_s.cfg
model-file=yolox_s.wts
model-engine-file=model_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=0
parse-bbox-func-name=NvDsInferParseYolo
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
engine-create-func-name=NvDsInferYoloCudaEngineGet
[class-attrs-all]
nms-iou-threshold=0.45
pre-cluster-threshold=0.25
topk=300

View File

@@ -0,0 +1,28 @@
[property]
gpu-id=0
net-scale-factor=0.0173520735727919486
offsets=123.675;116.28;103.53
model-color-format=0
custom-network-config=yolox_s.cfg
model-file=yolox_s.wts
model-engine-file=model_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=0
parse-bbox-func-name=NvDsInferParseYolo
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
engine-create-func-name=NvDsInferYoloCudaEngineGet
[class-attrs-all]
nms-iou-threshold=0.45
pre-cluster-threshold=0.25
topk=300

137
docs/YOLOX.md Normal file
View File

@@ -0,0 +1,137 @@
# YOLOX usage
**NOTE**: The yaml file is not required.
* [Convert model](#convert-model)
* [Compile the lib](#compile-the-lib)
* [Edit the config_infer_primary_yolox file](#edit-the-config_infer_primary_yolox-file)
* [Edit the deepstream_app_config file](#edit-the-deepstream_app_config-file)
* [Testing the model](#testing-the-model)
##
### Convert model
#### 1. Download the YOLOX repo and install the requirements
```
git clone https://github.com/Megvii-BaseDetection/YOLOX
cd YOLOX
pip3 install -r requirements.txt
```
**NOTE**: It is recommended to use Python virtualenv.
#### 2. Copy conversor
Copy the `gen_wts_yolox.py` file from `DeepStream-Yolo/utils` directory to the `YOLOX` folder.
#### 3. Download the model
Download the `pth` file from [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX/releases) releases (example for YOLOX-s standard)
```
wget https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s.pth
```
**NOTE**: You can use your custom model, but it is important to keep the YOLO model reference (`yolox_`) in you `cfg` and `weights`/`wts` filenames to generate the engine correctly.
#### 4. Convert model
Generate the `cfg` and `wts` files (example for YOLOX-s standard)
```
python3 gen_wts_yolox.py -w yolox_s.pth -e exps/default/yolox_s.py
```
#### 5. Copy generated files
Copy the generated `cfg` and `wts` files to the `DeepStream-Yolo` folder.
##
### Compile the lib
Open the `DeepStream-Yolo` folder and compile the lib
* DeepStream 6.1.1 on x86 platform
```
CUDA_VER=11.7 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.1 on x86 platform
```
CUDA_VER=11.6 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.0.1 / 6.0 on x86 platform
```
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.1.1 / 6.1 on Jetson platform
```
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.0.1 / 6.0 on Jetson platform
```
CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
```
##
### Edit the config_infer_primary_yolox file
Edit the `config_infer_primary_yolox.txt` file according to your model (example for YOLOX-s standard)
```
[property]
...
custom-network-config=yolox_s.cfg
model-file=yolox_s.wts
...
```
**NOTE**: If you use the **legacy** model, you should edit the `config_infer_primary_yolox_legacy.txt` file.
**NOTE**: The **YOLOX standard** uses no normalization on the image preprocess. It is important to change the `net-scale-factor` according to the trained values.
```
net-scale-factor=0
```
**NOTE**: The **YOLOX legacy** uses normalization on the image preprocess. It is important to change the `net-scale-factor` and `offsets` according to the trained values.
Default: `mean = 0.485, 0.456, 0.406` and `std = 0.229, 0.224, 0.225`
```
net-scale-factor=0.0173520735727919486
offsets=123.675;116.28;103.53
```
##
### Edit the deepstream_app_config file
```
...
[primary-gie]
...
config-file=config_infer_primary_yolox.txt
```
**NOTE**: If you use the **legacy** model, you should edit it to `config_infer_primary_yolox_legacy.txt`.
##
### Testing the model
```
deepstream-app -c deepstream_app_config.txt
```

View File

@@ -44,11 +44,11 @@ detectV8Layer(int layerIdx, std::map<std::string, std::string>& block, std::vect
shuffle1Box->setName(shuffle1BoxLayerName.c_str()); shuffle1Box->setName(shuffle1BoxLayerName.c_str());
nvinfer1::Dims reshape1Dims = {3, {4, reg_max, inputDims.d[1]}}; nvinfer1::Dims reshape1Dims = {3, {4, reg_max, inputDims.d[1]}};
shuffle1Box->setReshapeDimensions(reshape1Dims); shuffle1Box->setReshapeDimensions(reshape1Dims);
nvinfer1::Permutation permutation1; nvinfer1::Permutation permutation1Box;
permutation1.order[0] = 1; permutation1Box.order[0] = 1;
permutation1.order[1] = 0; permutation1Box.order[1] = 0;
permutation1.order[2] = 2; permutation1Box.order[2] = 2;
shuffle1Box->setSecondTranspose(permutation1); shuffle1Box->setSecondTranspose(permutation1Box);
box = shuffle1Box->getOutput(0); box = shuffle1Box->getOutput(0);
nvinfer1::ISoftMaxLayer* softmax = network->addSoftMax(*box); nvinfer1::ISoftMaxLayer* softmax = network->addSoftMax(*box);
@@ -186,10 +186,10 @@ detectV8Layer(int layerIdx, std::map<std::string, std::string>& block, std::vect
assert(shuffle != nullptr); assert(shuffle != nullptr);
std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx); std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx);
shuffle->setName(shuffleLayerName.c_str()); shuffle->setName(shuffleLayerName.c_str());
nvinfer1::Permutation permutation2; nvinfer1::Permutation permutation;
permutation2.order[0] = 1; permutation.order[0] = 1;
permutation2.order[1] = 0; permutation.order[1] = 0;
shuffle->setFirstTranspose(permutation2); shuffle->setFirstTranspose(permutation);
output = shuffle->getOutput(0); output = shuffle->getOutput(0);
return output; return output;

View File

@@ -18,16 +18,15 @@ shuffleLayer(int layerIdx, std::string& layer, std::map<std::string, std::string
std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx); std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx);
shuffle->setName(shuffleLayerName.c_str()); shuffle->setName(shuffleLayerName.c_str());
int from = -1;
if (block.find("from") != block.end())
from = std::stoi(block.at("from"));
if (from < 0)
from = tensorOutputs.size() + from;
layer = std::to_string(from);
if (block.find("reshape") != block.end()) { if (block.find("reshape") != block.end()) {
int from = -1;
if (block.find("from") != block.end())
from = std::stoi(block.at("from"));
if (from < 0)
from = tensorOutputs.size() + from;
layer = std::to_string(from);
nvinfer1::Dims inputTensorDims = tensorOutputs[from]->getDimensions(); nvinfer1::Dims inputTensorDims = tensorOutputs[from]->getDimensions();
std::string strReshape = block.at("reshape"); std::string strReshape = block.at("reshape");

View File

@@ -136,7 +136,7 @@ Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::INetworkDefinition
float eps = 1.0e-5; float eps = 1.0e-5;
if (m_NetworkType.find("yolov5") != std::string::npos || m_NetworkType.find("yolov7") != std::string::npos || if (m_NetworkType.find("yolov5") != std::string::npos || m_NetworkType.find("yolov7") != std::string::npos ||
m_NetworkType.find("yolov8") != std::string::npos) m_NetworkType.find("yolov8") != std::string::npos || m_NetworkType.find("yolox") != std::string::npos)
eps = 1.0e-3; eps = 1.0e-3;
else if (m_NetworkType.find("yolor") != std::string::npos) else if (m_NetworkType.find("yolor") != std::string::npos)
eps = 1.0e-4; eps = 1.0e-4;
@@ -398,6 +398,23 @@ Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::INetworkDefinition
std::string layerName = "detect_v8"; std::string layerName = "detect_v8";
printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr)); printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr));
} }
else if (m_ConfigBlocks.at(i).at("type") == "detect_x") {
modelType = 5;
std::string blobName = "detect_x_" + std::to_string(i);
nvinfer1::Dims prevTensorDims = previous->getDimensions();
TensorInfo& curYoloTensor = m_YoloTensors.at(yoloCountInputs);
curYoloTensor.blobName = blobName;
curYoloTensor.numBBoxes = prevTensorDims.d[0];
m_NumClasses = prevTensorDims.d[1] - 5;
std::string outputVol = dimsToString(previous->getDimensions());
tensorOutputs.push_back(previous);
yoloTensorInputs[yoloCountInputs] = previous;
++yoloCountInputs;
std::string layerName = "detect_x";
printLayerInfo(layerIndex, layerName, "-", outputVol, std::to_string(weightPtr));
}
else { else {
std::cerr << "\nUnsupported layer type --> \"" << m_ConfigBlocks.at(i).at("type") << "\"" << std::endl; std::cerr << "\nUnsupported layer type --> \"" << m_ConfigBlocks.at(i).at("type") << "\"" << std::endl;
assert(0); assert(0);
@@ -415,7 +432,7 @@ Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::INetworkDefinition
uint64_t outputSize = 0; uint64_t outputSize = 0;
for (uint j = 0; j < yoloCountInputs; ++j) { for (uint j = 0; j < yoloCountInputs; ++j) {
TensorInfo& curYoloTensor = m_YoloTensors.at(j); TensorInfo& curYoloTensor = m_YoloTensors.at(j);
if (modelType == 3 || modelType == 4) if (modelType == 3 || modelType == 4 || modelType == 5)
outputSize = curYoloTensor.numBBoxes; outputSize = curYoloTensor.numBBoxes;
else else
outputSize += curYoloTensor.gridSizeX * curYoloTensor.gridSizeY * curYoloTensor.numBBoxes; outputSize += curYoloTensor.gridSizeX * curYoloTensor.gridSizeY * curYoloTensor.numBBoxes;
@@ -587,6 +604,41 @@ Yolo::parseConfigBlocks()
TensorInfo outputTensor; TensorInfo outputTensor;
m_YoloTensors.push_back(outputTensor); m_YoloTensors.push_back(outputTensor);
} }
else if (block.at("type") == "detect_x") {
++m_YoloCount;
TensorInfo outputTensor;
std::vector<int> strides;
std::string stridesString = block.at("strides");
while (!stridesString.empty()) {
int npos = stridesString.find_first_of(',');
if (npos != -1) {
int stride = std::stof(trim(stridesString.substr(0, npos)));
strides.push_back(stride);
stridesString.erase(0, npos + 1);
}
else {
int stride = std::stof(trim(stridesString));
strides.push_back(stride);
break;
}
}
for (uint i = 0; i < strides.size(); ++i) {
int num_grid_y = m_InputH / strides[i];
int num_grid_x = m_InputW / strides[i];
for (int g1 = 0; g1 < num_grid_y; ++g1) {
for (int g0 = 0; g0 < num_grid_x; ++g0) {
outputTensor.anchors.push_back((float) g0);
outputTensor.anchors.push_back((float) g1);
outputTensor.mask.push_back(strides[i]);
}
}
}
m_YoloTensors.push_back(outputTensor);
}
} }
} }

View File

@@ -18,7 +18,7 @@ __global__ void gpuYoloLayer_v8(const float* input, int* num_detections, float*
int maxIndex = -1; int maxIndex = -1;
for (uint i = 0; i < numOutputClasses; ++i) { for (uint i = 0; i < numOutputClasses; ++i) {
float prob = input[x_id * (4 + numOutputClasses) + i + 4]; float prob = input[x_id * (4 + numOutputClasses) + 4 + i];
if (prob > maxProb) { if (prob > maxProb) {
maxProb = prob; maxProb = prob;
maxIndex = i; maxIndex = i;

View File

@@ -0,0 +1,73 @@
/*
* Created by Marcos Luciano
* https://www.github.com/marcoslucianops
*/
#include <stdint.h>
__global__ void gpuYoloLayer_x(const float* input, int* num_detections, float* detection_boxes, float* detection_scores,
int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight,
const uint numOutputClasses, const uint64_t outputSize, const float* anchors, const int* mask)
{
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
if (x_id >= outputSize)
return;
const float objectness = input[x_id * (5 + numOutputClasses) + 4];
if (objectness < scoreThreshold)
return;
int count = (int)atomicAdd(num_detections, 1);
float x = (input[x_id * (5 + numOutputClasses) + 0] + anchors[x_id * 2]) * mask[x_id];
float y = (input[x_id * (5 + numOutputClasses) + 1] + anchors[x_id * 2 + 1]) * mask[x_id];
float w = __expf(input[x_id * (5 + numOutputClasses) + 2]) * mask[x_id];
float h = __expf(input[x_id * (5 + numOutputClasses) + 3]) * mask[x_id];
float maxProb = 0.0f;
int maxIndex = -1;
for (uint i = 0; i < numOutputClasses; ++i) {
float prob = input[x_id * (5 + numOutputClasses) + 5 + i];
if (prob > maxProb) {
maxProb = prob;
maxIndex = i;
}
}
detection_boxes[count * 4 + 0] = x - 0.5 * w;
detection_boxes[count * 4 + 1] = y - 0.5 * h;
detection_boxes[count * 4 + 2] = x + 0.5 * w;
detection_boxes[count * 4 + 3] = y + 0.5 * h;
detection_scores[count] = objectness * maxProb;
detection_classes[count] = maxIndex;
}
cudaError_t cudaYoloLayer_x(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& numOutputClasses, const void* anchors, const void* mask, cudaStream_t stream);
cudaError_t cudaYoloLayer_x(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& numOutputClasses, const void* anchors, const void* mask, cudaStream_t stream)
{
int threads_per_block = 16;
int number_of_blocks = (outputSize / threads_per_block) + 1;
for (unsigned int batch = 0; batch < batchSize; ++batch) {
gpuYoloLayer_x<<<number_of_blocks, threads_per_block, 0, stream>>>(
reinterpret_cast<const float*>(input) + (batch * (5 + numOutputClasses) * outputSize),
reinterpret_cast<int*>(num_detections) + (batch),
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
scoreThreshold, netWidth, netHeight, numOutputClasses, outputSize, reinterpret_cast<const float*>(anchors),
reinterpret_cast<const int*>(mask));
}
return cudaGetLastError();
}

View File

@@ -38,6 +38,10 @@ namespace {
} }
} }
cudaError_t cudaYoloLayer_x(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& numOutputClasses, const void* anchors, const void* mask, cudaStream_t stream);
cudaError_t cudaYoloLayer_v8(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, cudaError_t cudaYoloLayer_v8(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream); const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream);
@@ -158,7 +162,35 @@ YoloLayer::enqueue(int batchSize, void const* const* inputs, void* const* output
CUDA_CHECK(cudaMemsetAsync((float*)detection_scores, 0, sizeof(float) * m_OutputSize * batchSize, stream)); CUDA_CHECK(cudaMemsetAsync((float*)detection_scores, 0, sizeof(float) * m_OutputSize * batchSize, stream));
CUDA_CHECK(cudaMemsetAsync((int*)detection_classes, 0, sizeof(int) * m_OutputSize * batchSize, stream)); CUDA_CHECK(cudaMemsetAsync((int*)detection_classes, 0, sizeof(int) * m_OutputSize * batchSize, stream));
if (m_Type == 4) { if (m_Type == 5) {
TensorInfo& curYoloTensor = m_YoloTensors.at(0);
std::vector<float> anchors = curYoloTensor.anchors;
std::vector<int> mask = curYoloTensor.mask;
void* v_anchors;
void* v_mask;
if (anchors.size() > 0) {
float* f_anchors = anchors.data();
CUDA_CHECK(cudaMalloc(&v_anchors, sizeof(float) * anchors.size()));
CUDA_CHECK(cudaMemcpyAsync(v_anchors, f_anchors, sizeof(float) * anchors.size(), cudaMemcpyHostToDevice, stream));
}
if (mask.size() > 0) {
int* f_mask = mask.data();
CUDA_CHECK(cudaMalloc(&v_mask, sizeof(int) * mask.size()));
CUDA_CHECK(cudaMemcpyAsync(v_mask, f_mask, sizeof(int) * mask.size(), cudaMemcpyHostToDevice, stream));
}
CUDA_CHECK(cudaYoloLayer_x(inputs[0], num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, v_anchors, v_mask, stream));
if (anchors.size() > 0) {
CUDA_CHECK(cudaFree(v_anchors));
}
if (mask.size() > 0) {
CUDA_CHECK(cudaFree(v_mask));
}
}
else if (m_Type == 4) {
CUDA_CHECK(cudaYoloLayer_v8(inputs[0], num_detections, detection_boxes, detection_scores, detection_classes, batchSize, CUDA_CHECK(cudaYoloLayer_v8(inputs[0], num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, stream)); m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, stream));
} }

372
utils/gen_wts_yolox.py Normal file
View File

@@ -0,0 +1,372 @@
import argparse
import os
import struct
import torch
from yolox.exp import get_exp
class Layers(object):
def __init__(self, size, fw, fc):
self.blocks = [0 for _ in range(300)]
self.current = -1
self.width = size[0] if len(size) == 1 else size[1]
self.height = size[0]
self.backbone_outs = []
self.fpn_feats = []
self.pan_feats = []
self.yolo_head = []
self.fw = fw
self.fc = fc
self.wc = 0
self.net()
def Conv(self, child):
self.current += 1
if child._get_name() == 'DWConv':
self.convolutional(child.dconv)
self.convolutional(child.pconv)
else:
self.convolutional(child)
def Focus(self, child):
self.current += 1
self.reorg()
self.convolutional(child.conv)
def BaseConv(self, child, stage='', act=None):
self.current += 1
self.convolutional(child, act=act)
if stage == 'fpn':
self.fpn_feats.append(self.current)
def CSPLayer(self, child, stage=''):
self.current += 1
self.convolutional(child.conv2)
self.route('-2')
self.convolutional(child.conv1)
idx = -3
for m in child.m:
if m.use_add:
self.convolutional(m.conv1)
if m.conv2._get_name() == 'DWConv':
self.convolutional(m.conv2.dconv)
self.convolutional(m.conv2.pconv)
self.shortcut(-4)
idx -= 4
else:
self.convolutional(m.conv2)
self.shortcut(-3)
idx -= 3
else:
self.convolutional(m.conv1)
if m.conv2._get_name() == 'DWConv':
self.convolutional(m.conv2.dconv)
self.convolutional(m.conv2.pconv)
idx -= 3
else:
self.convolutional(m.conv2)
idx -= 2
self.route('-1, %d' % idx)
self.convolutional(child.conv3)
if stage == 'backbone':
self.backbone_outs.append(self.current)
elif stage == 'pan':
self.pan_feats.append(self.current)
def SPPBottleneck(self, child):
self.current += 1
self.convolutional(child.conv1)
self.maxpool(child.m[0])
self.route('-2')
self.maxpool(child.m[1])
self.route('-4')
self.maxpool(child.m[2])
self.route('-6, -5, -3, -1')
self.convolutional(child.conv2)
def Upsample(self, child):
self.current += 1
self.upsample(child)
def Concat(self, route):
self.current += 1
r = self.get_route(route)
self.route('-1, %d' % r)
def Route(self, route):
self.current += 1
if route > 0:
r = self.get_route(route)
self.route('%d' % r)
else:
self.route('%d' % route)
def RouteShuffleOut(self, route):
self.current += 1
self.route(route)
self.shuffle(reshape=['c', 'hw'])
self.yolo_head.append(self.current)
def Detect(self, strides):
self.current += 1
routes = self.yolo_head[::-1]
for i, route in enumerate(routes):
routes[i] = self.get_route(route)
self.route(str(routes)[1:-1], axis=1)
self.shuffle(transpose1=[1, 0])
self.yolo(strides)
def net(self):
self.fc.write('[net]\n' +
'width=%d\n' % self.width +
'height=%d\n' % self.height +
'channels=3\n' +
'letter_box=1\n')
def reorg(self):
self.blocks[self.current] += 1
self.fc.write('\n[reorg]\n')
def convolutional(self, cv, act=None, detect=False):
self.blocks[self.current] += 1
self.get_state_dict(cv.state_dict())
if cv._get_name() == 'Conv2d':
filters = cv.out_channels
size = cv.kernel_size
stride = cv.stride
pad = cv.padding
groups = cv.groups
bias = cv.bias
bn = False
act = act if act is not None else 'linear'
else:
filters = cv.conv.out_channels
size = cv.conv.kernel_size
stride = cv.conv.stride
pad = cv.conv.padding
groups = cv.conv.groups
bias = cv.conv.bias
bn = True if hasattr(cv, 'bn') else False
if act is None:
act = self.get_activation(cv.act._get_name()) if hasattr(cv, 'act') else 'linear'
b = 'batch_normalize=1\n' if bn is True else ''
g = 'groups=%d\n' % groups if groups > 1 else ''
w = 'bias=0\n' if bias is None and bn is False else ''
self.fc.write('\n[convolutional]\n' +
b +
'filters=%d\n' % filters +
'size=%s\n' % self.get_value(size) +
'stride=%s\n' % self.get_value(stride) +
'pad=%s\n' % self.get_value(pad) +
g +
w +
'activation=%s\n' % act)
def route(self, layers, axis=0):
self.blocks[self.current] += 1
a = 'axis=%d\n' % axis if axis != 0 else ''
self.fc.write('\n[route]\n' +
'layers=%s\n' % layers +
a)
def shortcut(self, r, ew='add', act='linear'):
self.blocks[self.current] += 1
m = 'mode=mul\n' if ew == 'mul' else ''
self.fc.write('\n[shortcut]\n' +
'from=%d\n' % r +
m +
'activation=%s\n' % act)
def maxpool(self, m):
self.blocks[self.current] += 1
stride = m.stride
size = m.kernel_size
mode = m.ceil_mode
m = 'maxpool_up' if mode else 'maxpool'
self.fc.write('\n[%s]\n' % m +
'stride=%d\n' % stride +
'size=%d\n' % size)
def upsample(self, child):
self.blocks[self.current] += 1
stride = child.scale_factor
self.fc.write('\n[upsample]\n' +
'stride=%d\n' % stride)
def shuffle(self, reshape=None, transpose1=None, transpose2=None, route=None):
self.blocks[self.current] += 1
r = 'reshape=%s\n' % ', '.join(str(x) for x in reshape) if reshape is not None else ''
t1 = 'transpose1=%s\n' % ', '.join(str(x) for x in transpose1) if transpose1 is not None else ''
t2 = 'transpose2=%s\n' % ', '.join(str(x) for x in transpose2) if transpose2 is not None else ''
f = 'from=%d\n' % route if route is not None else ''
self.fc.write('\n[shuffle]\n' +
r +
t1 +
t2 +
f)
def yolo(self, strides):
self.blocks[self.current] += 1
self.fc.write('\n[detect_x]\n' +
'strides=%s\n' % str(strides)[1:-1])
def get_state_dict(self, state_dict):
for k, v in state_dict.items():
if 'num_batches_tracked' not in k:
vr = v.reshape(-1).numpy()
self.fw.write('{} {} '.format(k, len(vr)))
for vv in vr:
self.fw.write(' ')
self.fw.write(struct.pack('>f', float(vv)).hex())
self.fw.write('\n')
self.wc += 1
def get_value(self, key):
if type(key) == int:
return key
return key[0] if key[0] == key[1] else str(key)[1:-1]
def get_route(self, n):
r = 0
for i, b in enumerate(self.blocks):
if i <= n:
r += b
else:
break
return r - 1
def get_activation(self, act):
if act == 'Hardswish':
return 'hardswish'
elif act == 'LeakyReLU':
return 'leaky'
elif act == 'SiLU':
return 'silu'
return 'linear'
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch YOLOX conversion')
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pth) file path (required)')
parser.add_argument('-e', '--exp', required=True, help='Input exp (.py) file path (required)')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit('Invalid weights file')
if not os.path.isfile(args.exp):
raise SystemExit('Invalid exp file')
return args.weights, args.exp
pth_file, exp_file = parse_args()
exp = get_exp(exp_file)
model = exp.get_model()
model.load_state_dict(torch.load(pth_file, map_location='cpu')['model'])
model.to('cpu').eval()
model_name = exp.exp_name
inference_size = (exp.input_size[1], exp.input_size[0])
backbone = model.backbone._get_name()
head = model.head._get_name()
wts_file = model_name + '.wts' if 'yolox' in model_name else 'yolox_' + model_name + '.wts'
cfg_file = model_name + '.cfg' if 'yolox' in model_name else 'yolox_' + model_name + '.cfg'
with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc:
layers = Layers(inference_size, fw, fc)
if backbone == 'YOLOPAFPN':
layers.fc.write('\n# YOLOPAFPN\n')
layers.Focus(model.backbone.backbone.stem)
layers.Conv(model.backbone.backbone.dark2[0])
layers.CSPLayer(model.backbone.backbone.dark2[1])
layers.Conv(model.backbone.backbone.dark3[0])
layers.CSPLayer(model.backbone.backbone.dark3[1], 'backbone')
layers.Conv(model.backbone.backbone.dark4[0])
layers.CSPLayer(model.backbone.backbone.dark4[1], 'backbone')
layers.Conv(model.backbone.backbone.dark5[0])
layers.SPPBottleneck(model.backbone.backbone.dark5[1])
layers.CSPLayer(model.backbone.backbone.dark5[2], 'backbone')
layers.BaseConv(model.backbone.lateral_conv0, 'fpn')
layers.Upsample(model.backbone.upsample)
layers.Concat(layers.backbone_outs[1])
layers.CSPLayer(model.backbone.C3_p4)
layers.BaseConv(model.backbone.reduce_conv1, 'fpn')
layers.Upsample(model.backbone.upsample)
layers.Concat(layers.backbone_outs[0])
layers.CSPLayer(model.backbone.C3_p3, 'pan')
layers.Conv(model.backbone.bu_conv2)
layers.Concat(layers.fpn_feats[1])
layers.CSPLayer(model.backbone.C3_n3, 'pan')
layers.Conv(model.backbone.bu_conv1)
layers.Concat(layers.fpn_feats[0])
layers.CSPLayer(model.backbone.C3_n4, 'pan')
layers.pan_feats = layers.pan_feats[::-1]
else:
raise SystemExit('Model not supported')
if head == 'YOLOXHead':
layers.fc.write('\n# YOLOXHead\n')
for i, feat in enumerate(layers.pan_feats):
idx = len(layers.pan_feats) - i - 1
dw = True if model.head.cls_convs[idx][0]._get_name() == 'DWConv' else False
if i > 0:
layers.Route(feat)
layers.BaseConv(model.head.stems[idx])
layers.Conv(model.head.cls_convs[idx][0])
layers.Conv(model.head.cls_convs[idx][1])
layers.BaseConv(model.head.cls_preds[idx], act='logistic')
if dw:
layers.Route(-6)
else:
layers.Route(-4)
layers.Conv(model.head.reg_convs[idx][0])
layers.Conv(model.head.reg_convs[idx][1])
layers.BaseConv(model.head.obj_preds[idx], act='logistic')
layers.Route(-2)
layers.BaseConv(model.head.reg_preds[idx])
if dw:
layers.RouteShuffleOut('-1, -3, -9')
else:
layers.RouteShuffleOut('-1, -3, -7')
layers.Detect(model.head.strides)
else:
raise SystemExit('Model not supported')
os.system('echo "%d" | cat - %s > temp && mv temp %s' % (layers.wc, wts_file, wts_file))