Add YOLOX support
This commit is contained in:
@@ -5,7 +5,6 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models
|
||||
### Future updates
|
||||
|
||||
* DeepStream tutorials
|
||||
* YOLOX support
|
||||
* YOLOv6 support
|
||||
* Dynamic batch-size
|
||||
* PP-YOLOE+ support
|
||||
@@ -29,6 +28,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models
|
||||
* Optimized NMS [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142)
|
||||
* Models benchmarks
|
||||
* **YOLOv8 support**
|
||||
* **YOLOX support**
|
||||
|
||||
##
|
||||
|
||||
@@ -47,6 +47,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models
|
||||
* [PP-YOLOE usage](docs/PPYOLOE.md)
|
||||
* [YOLOv7 usage](docs/YOLOv7.md)
|
||||
* [YOLOv8 usage](docs/YOLOv8.md)
|
||||
* [YOLOX usage](docs/YOLOX.md)
|
||||
* [Using your custom model](docs/customModels.md)
|
||||
* [Multiple YOLO GIEs](docs/multipleGIEs.md)
|
||||
|
||||
@@ -112,6 +113,7 @@ NVIDIA DeepStream SDK 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO models
|
||||
* [PP-YOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)
|
||||
* [YOLOv7](https://github.com/WongKinYiu/yolov7)
|
||||
* [YOLOv8](https://github.com/ultralytics/ultralytics)
|
||||
* [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)
|
||||
* [MobileNet-YOLO](https://github.com/dog-qiuqiu/MobileNet-Yolo)
|
||||
* [YOLO-Fastest](https://github.com/dog-qiuqiu/Yolo-Fastest)
|
||||
|
||||
@@ -135,7 +137,7 @@ sample = 1920x1080 video
|
||||
- Eval
|
||||
|
||||
```
|
||||
nms-iou-threshold = 0.6 (Darknet and YOLOv8) / 0.65 (YOLOR, YOLOv5 and YOLOv7) / 0.7 (Paddle)
|
||||
nms-iou-threshold = 0.6 (Darknet and YOLOv8) / 0.65 (YOLOR, YOLOv5, YOLOv7 and YOLOX) / 0.7 (Paddle)
|
||||
pre-cluster-threshold = 0.001
|
||||
topk = 300
|
||||
```
|
||||
|
||||
27
config_infer_primary_yolox.txt
Normal file
27
config_infer_primary_yolox.txt
Normal file
@@ -0,0 +1,27 @@
|
||||
[property]
|
||||
gpu-id=0
|
||||
net-scale-factor=0
|
||||
model-color-format=0
|
||||
custom-network-config=yolox_s.cfg
|
||||
model-file=yolox_s.wts
|
||||
model-engine-file=model_b1_gpu0_fp32.engine
|
||||
#int8-calib-file=calib.table
|
||||
labelfile-path=labels.txt
|
||||
batch-size=1
|
||||
network-mode=0
|
||||
num-detected-classes=80
|
||||
interval=0
|
||||
gie-unique-id=1
|
||||
process-mode=1
|
||||
network-type=0
|
||||
cluster-mode=2
|
||||
maintain-aspect-ratio=1
|
||||
symmetric-padding=0
|
||||
parse-bbox-func-name=NvDsInferParseYolo
|
||||
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
|
||||
engine-create-func-name=NvDsInferYoloCudaEngineGet
|
||||
|
||||
[class-attrs-all]
|
||||
nms-iou-threshold=0.45
|
||||
pre-cluster-threshold=0.25
|
||||
topk=300
|
||||
28
config_infer_primary_yolox_legacy.txt
Normal file
28
config_infer_primary_yolox_legacy.txt
Normal file
@@ -0,0 +1,28 @@
|
||||
[property]
|
||||
gpu-id=0
|
||||
net-scale-factor=0.0173520735727919486
|
||||
offsets=123.675;116.28;103.53
|
||||
model-color-format=0
|
||||
custom-network-config=yolox_s.cfg
|
||||
model-file=yolox_s.wts
|
||||
model-engine-file=model_b1_gpu0_fp32.engine
|
||||
#int8-calib-file=calib.table
|
||||
labelfile-path=labels.txt
|
||||
batch-size=1
|
||||
network-mode=0
|
||||
num-detected-classes=80
|
||||
interval=0
|
||||
gie-unique-id=1
|
||||
process-mode=1
|
||||
network-type=0
|
||||
cluster-mode=2
|
||||
maintain-aspect-ratio=1
|
||||
symmetric-padding=0
|
||||
parse-bbox-func-name=NvDsInferParseYolo
|
||||
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
|
||||
engine-create-func-name=NvDsInferYoloCudaEngineGet
|
||||
|
||||
[class-attrs-all]
|
||||
nms-iou-threshold=0.45
|
||||
pre-cluster-threshold=0.25
|
||||
topk=300
|
||||
137
docs/YOLOX.md
Normal file
137
docs/YOLOX.md
Normal file
@@ -0,0 +1,137 @@
|
||||
# YOLOX usage
|
||||
|
||||
**NOTE**: The yaml file is not required.
|
||||
|
||||
* [Convert model](#convert-model)
|
||||
* [Compile the lib](#compile-the-lib)
|
||||
* [Edit the config_infer_primary_yolox file](#edit-the-config_infer_primary_yolox-file)
|
||||
* [Edit the deepstream_app_config file](#edit-the-deepstream_app_config-file)
|
||||
* [Testing the model](#testing-the-model)
|
||||
|
||||
##
|
||||
|
||||
### Convert model
|
||||
|
||||
#### 1. Download the YOLOX repo and install the requirements
|
||||
|
||||
```
|
||||
git clone https://github.com/Megvii-BaseDetection/YOLOX
|
||||
cd YOLOX
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
**NOTE**: It is recommended to use Python virtualenv.
|
||||
|
||||
#### 2. Copy conversor
|
||||
|
||||
Copy the `gen_wts_yolox.py` file from `DeepStream-Yolo/utils` directory to the `YOLOX` folder.
|
||||
|
||||
#### 3. Download the model
|
||||
|
||||
Download the `pth` file from [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX/releases) releases (example for YOLOX-s standard)
|
||||
|
||||
```
|
||||
wget https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s.pth
|
||||
```
|
||||
|
||||
**NOTE**: You can use your custom model, but it is important to keep the YOLO model reference (`yolox_`) in you `cfg` and `weights`/`wts` filenames to generate the engine correctly.
|
||||
|
||||
#### 4. Convert model
|
||||
|
||||
Generate the `cfg` and `wts` files (example for YOLOX-s standard)
|
||||
|
||||
```
|
||||
python3 gen_wts_yolox.py -w yolox_s.pth -e exps/default/yolox_s.py
|
||||
```
|
||||
|
||||
#### 5. Copy generated files
|
||||
|
||||
Copy the generated `cfg` and `wts` files to the `DeepStream-Yolo` folder.
|
||||
|
||||
##
|
||||
|
||||
### Compile the lib
|
||||
|
||||
Open the `DeepStream-Yolo` folder and compile the lib
|
||||
|
||||
* DeepStream 6.1.1 on x86 platform
|
||||
|
||||
```
|
||||
CUDA_VER=11.7 make -C nvdsinfer_custom_impl_Yolo
|
||||
```
|
||||
|
||||
* DeepStream 6.1 on x86 platform
|
||||
|
||||
```
|
||||
CUDA_VER=11.6 make -C nvdsinfer_custom_impl_Yolo
|
||||
```
|
||||
|
||||
* DeepStream 6.0.1 / 6.0 on x86 platform
|
||||
|
||||
```
|
||||
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
|
||||
```
|
||||
|
||||
* DeepStream 6.1.1 / 6.1 on Jetson platform
|
||||
|
||||
```
|
||||
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
|
||||
```
|
||||
|
||||
* DeepStream 6.0.1 / 6.0 on Jetson platform
|
||||
|
||||
```
|
||||
CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
|
||||
```
|
||||
|
||||
##
|
||||
|
||||
### Edit the config_infer_primary_yolox file
|
||||
|
||||
Edit the `config_infer_primary_yolox.txt` file according to your model (example for YOLOX-s standard)
|
||||
|
||||
```
|
||||
[property]
|
||||
...
|
||||
custom-network-config=yolox_s.cfg
|
||||
model-file=yolox_s.wts
|
||||
...
|
||||
```
|
||||
|
||||
**NOTE**: If you use the **legacy** model, you should edit the `config_infer_primary_yolox_legacy.txt` file.
|
||||
|
||||
**NOTE**: The **YOLOX standard** uses no normalization on the image preprocess. It is important to change the `net-scale-factor` according to the trained values.
|
||||
|
||||
```
|
||||
net-scale-factor=0
|
||||
```
|
||||
|
||||
**NOTE**: The **YOLOX legacy** uses normalization on the image preprocess. It is important to change the `net-scale-factor` and `offsets` according to the trained values.
|
||||
|
||||
Default: `mean = 0.485, 0.456, 0.406` and `std = 0.229, 0.224, 0.225`
|
||||
|
||||
```
|
||||
net-scale-factor=0.0173520735727919486
|
||||
offsets=123.675;116.28;103.53
|
||||
```
|
||||
|
||||
##
|
||||
|
||||
### Edit the deepstream_app_config file
|
||||
|
||||
```
|
||||
...
|
||||
[primary-gie]
|
||||
...
|
||||
config-file=config_infer_primary_yolox.txt
|
||||
```
|
||||
|
||||
**NOTE**: If you use the **legacy** model, you should edit it to `config_infer_primary_yolox_legacy.txt`.
|
||||
|
||||
##
|
||||
|
||||
### Testing the model
|
||||
|
||||
```
|
||||
deepstream-app -c deepstream_app_config.txt
|
||||
```
|
||||
@@ -44,11 +44,11 @@ detectV8Layer(int layerIdx, std::map<std::string, std::string>& block, std::vect
|
||||
shuffle1Box->setName(shuffle1BoxLayerName.c_str());
|
||||
nvinfer1::Dims reshape1Dims = {3, {4, reg_max, inputDims.d[1]}};
|
||||
shuffle1Box->setReshapeDimensions(reshape1Dims);
|
||||
nvinfer1::Permutation permutation1;
|
||||
permutation1.order[0] = 1;
|
||||
permutation1.order[1] = 0;
|
||||
permutation1.order[2] = 2;
|
||||
shuffle1Box->setSecondTranspose(permutation1);
|
||||
nvinfer1::Permutation permutation1Box;
|
||||
permutation1Box.order[0] = 1;
|
||||
permutation1Box.order[1] = 0;
|
||||
permutation1Box.order[2] = 2;
|
||||
shuffle1Box->setSecondTranspose(permutation1Box);
|
||||
box = shuffle1Box->getOutput(0);
|
||||
|
||||
nvinfer1::ISoftMaxLayer* softmax = network->addSoftMax(*box);
|
||||
@@ -186,10 +186,10 @@ detectV8Layer(int layerIdx, std::map<std::string, std::string>& block, std::vect
|
||||
assert(shuffle != nullptr);
|
||||
std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx);
|
||||
shuffle->setName(shuffleLayerName.c_str());
|
||||
nvinfer1::Permutation permutation2;
|
||||
permutation2.order[0] = 1;
|
||||
permutation2.order[1] = 0;
|
||||
shuffle->setFirstTranspose(permutation2);
|
||||
nvinfer1::Permutation permutation;
|
||||
permutation.order[0] = 1;
|
||||
permutation.order[1] = 0;
|
||||
shuffle->setFirstTranspose(permutation);
|
||||
output = shuffle->getOutput(0);
|
||||
|
||||
return output;
|
||||
|
||||
@@ -18,16 +18,15 @@ shuffleLayer(int layerIdx, std::string& layer, std::map<std::string, std::string
|
||||
std::string shuffleLayerName = "shuffle_" + std::to_string(layerIdx);
|
||||
shuffle->setName(shuffleLayerName.c_str());
|
||||
|
||||
int from = -1;
|
||||
if (block.find("from") != block.end())
|
||||
from = std::stoi(block.at("from"));
|
||||
if (from < 0)
|
||||
from = tensorOutputs.size() + from;
|
||||
|
||||
layer = std::to_string(from);
|
||||
|
||||
if (block.find("reshape") != block.end()) {
|
||||
int from = -1;
|
||||
if (block.find("from") != block.end())
|
||||
from = std::stoi(block.at("from"));
|
||||
|
||||
if (from < 0)
|
||||
from = tensorOutputs.size() + from;
|
||||
|
||||
layer = std::to_string(from);
|
||||
|
||||
nvinfer1::Dims inputTensorDims = tensorOutputs[from]->getDimensions();
|
||||
|
||||
std::string strReshape = block.at("reshape");
|
||||
|
||||
@@ -136,7 +136,7 @@ Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::INetworkDefinition
|
||||
|
||||
float eps = 1.0e-5;
|
||||
if (m_NetworkType.find("yolov5") != std::string::npos || m_NetworkType.find("yolov7") != std::string::npos ||
|
||||
m_NetworkType.find("yolov8") != std::string::npos)
|
||||
m_NetworkType.find("yolov8") != std::string::npos || m_NetworkType.find("yolox") != std::string::npos)
|
||||
eps = 1.0e-3;
|
||||
else if (m_NetworkType.find("yolor") != std::string::npos)
|
||||
eps = 1.0e-4;
|
||||
@@ -398,6 +398,23 @@ Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::INetworkDefinition
|
||||
std::string layerName = "detect_v8";
|
||||
printLayerInfo(layerIndex, layerName, inputVol, outputVol, std::to_string(weightPtr));
|
||||
}
|
||||
else if (m_ConfigBlocks.at(i).at("type") == "detect_x") {
|
||||
modelType = 5;
|
||||
|
||||
std::string blobName = "detect_x_" + std::to_string(i);
|
||||
nvinfer1::Dims prevTensorDims = previous->getDimensions();
|
||||
TensorInfo& curYoloTensor = m_YoloTensors.at(yoloCountInputs);
|
||||
curYoloTensor.blobName = blobName;
|
||||
curYoloTensor.numBBoxes = prevTensorDims.d[0];
|
||||
m_NumClasses = prevTensorDims.d[1] - 5;
|
||||
|
||||
std::string outputVol = dimsToString(previous->getDimensions());
|
||||
tensorOutputs.push_back(previous);
|
||||
yoloTensorInputs[yoloCountInputs] = previous;
|
||||
++yoloCountInputs;
|
||||
std::string layerName = "detect_x";
|
||||
printLayerInfo(layerIndex, layerName, "-", outputVol, std::to_string(weightPtr));
|
||||
}
|
||||
else {
|
||||
std::cerr << "\nUnsupported layer type --> \"" << m_ConfigBlocks.at(i).at("type") << "\"" << std::endl;
|
||||
assert(0);
|
||||
@@ -415,7 +432,7 @@ Yolo::buildYoloNetwork(std::vector<float>& weights, nvinfer1::INetworkDefinition
|
||||
uint64_t outputSize = 0;
|
||||
for (uint j = 0; j < yoloCountInputs; ++j) {
|
||||
TensorInfo& curYoloTensor = m_YoloTensors.at(j);
|
||||
if (modelType == 3 || modelType == 4)
|
||||
if (modelType == 3 || modelType == 4 || modelType == 5)
|
||||
outputSize = curYoloTensor.numBBoxes;
|
||||
else
|
||||
outputSize += curYoloTensor.gridSizeX * curYoloTensor.gridSizeY * curYoloTensor.numBBoxes;
|
||||
@@ -587,6 +604,41 @@ Yolo::parseConfigBlocks()
|
||||
TensorInfo outputTensor;
|
||||
m_YoloTensors.push_back(outputTensor);
|
||||
}
|
||||
else if (block.at("type") == "detect_x") {
|
||||
++m_YoloCount;
|
||||
TensorInfo outputTensor;
|
||||
|
||||
std::vector<int> strides;
|
||||
|
||||
std::string stridesString = block.at("strides");
|
||||
while (!stridesString.empty()) {
|
||||
int npos = stridesString.find_first_of(',');
|
||||
if (npos != -1) {
|
||||
int stride = std::stof(trim(stridesString.substr(0, npos)));
|
||||
strides.push_back(stride);
|
||||
stridesString.erase(0, npos + 1);
|
||||
}
|
||||
else {
|
||||
int stride = std::stof(trim(stridesString));
|
||||
strides.push_back(stride);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint i = 0; i < strides.size(); ++i) {
|
||||
int num_grid_y = m_InputH / strides[i];
|
||||
int num_grid_x = m_InputW / strides[i];
|
||||
for (int g1 = 0; g1 < num_grid_y; ++g1) {
|
||||
for (int g0 = 0; g0 < num_grid_x; ++g0) {
|
||||
outputTensor.anchors.push_back((float) g0);
|
||||
outputTensor.anchors.push_back((float) g1);
|
||||
outputTensor.mask.push_back(strides[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m_YoloTensors.push_back(outputTensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ __global__ void gpuYoloLayer_v8(const float* input, int* num_detections, float*
|
||||
int maxIndex = -1;
|
||||
|
||||
for (uint i = 0; i < numOutputClasses; ++i) {
|
||||
float prob = input[x_id * (4 + numOutputClasses) + i + 4];
|
||||
float prob = input[x_id * (4 + numOutputClasses) + 4 + i];
|
||||
if (prob > maxProb) {
|
||||
maxProb = prob;
|
||||
maxIndex = i;
|
||||
|
||||
73
nvdsinfer_custom_impl_Yolo/yoloForward_x.cu
Normal file
73
nvdsinfer_custom_impl_Yolo/yoloForward_x.cu
Normal file
@@ -0,0 +1,73 @@
|
||||
/*
|
||||
* Created by Marcos Luciano
|
||||
* https://www.github.com/marcoslucianops
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
__global__ void gpuYoloLayer_x(const float* input, int* num_detections, float* detection_boxes, float* detection_scores,
|
||||
int* detection_classes, const float scoreThreshold, const uint netWidth, const uint netHeight,
|
||||
const uint numOutputClasses, const uint64_t outputSize, const float* anchors, const int* mask)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (x_id >= outputSize)
|
||||
return;
|
||||
|
||||
const float objectness = input[x_id * (5 + numOutputClasses) + 4];
|
||||
|
||||
if (objectness < scoreThreshold)
|
||||
return;
|
||||
|
||||
int count = (int)atomicAdd(num_detections, 1);
|
||||
|
||||
float x = (input[x_id * (5 + numOutputClasses) + 0] + anchors[x_id * 2]) * mask[x_id];
|
||||
|
||||
float y = (input[x_id * (5 + numOutputClasses) + 1] + anchors[x_id * 2 + 1]) * mask[x_id];
|
||||
|
||||
float w = __expf(input[x_id * (5 + numOutputClasses) + 2]) * mask[x_id];
|
||||
|
||||
float h = __expf(input[x_id * (5 + numOutputClasses) + 3]) * mask[x_id];
|
||||
|
||||
float maxProb = 0.0f;
|
||||
int maxIndex = -1;
|
||||
|
||||
for (uint i = 0; i < numOutputClasses; ++i) {
|
||||
float prob = input[x_id * (5 + numOutputClasses) + 5 + i];
|
||||
if (prob > maxProb) {
|
||||
maxProb = prob;
|
||||
maxIndex = i;
|
||||
}
|
||||
}
|
||||
|
||||
detection_boxes[count * 4 + 0] = x - 0.5 * w;
|
||||
detection_boxes[count * 4 + 1] = y - 0.5 * h;
|
||||
detection_boxes[count * 4 + 2] = x + 0.5 * w;
|
||||
detection_boxes[count * 4 + 3] = y + 0.5 * h;
|
||||
detection_scores[count] = objectness * maxProb;
|
||||
detection_classes[count] = maxIndex;
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer_x(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& numOutputClasses, const void* anchors, const void* mask, cudaStream_t stream);
|
||||
|
||||
cudaError_t cudaYoloLayer_x(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& numOutputClasses, const void* anchors, const void* mask, cudaStream_t stream)
|
||||
{
|
||||
int threads_per_block = 16;
|
||||
int number_of_blocks = (outputSize / threads_per_block) + 1;
|
||||
|
||||
for (unsigned int batch = 0; batch < batchSize; ++batch) {
|
||||
gpuYoloLayer_x<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * (5 + numOutputClasses) * outputSize),
|
||||
reinterpret_cast<int*>(num_detections) + (batch),
|
||||
reinterpret_cast<float*>(detection_boxes) + (batch * 4 * outputSize),
|
||||
reinterpret_cast<float*>(detection_scores) + (batch * outputSize),
|
||||
reinterpret_cast<int*>(detection_classes) + (batch * outputSize),
|
||||
scoreThreshold, netWidth, netHeight, numOutputClasses, outputSize, reinterpret_cast<const float*>(anchors),
|
||||
reinterpret_cast<const int*>(mask));
|
||||
}
|
||||
return cudaGetLastError();
|
||||
}
|
||||
@@ -38,6 +38,10 @@ namespace {
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer_x(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& numOutputClasses, const void* anchors, const void* mask, cudaStream_t stream);
|
||||
|
||||
cudaError_t cudaYoloLayer_v8(const void* input, void* num_detections, void* detection_boxes, void* detection_scores,
|
||||
void* detection_classes, const uint& batchSize, uint64_t& outputSize, const float& scoreThreshold, const uint& netWidth,
|
||||
const uint& netHeight, const uint& numOutputClasses, cudaStream_t stream);
|
||||
@@ -158,7 +162,35 @@ YoloLayer::enqueue(int batchSize, void const* const* inputs, void* const* output
|
||||
CUDA_CHECK(cudaMemsetAsync((float*)detection_scores, 0, sizeof(float) * m_OutputSize * batchSize, stream));
|
||||
CUDA_CHECK(cudaMemsetAsync((int*)detection_classes, 0, sizeof(int) * m_OutputSize * batchSize, stream));
|
||||
|
||||
if (m_Type == 4) {
|
||||
if (m_Type == 5) {
|
||||
TensorInfo& curYoloTensor = m_YoloTensors.at(0);
|
||||
std::vector<float> anchors = curYoloTensor.anchors;
|
||||
std::vector<int> mask = curYoloTensor.mask;
|
||||
|
||||
void* v_anchors;
|
||||
void* v_mask;
|
||||
if (anchors.size() > 0) {
|
||||
float* f_anchors = anchors.data();
|
||||
CUDA_CHECK(cudaMalloc(&v_anchors, sizeof(float) * anchors.size()));
|
||||
CUDA_CHECK(cudaMemcpyAsync(v_anchors, f_anchors, sizeof(float) * anchors.size(), cudaMemcpyHostToDevice, stream));
|
||||
}
|
||||
if (mask.size() > 0) {
|
||||
int* f_mask = mask.data();
|
||||
CUDA_CHECK(cudaMalloc(&v_mask, sizeof(int) * mask.size()));
|
||||
CUDA_CHECK(cudaMemcpyAsync(v_mask, f_mask, sizeof(int) * mask.size(), cudaMemcpyHostToDevice, stream));
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaYoloLayer_x(inputs[0], num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
|
||||
m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, v_anchors, v_mask, stream));
|
||||
|
||||
if (anchors.size() > 0) {
|
||||
CUDA_CHECK(cudaFree(v_anchors));
|
||||
}
|
||||
if (mask.size() > 0) {
|
||||
CUDA_CHECK(cudaFree(v_mask));
|
||||
}
|
||||
}
|
||||
else if (m_Type == 4) {
|
||||
CUDA_CHECK(cudaYoloLayer_v8(inputs[0], num_detections, detection_boxes, detection_scores, detection_classes, batchSize,
|
||||
m_OutputSize, m_ScoreThreshold, m_NetWidth, m_NetHeight, m_NumClasses, stream));
|
||||
}
|
||||
|
||||
372
utils/gen_wts_yolox.py
Normal file
372
utils/gen_wts_yolox.py
Normal file
@@ -0,0 +1,372 @@
|
||||
import argparse
|
||||
import os
|
||||
import struct
|
||||
import torch
|
||||
from yolox.exp import get_exp
|
||||
|
||||
|
||||
class Layers(object):
|
||||
def __init__(self, size, fw, fc):
|
||||
self.blocks = [0 for _ in range(300)]
|
||||
self.current = -1
|
||||
|
||||
self.width = size[0] if len(size) == 1 else size[1]
|
||||
self.height = size[0]
|
||||
|
||||
self.backbone_outs = []
|
||||
self.fpn_feats = []
|
||||
self.pan_feats = []
|
||||
self.yolo_head = []
|
||||
|
||||
self.fw = fw
|
||||
self.fc = fc
|
||||
self.wc = 0
|
||||
|
||||
self.net()
|
||||
|
||||
def Conv(self, child):
|
||||
self.current += 1
|
||||
|
||||
if child._get_name() == 'DWConv':
|
||||
self.convolutional(child.dconv)
|
||||
self.convolutional(child.pconv)
|
||||
else:
|
||||
self.convolutional(child)
|
||||
|
||||
def Focus(self, child):
|
||||
self.current += 1
|
||||
|
||||
self.reorg()
|
||||
self.convolutional(child.conv)
|
||||
|
||||
def BaseConv(self, child, stage='', act=None):
|
||||
self.current += 1
|
||||
|
||||
self.convolutional(child, act=act)
|
||||
if stage == 'fpn':
|
||||
self.fpn_feats.append(self.current)
|
||||
|
||||
def CSPLayer(self, child, stage=''):
|
||||
self.current += 1
|
||||
|
||||
self.convolutional(child.conv2)
|
||||
self.route('-2')
|
||||
self.convolutional(child.conv1)
|
||||
idx = -3
|
||||
for m in child.m:
|
||||
if m.use_add:
|
||||
self.convolutional(m.conv1)
|
||||
if m.conv2._get_name() == 'DWConv':
|
||||
self.convolutional(m.conv2.dconv)
|
||||
self.convolutional(m.conv2.pconv)
|
||||
self.shortcut(-4)
|
||||
idx -= 4
|
||||
else:
|
||||
self.convolutional(m.conv2)
|
||||
self.shortcut(-3)
|
||||
idx -= 3
|
||||
else:
|
||||
self.convolutional(m.conv1)
|
||||
if m.conv2._get_name() == 'DWConv':
|
||||
self.convolutional(m.conv2.dconv)
|
||||
self.convolutional(m.conv2.pconv)
|
||||
idx -= 3
|
||||
else:
|
||||
self.convolutional(m.conv2)
|
||||
idx -= 2
|
||||
self.route('-1, %d' % idx)
|
||||
self.convolutional(child.conv3)
|
||||
if stage == 'backbone':
|
||||
self.backbone_outs.append(self.current)
|
||||
elif stage == 'pan':
|
||||
self.pan_feats.append(self.current)
|
||||
|
||||
def SPPBottleneck(self, child):
|
||||
self.current += 1
|
||||
|
||||
self.convolutional(child.conv1)
|
||||
self.maxpool(child.m[0])
|
||||
self.route('-2')
|
||||
self.maxpool(child.m[1])
|
||||
self.route('-4')
|
||||
self.maxpool(child.m[2])
|
||||
self.route('-6, -5, -3, -1')
|
||||
self.convolutional(child.conv2)
|
||||
|
||||
def Upsample(self, child):
|
||||
self.current += 1
|
||||
|
||||
self.upsample(child)
|
||||
|
||||
def Concat(self, route):
|
||||
self.current += 1
|
||||
|
||||
r = self.get_route(route)
|
||||
self.route('-1, %d' % r)
|
||||
|
||||
def Route(self, route):
|
||||
self.current += 1
|
||||
|
||||
if route > 0:
|
||||
r = self.get_route(route)
|
||||
self.route('%d' % r)
|
||||
else:
|
||||
self.route('%d' % route)
|
||||
|
||||
def RouteShuffleOut(self, route):
|
||||
self.current += 1
|
||||
|
||||
self.route(route)
|
||||
self.shuffle(reshape=['c', 'hw'])
|
||||
self.yolo_head.append(self.current)
|
||||
|
||||
def Detect(self, strides):
|
||||
self.current += 1
|
||||
|
||||
routes = self.yolo_head[::-1]
|
||||
|
||||
for i, route in enumerate(routes):
|
||||
routes[i] = self.get_route(route)
|
||||
self.route(str(routes)[1:-1], axis=1)
|
||||
self.shuffle(transpose1=[1, 0])
|
||||
self.yolo(strides)
|
||||
|
||||
def net(self):
|
||||
self.fc.write('[net]\n' +
|
||||
'width=%d\n' % self.width +
|
||||
'height=%d\n' % self.height +
|
||||
'channels=3\n' +
|
||||
'letter_box=1\n')
|
||||
|
||||
def reorg(self):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
self.fc.write('\n[reorg]\n')
|
||||
|
||||
def convolutional(self, cv, act=None, detect=False):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
self.get_state_dict(cv.state_dict())
|
||||
|
||||
if cv._get_name() == 'Conv2d':
|
||||
filters = cv.out_channels
|
||||
size = cv.kernel_size
|
||||
stride = cv.stride
|
||||
pad = cv.padding
|
||||
groups = cv.groups
|
||||
bias = cv.bias
|
||||
bn = False
|
||||
act = act if act is not None else 'linear'
|
||||
else:
|
||||
filters = cv.conv.out_channels
|
||||
size = cv.conv.kernel_size
|
||||
stride = cv.conv.stride
|
||||
pad = cv.conv.padding
|
||||
groups = cv.conv.groups
|
||||
bias = cv.conv.bias
|
||||
bn = True if hasattr(cv, 'bn') else False
|
||||
if act is None:
|
||||
act = self.get_activation(cv.act._get_name()) if hasattr(cv, 'act') else 'linear'
|
||||
|
||||
b = 'batch_normalize=1\n' if bn is True else ''
|
||||
g = 'groups=%d\n' % groups if groups > 1 else ''
|
||||
w = 'bias=0\n' if bias is None and bn is False else ''
|
||||
|
||||
self.fc.write('\n[convolutional]\n' +
|
||||
b +
|
||||
'filters=%d\n' % filters +
|
||||
'size=%s\n' % self.get_value(size) +
|
||||
'stride=%s\n' % self.get_value(stride) +
|
||||
'pad=%s\n' % self.get_value(pad) +
|
||||
g +
|
||||
w +
|
||||
'activation=%s\n' % act)
|
||||
|
||||
def route(self, layers, axis=0):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
a = 'axis=%d\n' % axis if axis != 0 else ''
|
||||
|
||||
self.fc.write('\n[route]\n' +
|
||||
'layers=%s\n' % layers +
|
||||
a)
|
||||
|
||||
def shortcut(self, r, ew='add', act='linear'):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
m = 'mode=mul\n' if ew == 'mul' else ''
|
||||
|
||||
self.fc.write('\n[shortcut]\n' +
|
||||
'from=%d\n' % r +
|
||||
m +
|
||||
'activation=%s\n' % act)
|
||||
|
||||
def maxpool(self, m):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
stride = m.stride
|
||||
size = m.kernel_size
|
||||
mode = m.ceil_mode
|
||||
|
||||
m = 'maxpool_up' if mode else 'maxpool'
|
||||
|
||||
self.fc.write('\n[%s]\n' % m +
|
||||
'stride=%d\n' % stride +
|
||||
'size=%d\n' % size)
|
||||
|
||||
def upsample(self, child):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
stride = child.scale_factor
|
||||
|
||||
self.fc.write('\n[upsample]\n' +
|
||||
'stride=%d\n' % stride)
|
||||
|
||||
def shuffle(self, reshape=None, transpose1=None, transpose2=None, route=None):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
r = 'reshape=%s\n' % ', '.join(str(x) for x in reshape) if reshape is not None else ''
|
||||
t1 = 'transpose1=%s\n' % ', '.join(str(x) for x in transpose1) if transpose1 is not None else ''
|
||||
t2 = 'transpose2=%s\n' % ', '.join(str(x) for x in transpose2) if transpose2 is not None else ''
|
||||
f = 'from=%d\n' % route if route is not None else ''
|
||||
|
||||
self.fc.write('\n[shuffle]\n' +
|
||||
r +
|
||||
t1 +
|
||||
t2 +
|
||||
f)
|
||||
|
||||
def yolo(self, strides):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
self.fc.write('\n[detect_x]\n' +
|
||||
'strides=%s\n' % str(strides)[1:-1])
|
||||
|
||||
def get_state_dict(self, state_dict):
|
||||
for k, v in state_dict.items():
|
||||
if 'num_batches_tracked' not in k:
|
||||
vr = v.reshape(-1).numpy()
|
||||
self.fw.write('{} {} '.format(k, len(vr)))
|
||||
for vv in vr:
|
||||
self.fw.write(' ')
|
||||
self.fw.write(struct.pack('>f', float(vv)).hex())
|
||||
self.fw.write('\n')
|
||||
self.wc += 1
|
||||
|
||||
def get_value(self, key):
|
||||
if type(key) == int:
|
||||
return key
|
||||
return key[0] if key[0] == key[1] else str(key)[1:-1]
|
||||
|
||||
def get_route(self, n):
|
||||
r = 0
|
||||
for i, b in enumerate(self.blocks):
|
||||
if i <= n:
|
||||
r += b
|
||||
else:
|
||||
break
|
||||
return r - 1
|
||||
|
||||
def get_activation(self, act):
|
||||
if act == 'Hardswish':
|
||||
return 'hardswish'
|
||||
elif act == 'LeakyReLU':
|
||||
return 'leaky'
|
||||
elif act == 'SiLU':
|
||||
return 'silu'
|
||||
return 'linear'
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='PyTorch YOLOX conversion')
|
||||
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pth) file path (required)')
|
||||
parser.add_argument('-e', '--exp', required=True, help='Input exp (.py) file path (required)')
|
||||
args = parser.parse_args()
|
||||
if not os.path.isfile(args.weights):
|
||||
raise SystemExit('Invalid weights file')
|
||||
if not os.path.isfile(args.exp):
|
||||
raise SystemExit('Invalid exp file')
|
||||
return args.weights, args.exp
|
||||
|
||||
|
||||
pth_file, exp_file = parse_args()
|
||||
|
||||
exp = get_exp(exp_file)
|
||||
model = exp.get_model()
|
||||
model.load_state_dict(torch.load(pth_file, map_location='cpu')['model'])
|
||||
model.to('cpu').eval()
|
||||
|
||||
model_name = exp.exp_name
|
||||
inference_size = (exp.input_size[1], exp.input_size[0])
|
||||
|
||||
backbone = model.backbone._get_name()
|
||||
head = model.head._get_name()
|
||||
|
||||
wts_file = model_name + '.wts' if 'yolox' in model_name else 'yolox_' + model_name + '.wts'
|
||||
cfg_file = model_name + '.cfg' if 'yolox' in model_name else 'yolox_' + model_name + '.cfg'
|
||||
|
||||
with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc:
|
||||
layers = Layers(inference_size, fw, fc)
|
||||
|
||||
if backbone == 'YOLOPAFPN':
|
||||
layers.fc.write('\n# YOLOPAFPN\n')
|
||||
|
||||
layers.Focus(model.backbone.backbone.stem)
|
||||
layers.Conv(model.backbone.backbone.dark2[0])
|
||||
layers.CSPLayer(model.backbone.backbone.dark2[1])
|
||||
layers.Conv(model.backbone.backbone.dark3[0])
|
||||
layers.CSPLayer(model.backbone.backbone.dark3[1], 'backbone')
|
||||
layers.Conv(model.backbone.backbone.dark4[0])
|
||||
layers.CSPLayer(model.backbone.backbone.dark4[1], 'backbone')
|
||||
layers.Conv(model.backbone.backbone.dark5[0])
|
||||
layers.SPPBottleneck(model.backbone.backbone.dark5[1])
|
||||
layers.CSPLayer(model.backbone.backbone.dark5[2], 'backbone')
|
||||
layers.BaseConv(model.backbone.lateral_conv0, 'fpn')
|
||||
layers.Upsample(model.backbone.upsample)
|
||||
layers.Concat(layers.backbone_outs[1])
|
||||
layers.CSPLayer(model.backbone.C3_p4)
|
||||
layers.BaseConv(model.backbone.reduce_conv1, 'fpn')
|
||||
layers.Upsample(model.backbone.upsample)
|
||||
layers.Concat(layers.backbone_outs[0])
|
||||
layers.CSPLayer(model.backbone.C3_p3, 'pan')
|
||||
layers.Conv(model.backbone.bu_conv2)
|
||||
layers.Concat(layers.fpn_feats[1])
|
||||
layers.CSPLayer(model.backbone.C3_n3, 'pan')
|
||||
layers.Conv(model.backbone.bu_conv1)
|
||||
layers.Concat(layers.fpn_feats[0])
|
||||
layers.CSPLayer(model.backbone.C3_n4, 'pan')
|
||||
layers.pan_feats = layers.pan_feats[::-1]
|
||||
else:
|
||||
raise SystemExit('Model not supported')
|
||||
|
||||
if head == 'YOLOXHead':
|
||||
layers.fc.write('\n# YOLOXHead\n')
|
||||
|
||||
for i, feat in enumerate(layers.pan_feats):
|
||||
idx = len(layers.pan_feats) - i - 1
|
||||
dw = True if model.head.cls_convs[idx][0]._get_name() == 'DWConv' else False
|
||||
if i > 0:
|
||||
layers.Route(feat)
|
||||
layers.BaseConv(model.head.stems[idx])
|
||||
layers.Conv(model.head.cls_convs[idx][0])
|
||||
layers.Conv(model.head.cls_convs[idx][1])
|
||||
layers.BaseConv(model.head.cls_preds[idx], act='logistic')
|
||||
if dw:
|
||||
layers.Route(-6)
|
||||
else:
|
||||
layers.Route(-4)
|
||||
layers.Conv(model.head.reg_convs[idx][0])
|
||||
layers.Conv(model.head.reg_convs[idx][1])
|
||||
layers.BaseConv(model.head.obj_preds[idx], act='logistic')
|
||||
layers.Route(-2)
|
||||
layers.BaseConv(model.head.reg_preds[idx])
|
||||
if dw:
|
||||
layers.RouteShuffleOut('-1, -3, -9')
|
||||
else:
|
||||
layers.RouteShuffleOut('-1, -3, -7')
|
||||
layers.Detect(model.head.strides)
|
||||
|
||||
else:
|
||||
raise SystemExit('Model not supported')
|
||||
|
||||
os.system('echo "%d" | cat - %s > temp && mv temp %s' % (layers.wc, wts_file, wts_file))
|
||||
Reference in New Issue
Block a user