diff --git a/README.md b/README.md index 96a700c..c5bb9b4 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,6 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models ### Future updates * DeepStream tutorials -* YOLOX support * YOLOv6 support * Dynamic batch-size * PP-YOLOE+ support @@ -29,6 +28,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models * Optimized NMS [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142) * Models benchmarks * **YOLOv8 support** +* **YOLOX support** ## @@ -47,6 +47,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models * [PP-YOLOE usage](docs/PPYOLOE.md) * [YOLOv7 usage](docs/YOLOv7.md) * [YOLOv8 usage](docs/YOLOv8.md) +* [YOLOX usage](docs/YOLOX.md) * [Using your custom model](docs/customModels.md) * [Multiple YOLO GIEs](docs/multipleGIEs.md) @@ -112,6 +113,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models * [PP-YOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe) * [YOLOv7](https://github.com/WongKinYiu/yolov7) * [YOLOv8](https://github.com/ultralytics/ultralytics) +* [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX) * [MobileNet-YOLO](https://github.com/dog-qiuqiu/MobileNet-Yolo) * [YOLO-Fastest](https://github.com/dog-qiuqiu/Yolo-Fastest) @@ -135,7 +137,7 @@ sample = 1920x1080 video - Eval ``` -nms-iou-threshold = 0.6 (Darknet and YOLOv8) / 0.65 (YOLOR, YOLOv5 and YOLOv7) / 0.7 (Paddle) +nms-iou-threshold = 0.6 (Darknet and YOLOv8) / 0.65 (YOLOR, YOLOv5, YOLOv7 and YOLOX) / 0.7 (Paddle) pre-cluster-threshold = 0.001 topk = 300 ``` diff --git a/config_infer_primary_yolox.txt b/config_infer_primary_yolox.txt new file mode 100644 index 0000000..e006344 --- /dev/null +++ b/config_infer_primary_yolox.txt @@ -0,0 +1,27 @@ +[property] +gpu-id=0 +net-scale-factor=0 +model-color-format=0 +custom-network-config=yolox_s.cfg +model-file=yolox_s.wts +model-engine-file=model_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=1 +symmetric-padding=0 +parse-bbox-func-name=NvDsInferParseYolo +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so +engine-create-func-name=NvDsInferYoloCudaEngineGet + +[class-attrs-all] +nms-iou-threshold=0.45 +pre-cluster-threshold=0.25 +topk=300 diff --git a/config_infer_primary_yolox_legacy.txt b/config_infer_primary_yolox_legacy.txt new file mode 100644 index 0000000..5c078ce --- /dev/null +++ b/config_infer_primary_yolox_legacy.txt @@ -0,0 +1,28 @@ +[property] +gpu-id=0 +net-scale-factor=0.0173520735727919486 +offsets=123.675;116.28;103.53 +model-color-format=0 +custom-network-config=yolox_s.cfg +model-file=yolox_s.wts +model-engine-file=model_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=2 +maintain-aspect-ratio=1 +symmetric-padding=0 +parse-bbox-func-name=NvDsInferParseYolo +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so +engine-create-func-name=NvDsInferYoloCudaEngineGet + +[class-attrs-all] +nms-iou-threshold=0.45 +pre-cluster-threshold=0.25 +topk=300 diff --git a/docs/YOLOX.md b/docs/YOLOX.md new file mode 100644 index 0000000..055a1eb --- /dev/null +++ b/docs/YOLOX.md @@ -0,0 +1,137 @@ +# YOLOX usage + +**NOTE**: The yaml file is not required. + +* [Convert model](#convert-model) +* [Compile the lib](#compile-the-lib) +* [Edit the config_infer_primary_yolox file](#edit-the-config_infer_primary_yolox-file) +* [Edit the deepstream_app_config file](#edit-the-deepstream_app_config-file) +* [Testing the model](#testing-the-model) + +## + +### Convert model + +#### 1. Download the YOLOX repo and install the requirements + +``` +git clone https://github.com/Megvii-BaseDetection/YOLOX +cd YOLOX +pip3 install -r requirements.txt +``` + +**NOTE**: It is recommended to use Python virtualenv. + +#### 2. Copy conversor + +Copy the `gen_wts_yolox.py` file from `DeepStream-Yolo/utils` directory to the `YOLOX` folder. + +#### 3. Download the model + +Download the `pth` file from [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX/releases) releases (example for YOLOX-s standard) + +``` +wget https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s.pth +``` + +**NOTE**: You can use your custom model, but it is important to keep the YOLO model reference (`yolox_`) in you `cfg` and `weights`/`wts` filenames to generate the engine correctly. + +#### 4. Convert model + +Generate the `cfg` and `wts` files (example for YOLOX-s standard) + +``` +python3 gen_wts_yolox.py -w yolox_s.pth -e exps/default/yolox_s.py +``` + +#### 5. Copy generated files + +Copy the generated `cfg` and `wts` files to the `DeepStream-Yolo` folder. + +## + +### Compile the lib + +Open the `DeepStream-Yolo` folder and compile the lib + +* DeepStream 6.1.1 on x86 platform + + ``` + CUDA_VER=11.7 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.1 on x86 platform + + ``` + CUDA_VER=11.6 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.0.1 / 6.0 on x86 platform + + ``` + CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.1.1 / 6.1 on Jetson platform + + ``` + CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.0.1 / 6.0 on Jetson platform + + ``` + CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo + ``` + +## + +### Edit the config_infer_primary_yolox file + +Edit the `config_infer_primary_yolox.txt` file according to your model (example for YOLOX-s standard) + +``` +[property] +... +custom-network-config=yolox_s.cfg +model-file=yolox_s.wts +... +``` + +**NOTE**: If you use the **legacy** model, you should edit the `config_infer_primary_yolox_legacy.txt` file. + +**NOTE**: The **YOLOX standard** uses no normalization on the image preprocess. It is important to change the `net-scale-factor` according to the trained values. + +``` +net-scale-factor=0 +``` + +**NOTE**: The **YOLOX legacy** uses normalization on the image preprocess. It is important to change the `net-scale-factor` and `offsets` according to the trained values. + +Default: `mean = 0.485, 0.456, 0.406` and `std = 0.229, 0.224, 0.225` + +``` +net-scale-factor=0.0173520735727919486 +offsets=123.675;116.28;103.53 +``` + +## + +### Edit the deepstream_app_config file + +``` +... +[primary-gie] +... +config-file=config_infer_primary_yolox.txt +``` + +**NOTE**: If you use the **legacy** model, you should edit it to `config_infer_primary_yolox_legacy.txt`. + +## + +### Testing the model + +``` +deepstream-app -c deepstream_app_config.txt +``` diff --git a/nvdsinfer_custom_impl_Yolo/layers/detect_v8_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/detect_v8_layer.cpp index b8ad7e4..ed3e7ad 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/detect_v8_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/detect_v8_layer.cpp @@ -44,11 +44,11 @@ detectV8Layer(int layerIdx, std::map& block, std::vect shuffle1Box->setName(shuffle1BoxLayerName.c_str()); nvinfer1::Dims reshape1Dims = {3, {4, reg_max, inputDims.d[1]}}; shuffle1Box->setReshapeDimensions(reshape1Dims); - nvinfer1::Permutation permutation1; - permutation1.order[0] = 1; - permutation1.order[1] = 0; - permutation1.order[2] = 2; - shuffle1Box->setSecondTranspose(permutation1); + nvinfer1::Permutation permutation1Box; + permutation1Box.order[0] = 1; + permutation1Box.order[1] = 0; + permutation1Box.order[2] = 2; + shuffle1Box->setSecondTranspose(permutation1Box); box = shuffle1Box->getOutput(0); nvinfer1::ISoftMaxLayer* softmax = network->addSoftMax(*box); @@ -186,10 +186,10 @@ detectV8Layer(int layerIdx, std::map& block, std::vect assert(shuffle != nullptr); std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx); shuffle->setName(shuffleLayerName.c_str()); - nvinfer1::Permutation permutation2; - permutation2.order[0] = 1; - permutation2.order[1] = 0; - shuffle->setFirstTranspose(permutation2); + nvinfer1::Permutation permutation; + permutation.order[0] = 1; + permutation.order[1] = 0; + shuffle->setFirstTranspose(permutation); output = shuffle->getOutput(0); return output; diff --git a/nvdsinfer_custom_impl_Yolo/layers/shuffle_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/shuffle_layer.cpp index b844b50..fd8ce4c 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/shuffle_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/shuffle_layer.cpp @@ -18,16 +18,15 @@ shuffleLayer(int layerIdx, std::string& layer, std::mapsetName(shuffleLayerName.c_str()); + int from = -1; + if (block.find("from") != block.end()) + from = std::stoi(block.at("from")); + if (from < 0) + from = tensorOutputs.size() + from; + + layer = std::to_string(from); + if (block.find("reshape") != block.end()) { - int from = -1; - if (block.find("from") != block.end()) - from = std::stoi(block.at("from")); - - if (from < 0) - from = tensorOutputs.size() + from; - - layer = std::to_string(from); - nvinfer1::Dims inputTensorDims = tensorOutputs[from]->getDimensions(); std::string strReshape = block.at("reshape"); diff --git a/nvdsinfer_custom_impl_Yolo/yolo.cpp b/nvdsinfer_custom_impl_Yolo/yolo.cpp index f412df6..a7fa21c 100644 --- a/nvdsinfer_custom_impl_Yolo/yolo.cpp +++ b/nvdsinfer_custom_impl_Yolo/yolo.cpp @@ -136,7 +136,7 @@ Yolo::buildYoloNetwork(std::vector& weights, nvinfer1::INetworkDefinition float eps = 1.0e-5; if (m_NetworkType.find("yolov5") != std::string::npos || m_NetworkType.find("yolov7") != std::string::npos || - m_NetworkType.find("yolov8") != std::string::npos) + m_NetworkType.find("yolov8") != std::string::npos || m_NetworkType.find("yolox") != std::string::npos) eps = 1.0e-3; else if (m_NetworkType.find("yolor") != std::string::npos) eps = 1.0e-4; @@ -398,6 +398,23 @@ Yolo::buildYoloNetwork(std::vector& weights, nvinfer1::INetworkDefinition std::string layerName = "detect_v8"; printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr)); } + else if (m_ConfigBlocks.at(i).at("type") == "detect_x") { + modelType = 5; + + std::string blobName = "detect_x_" + std::to_string(i); + nvinfer1::Dims prevTensorDims = previous->getDimensions(); + TensorInfo& curYoloTensor = m_YoloTensors.at(yoloCountInputs); + curYoloTensor.blobName = blobName; + curYoloTensor.numBBoxes = prevTensorDims.d[0]; + m_NumClasses = prevTensorDims.d[1] - 5; + + std::string outputVol = dimsToString(previous->getDimensions()); + tensorOutputs.push_back(previous); + yoloTensorInputs[yoloCountInputs] = previous; + ++yoloCountInputs; + std::string layerName = "detect_x"; + printLayerInfo(layerIndex, layerName, "-", outputVol, std::to_string(weightPtr)); + } else { std::cerr << "\nUnsupported layer type --> \"" << m_ConfigBlocks.at(i).at("type") << "\"" << std::endl; assert(0); @@ -415,7 +432,7 @@ Yolo::buildYoloNetwork(std::vector& weights, nvinfer1::INetworkDefinition uint64_t outputSize = 0; for (uint j = 0; j < yoloCountInputs; ++j) { TensorInfo& curYoloTensor = m_YoloTensors.at(j); - if (modelType == 3 || modelType == 4) + if (modelType == 3 || modelType == 4 || modelType == 5) outputSize = curYoloTensor.numBBoxes; else outputSize += curYoloTensor.gridSizeX * curYoloTensor.gridSizeY * curYoloTensor.numBBoxes; @@ -587,6 +604,41 @@ Yolo::parseConfigBlocks() TensorInfo outputTensor; m_YoloTensors.push_back(outputTensor); } + else if (block.at("type") == "detect_x") { + ++m_YoloCount; + TensorInfo outputTensor; + + std::vector strides; + + std::string stridesString = block.at("strides"); + while (!stridesString.empty()) { + int npos = stridesString.find_first_of(','); + if (npos != -1) { + int stride = std::stof(trim(stridesString.substr(0, npos))); + strides.push_back(stride); + stridesString.erase(0, npos + 1); + } + else { + int stride = std::stof(trim(stridesString)); + strides.push_back(stride); + break; + } + } + + for (uint i = 0; i < strides.size(); ++i) { + int num_grid_y = m_InputH / strides[i]; + int num_grid_x = m_InputW / strides[i]; + for (int g1 = 0; g1 < num_grid_y; ++g1) { + for (int g0 = 0; g0 < num_grid_x; ++g0) { + outputTensor.anchors.push_back((float) g0); + outputTensor.anchors.push_back((float) g1); + outputTensor.mask.push_back(strides[i]); + } + } + } + + m_YoloTensors.push_back(outputTensor); + } } } diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward_v8.cu b/nvdsinfer_custom_impl_Yolo/yoloForward_v8.cu index 8bc5413..696cf14 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloForward_v8.cu +++ b/nvdsinfer_custom_impl_Yolo/yoloForward_v8.cu @@ -18,7 +18,7 @@ __global__ void gpuYoloLayer_v8(const float* input, int* num_detections, float* int maxIndex = -1; for (uint i = 0; i < numOutputClasses; ++i) { - float prob = input[x_id * (4 + numOutputClasses) + i + 4]; + float prob = input[x_id * (4 + numOutputClasses) + 4 + i]; if (prob > maxProb) { maxProb = prob; maxIndex = i; diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward_x.cu b/nvdsinfer_custom_impl_Yolo/yoloForward_x.cu new file mode 100644 index 0000000..966a669 --- /dev/null +++ b/nvdsinfer_custom_impl_Yolo/yoloForward_x.cu @@ -0,0 +1,73 @@ +/* + * Created by Marcos Luciano + * https://www.github.com/marcoslucianops + */ + +#include + +__global__ void gpuYoloLayer_x(const float* input, int* num_detections, float* detection_boxes, float* detection_scores, + int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight, + const uint numOutputClasses, const uint64_t outputSize, const float* anchors, const int* mask) +{ + uint x_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (x_id >= outputSize) + return; + + const float objectness = input[x_id * (5 + numOutputClasses) + 4]; + + if (objectness < scoreThreshold) + return; + + int count = (int)atomicAdd(num_detections, 1); + + float x = (input[x_id * (5 + numOutputClasses) + 0] + anchors[x_id * 2]) * mask[x_id]; + + float y = (input[x_id * (5 + numOutputClasses) + 1] + anchors[x_id * 2 + 1]) * mask[x_id]; + + float w = __expf(input[x_id * (5 + numOutputClasses) + 2]) * mask[x_id]; + + float h = __expf(input[x_id * (5 + numOutputClasses) + 3]) * mask[x_id]; + + float maxProb = 0.0f; + int maxIndex = -1; + + for (uint i = 0; i < numOutputClasses; ++i) { + float prob = input[x_id * (5 + numOutputClasses) + 5 + i]; + if (prob > maxProb) { + maxProb = prob; + maxIndex = i; + } + } + + detection_boxes[count * 4 + 0] = x - 0.5 * w; + detection_boxes[count * 4 + 1] = y - 0.5 * h; + detection_boxes[count * 4 + 2] = x + 0.5 * w; + detection_boxes[count * 4 + 3] = y + 0.5 * h; + detection_scores[count] = objectness * maxProb; + detection_classes[count] = maxIndex; +} + +cudaError_t cudaYoloLayer_x(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, + void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, + const uint& netHeight, const uint& numOutputClasses, const void* anchors, const void* mask, cudaStream_t stream); + +cudaError_t cudaYoloLayer_x(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, + void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, + const uint& netHeight, const uint& numOutputClasses, const void* anchors, const void* mask, cudaStream_t stream) +{ + int threads_per_block = 16; + int number_of_blocks = (outputSize / threads_per_block) + 1; + + for (unsigned int batch = 0; batch < batchSize; ++batch) { + gpuYoloLayer_x<<>>( + reinterpret_cast(input) + (batch * (5 + numOutputClasses) * outputSize), + reinterpret_cast(num_detections) + (batch), + reinterpret_cast(detection_boxes) + (batch * 4 * outputSize), + reinterpret_cast(detection_scores) + (batch * outputSize), + reinterpret_cast(detection_classes) + (batch * outputSize), + scoreThreshold, netWidth, netHeight, numOutputClasses, outputSize, reinterpret_cast(anchors), + reinterpret_cast(mask)); + } + return cudaGetLastError(); +} diff --git a/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp b/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp index ebb24b5..88c6dbd 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp +++ b/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp @@ -38,6 +38,10 @@ namespace { } } +cudaError_t cudaYoloLayer_x(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, + void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, + const uint& netHeight, const uint& numOutputClasses, const void* anchors, const void* mask, cudaStream_t stream); + cudaError_t cudaYoloLayer_v8(const void* input, void* num_detections, void* detection_boxes, void* detection_scores, void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth, const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream); @@ -158,7 +162,35 @@ YoloLayer::enqueue(int batchSize, void const* const* inputs, void* const* output CUDA_CHECK(cudaMemsetAsync((float*)detection_scores, 0, sizeof(float) * m_OutputSize * batchSize, stream)); CUDA_CHECK(cudaMemsetAsync((int*)detection_classes, 0, sizeof(int) * m_OutputSize * batchSize, stream)); - if (m_Type == 4) { + if (m_Type == 5) { + TensorInfo& curYoloTensor = m_YoloTensors.at(0); + std::vector anchors = curYoloTensor.anchors; + std::vector mask = curYoloTensor.mask; + + void* v_anchors; + void* v_mask; + if (anchors.size() > 0) { + float* f_anchors = anchors.data(); + CUDA_CHECK(cudaMalloc(&v_anchors, sizeof(float) * anchors.size())); + CUDA_CHECK(cudaMemcpyAsync(v_anchors, f_anchors, sizeof(float) * anchors.size(), cudaMemcpyHostToDevice, stream)); + } + if (mask.size() > 0) { + int* f_mask = mask.data(); + CUDA_CHECK(cudaMalloc(&v_mask, sizeof(int) * mask.size())); + CUDA_CHECK(cudaMemcpyAsync(v_mask, f_mask, sizeof(int) * mask.size(), cudaMemcpyHostToDevice, stream)); + } + + CUDA_CHECK(cudaYoloLayer_x(inputs[0], num_detections, detection_boxes, detection_scores, detection_classes, batchSize, + m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, v_anchors, v_mask, stream)); + + if (anchors.size() > 0) { + CUDA_CHECK(cudaFree(v_anchors)); + } + if (mask.size() > 0) { + CUDA_CHECK(cudaFree(v_mask)); + } + } + else if (m_Type == 4) { CUDA_CHECK(cudaYoloLayer_v8(inputs[0], num_detections, detection_boxes, detection_scores, detection_classes, batchSize, m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, stream)); } diff --git a/utils/gen_wts_yolox.py b/utils/gen_wts_yolox.py new file mode 100644 index 0000000..c47ff92 --- /dev/null +++ b/utils/gen_wts_yolox.py @@ -0,0 +1,372 @@ +import argparse +import os +import struct +import torch +from yolox.exp import get_exp + + +class Layers(object): + def __init__(self, size, fw, fc): + self.blocks = [0 for _ in range(300)] + self.current = -1 + + self.width = size[0] if len(size) == 1 else size[1] + self.height = size[0] + + self.backbone_outs = [] + self.fpn_feats = [] + self.pan_feats = [] + self.yolo_head = [] + + self.fw = fw + self.fc = fc + self.wc = 0 + + self.net() + + def Conv(self, child): + self.current += 1 + + if child._get_name() == 'DWConv': + self.convolutional(child.dconv) + self.convolutional(child.pconv) + else: + self.convolutional(child) + + def Focus(self, child): + self.current += 1 + + self.reorg() + self.convolutional(child.conv) + + def BaseConv(self, child, stage='', act=None): + self.current += 1 + + self.convolutional(child, act=act) + if stage == 'fpn': + self.fpn_feats.append(self.current) + + def CSPLayer(self, child, stage=''): + self.current += 1 + + self.convolutional(child.conv2) + self.route('-2') + self.convolutional(child.conv1) + idx = -3 + for m in child.m: + if m.use_add: + self.convolutional(m.conv1) + if m.conv2._get_name() == 'DWConv': + self.convolutional(m.conv2.dconv) + self.convolutional(m.conv2.pconv) + self.shortcut(-4) + idx -= 4 + else: + self.convolutional(m.conv2) + self.shortcut(-3) + idx -= 3 + else: + self.convolutional(m.conv1) + if m.conv2._get_name() == 'DWConv': + self.convolutional(m.conv2.dconv) + self.convolutional(m.conv2.pconv) + idx -= 3 + else: + self.convolutional(m.conv2) + idx -= 2 + self.route('-1, %d' % idx) + self.convolutional(child.conv3) + if stage == 'backbone': + self.backbone_outs.append(self.current) + elif stage == 'pan': + self.pan_feats.append(self.current) + + def SPPBottleneck(self, child): + self.current += 1 + + self.convolutional(child.conv1) + self.maxpool(child.m[0]) + self.route('-2') + self.maxpool(child.m[1]) + self.route('-4') + self.maxpool(child.m[2]) + self.route('-6, -5, -3, -1') + self.convolutional(child.conv2) + + def Upsample(self, child): + self.current += 1 + + self.upsample(child) + + def Concat(self, route): + self.current += 1 + + r = self.get_route(route) + self.route('-1, %d' % r) + + def Route(self, route): + self.current += 1 + + if route > 0: + r = self.get_route(route) + self.route('%d' % r) + else: + self.route('%d' % route) + + def RouteShuffleOut(self, route): + self.current += 1 + + self.route(route) + self.shuffle(reshape=['c', 'hw']) + self.yolo_head.append(self.current) + + def Detect(self, strides): + self.current += 1 + + routes = self.yolo_head[::-1] + + for i, route in enumerate(routes): + routes[i] = self.get_route(route) + self.route(str(routes)[1:-1], axis=1) + self.shuffle(transpose1=[1, 0]) + self.yolo(strides) + + def net(self): + self.fc.write('[net]\n' + + 'width=%d\n' % self.width + + 'height=%d\n' % self.height + + 'channels=3\n' + + 'letter_box=1\n') + + def reorg(self): + self.blocks[self.current] += 1 + + self.fc.write('\n[reorg]\n') + + def convolutional(self, cv, act=None, detect=False): + self.blocks[self.current] += 1 + + self.get_state_dict(cv.state_dict()) + + if cv._get_name() == 'Conv2d': + filters = cv.out_channels + size = cv.kernel_size + stride = cv.stride + pad = cv.padding + groups = cv.groups + bias = cv.bias + bn = False + act = act if act is not None else 'linear' + else: + filters = cv.conv.out_channels + size = cv.conv.kernel_size + stride = cv.conv.stride + pad = cv.conv.padding + groups = cv.conv.groups + bias = cv.conv.bias + bn = True if hasattr(cv, 'bn') else False + if act is None: + act = self.get_activation(cv.act._get_name()) if hasattr(cv, 'act') else 'linear' + + b = 'batch_normalize=1\n' if bn is True else '' + g = 'groups=%d\n' % groups if groups > 1 else '' + w = 'bias=0\n' if bias is None and bn is False else '' + + self.fc.write('\n[convolutional]\n' + + b + + 'filters=%d\n' % filters + + 'size=%s\n' % self.get_value(size) + + 'stride=%s\n' % self.get_value(stride) + + 'pad=%s\n' % self.get_value(pad) + + g + + w + + 'activation=%s\n' % act) + + def route(self, layers, axis=0): + self.blocks[self.current] += 1 + + a = 'axis=%d\n' % axis if axis != 0 else '' + + self.fc.write('\n[route]\n' + + 'layers=%s\n' % layers + + a) + + def shortcut(self, r, ew='add', act='linear'): + self.blocks[self.current] += 1 + + m = 'mode=mul\n' if ew == 'mul' else '' + + self.fc.write('\n[shortcut]\n' + + 'from=%d\n' % r + + m + + 'activation=%s\n' % act) + + def maxpool(self, m): + self.blocks[self.current] += 1 + + stride = m.stride + size = m.kernel_size + mode = m.ceil_mode + + m = 'maxpool_up' if mode else 'maxpool' + + self.fc.write('\n[%s]\n' % m + + 'stride=%d\n' % stride + + 'size=%d\n' % size) + + def upsample(self, child): + self.blocks[self.current] += 1 + + stride = child.scale_factor + + self.fc.write('\n[upsample]\n' + + 'stride=%d\n' % stride) + + def shuffle(self, reshape=None, transpose1=None, transpose2=None, route=None): + self.blocks[self.current] += 1 + + r = 'reshape=%s\n' % ', '.join(str(x) for x in reshape) if reshape is not None else '' + t1 = 'transpose1=%s\n' % ', '.join(str(x) for x in transpose1) if transpose1 is not None else '' + t2 = 'transpose2=%s\n' % ', '.join(str(x) for x in transpose2) if transpose2 is not None else '' + f = 'from=%d\n' % route if route is not None else '' + + self.fc.write('\n[shuffle]\n' + + r + + t1 + + t2 + + f) + + def yolo(self, strides): + self.blocks[self.current] += 1 + + self.fc.write('\n[detect_x]\n' + + 'strides=%s\n' % str(strides)[1:-1]) + + def get_state_dict(self, state_dict): + for k, v in state_dict.items(): + if 'num_batches_tracked' not in k: + vr = v.reshape(-1).numpy() + self.fw.write('{} {} '.format(k, len(vr))) + for vv in vr: + self.fw.write(' ') + self.fw.write(struct.pack('>f', float(vv)).hex()) + self.fw.write('\n') + self.wc += 1 + + def get_value(self, key): + if type(key) == int: + return key + return key[0] if key[0] == key[1] else str(key)[1:-1] + + def get_route(self, n): + r = 0 + for i, b in enumerate(self.blocks): + if i <= n: + r += b + else: + break + return r - 1 + + def get_activation(self, act): + if act == 'Hardswish': + return 'hardswish' + elif act == 'LeakyReLU': + return 'leaky' + elif act == 'SiLU': + return 'silu' + return 'linear' + + +def parse_args(): + parser = argparse.ArgumentParser(description='PyTorch YOLOX conversion') + parser.add_argument('-w', '--weights', required=True, help='Input weights (.pth) file path (required)') + parser.add_argument('-e', '--exp', required=True, help='Input exp (.py) file path (required)') + args = parser.parse_args() + if not os.path.isfile(args.weights): + raise SystemExit('Invalid weights file') + if not os.path.isfile(args.exp): + raise SystemExit('Invalid exp file') + return args.weights, args.exp + + +pth_file, exp_file = parse_args() + +exp = get_exp(exp_file) +model = exp.get_model() +model.load_state_dict(torch.load(pth_file, map_location='cpu')['model']) +model.to('cpu').eval() + +model_name = exp.exp_name +inference_size = (exp.input_size[1], exp.input_size[0]) + +backbone = model.backbone._get_name() +head = model.head._get_name() + +wts_file = model_name + '.wts' if 'yolox' in model_name else 'yolox_' + model_name + '.wts' +cfg_file = model_name + '.cfg' if 'yolox' in model_name else 'yolox_' + model_name + '.cfg' + +with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc: + layers = Layers(inference_size, fw, fc) + + if backbone == 'YOLOPAFPN': + layers.fc.write('\n# YOLOPAFPN\n') + + layers.Focus(model.backbone.backbone.stem) + layers.Conv(model.backbone.backbone.dark2[0]) + layers.CSPLayer(model.backbone.backbone.dark2[1]) + layers.Conv(model.backbone.backbone.dark3[0]) + layers.CSPLayer(model.backbone.backbone.dark3[1], 'backbone') + layers.Conv(model.backbone.backbone.dark4[0]) + layers.CSPLayer(model.backbone.backbone.dark4[1], 'backbone') + layers.Conv(model.backbone.backbone.dark5[0]) + layers.SPPBottleneck(model.backbone.backbone.dark5[1]) + layers.CSPLayer(model.backbone.backbone.dark5[2], 'backbone') + layers.BaseConv(model.backbone.lateral_conv0, 'fpn') + layers.Upsample(model.backbone.upsample) + layers.Concat(layers.backbone_outs[1]) + layers.CSPLayer(model.backbone.C3_p4) + layers.BaseConv(model.backbone.reduce_conv1, 'fpn') + layers.Upsample(model.backbone.upsample) + layers.Concat(layers.backbone_outs[0]) + layers.CSPLayer(model.backbone.C3_p3, 'pan') + layers.Conv(model.backbone.bu_conv2) + layers.Concat(layers.fpn_feats[1]) + layers.CSPLayer(model.backbone.C3_n3, 'pan') + layers.Conv(model.backbone.bu_conv1) + layers.Concat(layers.fpn_feats[0]) + layers.CSPLayer(model.backbone.C3_n4, 'pan') + layers.pan_feats = layers.pan_feats[::-1] + else: + raise SystemExit('Model not supported') + + if head == 'YOLOXHead': + layers.fc.write('\n# YOLOXHead\n') + + for i, feat in enumerate(layers.pan_feats): + idx = len(layers.pan_feats) - i - 1 + dw = True if model.head.cls_convs[idx][0]._get_name() == 'DWConv' else False + if i > 0: + layers.Route(feat) + layers.BaseConv(model.head.stems[idx]) + layers.Conv(model.head.cls_convs[idx][0]) + layers.Conv(model.head.cls_convs[idx][1]) + layers.BaseConv(model.head.cls_preds[idx], act='logistic') + if dw: + layers.Route(-6) + else: + layers.Route(-4) + layers.Conv(model.head.reg_convs[idx][0]) + layers.Conv(model.head.reg_convs[idx][1]) + layers.BaseConv(model.head.obj_preds[idx], act='logistic') + layers.Route(-2) + layers.BaseConv(model.head.reg_preds[idx]) + if dw: + layers.RouteShuffleOut('-1, -3, -9') + else: + layers.RouteShuffleOut('-1, -3, -7') + layers.Detect(model.head.strides) + + else: + raise SystemExit('Model not supported') + +os.system('echo "%d" | cat - %s > temp && mv temp %s' % (layers.wc, wts_file, wts_file))