diff --git a/README.md b/README.md index 6af5558..5694949 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,8 @@ NVIDIA DeepStream SDK 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 configuration for YOLO mod ------------------------------------- ### **Big update on DeepStream-Yolo** ------------------------------------- +### Important: please generate the ONNX model and the TensorRT engine again with the updated files +------------------------------------- ### Future updates @@ -149,7 +151,7 @@ sample = 1920x1080 video - Eval ``` -nms-iou-threshold = 0.6 (Darknet) / 0.65 (YOLOv5, YOLOv6, YOLOv7, YOLOR and YOLOX) / 0.7 (Paddle, YOLO-NAS and YOLOv8) +nms-iou-threshold = 0.6 (Darknet) / 0.65 (YOLOv5, YOLOv6, YOLOv7, YOLOR and YOLOX) / 0.7 (Paddle, YOLO-NAS, YOLOv8 and YOLOv7-u6) pre-cluster-threshold = 0.001 topk = 300 ``` @@ -164,40 +166,49 @@ topk = 300 #### Results -**NOTE**: * = PyTorch +**NOTE**: * = PyTorch. -**NOTE**: ** = The YOLOv4 is trained with the trainvalno5k set, so the mAP is high on val2017 test +**NOTE**: ** = The YOLOv4 is trained with the trainvalno5k set, so the mAP is high on val2017 test. -**NOTE**: The p3.2xlarge instance (AWS) seems to max out at 625-635 FPS on DeepStream even using lighter models +**NOTE**: The p3.2xlarge instance (AWS) seems to max out at 625-635 FPS on DeepStream even using lighter models. -| DeepStream | Precision | Resolution | IoU=0.5:0.95 | IoU=0.5 | IoU=0.75 | FPS
(without display) | -|:----------------:|:---------:|:----------:|:------------:|:-------:|:--------:|:--------------------------:| -| YOLO-NAS L | FP16 | 640 | 0.484 | 0.658 | 0.532 | 235.27 | -| YOLO-NAS M | FP16 | 640 | 0.480 | 0.651 | 0.524 | 287.39 | -| YOLO-NAS S | FP16 | 640 | 0.442 | 0.614 | 0.485 | 478.52 | -| PP-YOLOE+_x | FP16 | 640 | 0. | 0. | 0. | | -| PP-YOLOE+_l | FP16 | 640 | 0. | 0. | 0. | | -| PP-YOLOE+_m | FP16 | 640 | 0. | 0. | 0. | | -| PP-YOLOE+_s | FP16 | 640 | 0.424 | 0.594 | 0.464 | 476.13 | -| PP-YOLOE-s (400) | FP16 | 640 | 0.423 | 0.589 | 0.463 | 461.23 | -| YOLOX-x | FP16 | 640 | 0.447 | 0.616 | 0.483 | 125.40 | -| YOLOX-l | FP16 | 640 | 0.430 | 0.598 | 0.466 | 193.10 | -| YOLOX-m | FP16 | 640 | 0.397 | 0.566 | 0.431 | 298.61 | -| YOLOX-s | FP16 | 640 | 0.335 | 0.502 | 0.365 | 522.05 | -| YOLOX-s legacy | FP16 | 640 | 0.375 | 0.569 | 0.407 | 518.52 | -| YOLOX-Darknet | FP16 | 640 | 0.414 | 0.595 | 0.453 | 212.88 | -| YOLOX-Tiny | FP16 | 640 | 0.274 | 0.427 | 0.292 | 633.95 | -| YOLOX-Nano | FP16 | 640 | 0.212 | 0.342 | 0.222 | 633.04 | -| YOLOv8x | FP16 | 640 | 0.499 | 0.669 | 0.545 | 130.49 | -| YOLOv8l | FP16 | 640 | 0.491 | 0.660 | 0.535 | 180.75 | -| YOLOv8m | FP16 | 640 | 0.468 | 0.637 | 0.510 | 278.08 | -| YOLOv8s | FP16 | 640 | 0.415 | 0.578 | 0.453 | 493.45 | -| YOLOv8n | FP16 | 640 | 0.343 | 0.492 | 0.373 | 627.43 | -| YOLOv7 | FP16 | 640 | 0. | 0. | 0. | | -| YOLOv6s 3.0 | FP16 | 640 | 0. | 0. | 0. | | -| YOLOv5s 7.0 | FP16 | 640 | 0. | 0. | 0. | | -| YOLOv4 | FP16 | 640 | 0. | 0. | 0. | | -| YOLOv3 | FP16 | 640 | 0. | 0. | 0. | | +| DeepStream | Precision | Resolution | IoU=0.5:0.95 | IoU=0.5 | IoU=0.75 | FPS
(without display) | +|:------------------:|:---------:|:----------:|:------------:|:-------:|:--------:|:--------------------------:| +| YOLO-NAS L | FP16 | 640 | 0.484 | 0.658 | 0.532 | 235.27 | +| YOLO-NAS M | FP16 | 640 | 0.480 | 0.651 | 0.524 | 287.39 | +| YOLO-NAS S | FP16 | 640 | 0.442 | 0.614 | 0.485 | 478.52 | +| PP-YOLOE+_x | FP16 | 640 | 0.528 | 0.705 | 0.579 | 121.17 | +| PP-YOLOE+_l | FP16 | 640 | 0.511 | 0.686 | 0.557 | 191.82 | +| PP-YOLOE+_m | FP16 | 640 | 0.483 | 0.658 | 0.528 | 264.39 | +| PP-YOLOE+_s | FP16 | 640 | 0.424 | 0.594 | 0.464 | 476.13 | +| PP-YOLOE-s (400) | FP16 | 640 | 0.423 | 0.589 | 0.463 | 461.23 | +| YOLOX-x | FP16 | 640 | 0.447 | 0.616 | 0.483 | 125.40 | +| YOLOX-l | FP16 | 640 | 0.430 | 0.598 | 0.466 | 193.10 | +| YOLOX-m | FP16 | 640 | 0.397 | 0.566 | 0.431 | 298.61 | +| YOLOX-s | FP16 | 640 | 0.335 | 0.502 | 0.365 | 522.05 | +| YOLOX-s legacy | FP16 | 640 | 0.375 | 0.569 | 0.407 | 518.52 | +| YOLOX-Darknet | FP16 | 640 | 0.414 | 0.595 | 0.453 | 212.88 | +| YOLOX-Tiny | FP16 | 640 | 0.274 | 0.427 | 0.292 | 633.95 | +| YOLOX-Nano | FP16 | 640 | 0.212 | 0.342 | 0.222 | 633.04 | +| YOLOv8x | FP16 | 640 | 0.499 | 0.669 | 0.545 | 130.49 | +| YOLOv8l | FP16 | 640 | 0.491 | 0.660 | 0.535 | 180.75 | +| YOLOv8m | FP16 | 640 | 0.468 | 0.637 | 0.510 | 278.08 | +| YOLOv8s | FP16 | 640 | 0.415 | 0.578 | 0.453 | 493.45 | +| YOLOv8n | FP16 | 640 | 0.343 | 0.492 | 0.373 | 627.43 | +| YOLOv7-u6 | FP16 | 640 | 0.484 | 0.652 | 0.530 | 193.54 | +| YOLOv7x* | FP16 | 640 | 0.496 | 0.679 | 0.536 | 155.07 | +| YOLOv7* | FP16 | 640 | 0.476 | 0.660 | 0.518 | 226.01 | +| YOLOv7-Tiny Leaky* | FP16 | 640 | 0.345 | 0.516 | 0.372 | 626.23 | +| YOLOv7-Tiny Leaky* | FP16 | 416 | 0.328 | 0.493 | 0.349 | 633.90 | +| YOLOv6-L 4.0 | FP16 | 640 | 0.490 | 0.671 | 0.535 | 178.41 | +| YOLOv6-M 4.0 | FP16 | 640 | 0.460 | 0.635 | 0.502 | 293.39 | +| YOLOv6-S 4.0 | FP16 | 640 | 0.416 | 0.585 | 0.453 | 513.90 | +| YOLOv6-N 4.0 | FP16 | 640 | 0.349 | 0.503 | 0.378 | 633.37 | +| YOLOv5x 7.0 | FP16 | 640 | 0.471 | 0.652 | 0.513 | 149.93 | +| YOLOv5l 7.0 | FP16 | 640 | 0.455 | 0.637 | 0.497 | 235.55 | +| YOLOv5m 7.0 | FP16 | 640 | 0.421 | 0.604 | 0.459 | 351.69 | +| YOLOv5s 7.0 | FP16 | 640 | 0.344 | 0.529 | 0.372 | 618.13 | +| YOLOv5n 7.0 | FP16 | 640 | 0.247 | 0.414 | 0.257 | 629.66 | ## diff --git a/docs/YOLOv6.md b/docs/YOLOv6.md index d5dfd69..bc5b517 100644 --- a/docs/YOLOv6.md +++ b/docs/YOLOv6.md @@ -1,5 +1,7 @@ # YOLOv6 usage +**NOTE**: You need to change the branch of the YOLOv6 repo according to the version of the model you want to convert. + **NOTE**: The yaml file is not required. * [Convert model](#convert-model) @@ -29,17 +31,17 @@ Copy the `export_yoloV6.py` file from `DeepStream-Yolo/utils` directory to the ` #### 3. Download the model -Download the `pt` file from [YOLOv6](https://github.com/meituan/YOLOv6/releases/) releases (example for YOLOv6-S 3.0) +Download the `pt` file from [YOLOv6](https://github.com/meituan/YOLOv6/releases/) releases (example for YOLOv6-S 4.0) ``` -wget https://github.com/meituan/YOLOv6/releases/download/0.3.0/yolov6s.pt +wget https://github.com/meituan/YOLOv6/releases/download/0.4.0/yolov6s.pt ``` **NOTE**: You can use your custom model. #### 4. Convert model -Generate the ONNX model file (example for YOLOv6-S 3.0) +Generate the ONNX model file (example for YOLOv6-S 4.0) ``` python3 export_yoloV6.py -w yolov6s.pt --simplify @@ -122,7 +124,7 @@ Open the `DeepStream-Yolo` folder and compile the lib ### Edit the config_infer_primary_yoloV6 file -Edit the `config_infer_primary_yoloV6.txt` file according to your model (example for YOLOv6-S 3.0 with 80 classes) +Edit the `config_infer_primary_yoloV6.txt` file according to your model (example for YOLOv6-S 4.0 with 80 classes) ``` [property] diff --git a/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp b/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp index dfa929f..624cbee 100644 --- a/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp +++ b/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp @@ -73,22 +73,22 @@ addBBoxProposal(const float bx1, const float by1, const float bx2, const float b } static std::vector -decodeTensorYolo(const float* detection, const uint& outputSize, const uint& count, const uint& netW, const uint& netH, +decodeTensorYolo(const float* detection, const uint& outputSize, const uint& netW, const uint& netH, const std::vector& preclusterThreshold) { std::vector binfo; for (uint b = 0; b < outputSize; ++b) { - float maxProb = count == 6 ? detection[b * count + 4] : detection[b * count + 4] * detection[b * count + 6]; - int maxIndex = (int) detection[b * count + 5]; + float maxProb = detection[b * 6 + 4]; + int maxIndex = (int) detection[b * 6 + 5]; if (maxProb < preclusterThreshold[maxIndex]) continue; - float bxc = detection[b * count + 0]; - float byc = detection[b * count + 1]; - float bw = detection[b * count + 2]; - float bh = detection[b * count + 3]; + float bxc = detection[b * 6 + 0]; + float byc = detection[b * 6 + 1]; + float bw = detection[b * 6 + 2]; + float bh = detection[b * 6 + 3]; float bx1 = bxc - bw / 2; float by1 = byc - bh / 2; @@ -102,22 +102,22 @@ decodeTensorYolo(const float* detection, const uint& outputSize, const uint& cou } static std::vector -decodeTensorYoloE(const float* detection, const uint& outputSize, const uint& count, const uint& netW, const uint& netH, +decodeTensorYoloE(const float* detection, const uint& outputSize, const uint& netW, const uint& netH, const std::vector& preclusterThreshold) { std::vector binfo; for (uint b = 0; b < outputSize; ++b) { - float maxProb = count == 6 ? detection[b * count + 4] : detection[b * count + 4] * detection[b * count + 6]; - int maxIndex = (int) detection[b * count + 5]; + float maxProb = detection[b * 6 + 4]; + int maxIndex = (int) detection[b * 6 + 5]; if (maxProb < preclusterThreshold[maxIndex]) continue; - float bx1 = detection[b * count + 0]; - float by1 = detection[b * count + 1]; - float bx2 = detection[b * count + 2]; - float by2 = detection[b * count + 3]; + float bx1 = detection[b * 6 + 0]; + float by1 = detection[b * 6 + 1]; + float bx2 = detection[b * 6 + 2]; + float by2 = detection[b * 6 + 3]; addBBoxProposal(bx1, by1, bx2, by2, netW, netH, maxIndex, maxProb, binfo); } @@ -139,9 +139,8 @@ NvDsInferParseCustomYolo(std::vector const& outputLayersInfo const NvDsInferLayerInfo& layer = outputLayersInfo[0]; const uint outputSize = layer.inferDims.d[0]; - const uint count = layer.inferDims.d[1]; - std::vector outObjs = decodeTensorYolo((const float*) (layer.buffer), outputSize, count, + std::vector outObjs = decodeTensorYolo((const float*) (layer.buffer), outputSize, networkInfo.width, networkInfo.height, detectionParams.perClassPreclusterThreshold); objects.insert(objects.end(), outObjs.begin(), outObjs.end()); @@ -165,9 +164,8 @@ NvDsInferParseCustomYoloE(std::vector const& outputLayersInf const NvDsInferLayerInfo& layer = outputLayersInfo[0]; const uint outputSize = layer.inferDims.d[0]; - const uint count = layer.inferDims.d[1]; - std::vector outObjs = decodeTensorYoloE((const float*) (layer.buffer), outputSize, count, + std::vector outObjs = decodeTensorYoloE((const float*) (layer.buffer), outputSize, networkInfo.width, networkInfo.height, detectionParams.perClassPreclusterThreshold); objects.insert(objects.end(), outObjs.begin(), outObjs.end()); diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward.cu b/nvdsinfer_custom_impl_Yolo/yoloForward.cu index e455425..98fa2ff 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloForward.cu +++ b/nvdsinfer_custom_impl_Yolo/yoloForward.cu @@ -50,13 +50,12 @@ __global__ void gpuYoloLayer(const float* input, float* output, int* count, cons int _count = (int)atomicAdd(count, 1); - output[_count * 7 + 0] = xc; - output[_count * 7 + 1] = yc; - output[_count * 7 + 2] = w; - output[_count * 7 + 3] = h; - output[_count * 7 + 4] = maxProb; - output[_count * 7 + 5] = maxIndex; - output[_count * 7 + 6] = objectness; + output[_count * 6 + 0] = xc; + output[_count * 6 + 1] = yc; + output[_count * 6 + 2] = w; + output[_count * 6 + 3] = h; + output[_count * 6 + 4] = maxProb * objectness; + output[_count * 6 + 5] = maxIndex; } cudaError_t cudaYoloLayer(const void* input, void* output, void* count, const uint& batchSize, uint64_t& inputSize, @@ -76,7 +75,7 @@ cudaError_t cudaYoloLayer(const void* input, void* output, void* count, const ui for (unsigned int batch = 0; batch < batchSize; ++batch) { gpuYoloLayer<<>>( reinterpret_cast (input) + (batch * inputSize), - reinterpret_cast (output) + (batch * 7 * outputSize), + reinterpret_cast (output) + (batch * 6 * outputSize), reinterpret_cast (count) + (batch), netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes, scaleXY, reinterpret_cast (anchors), reinterpret_cast (mask)); diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu b/nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu index 125cee3..e3cbc7f 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu +++ b/nvdsinfer_custom_impl_Yolo/yoloForward_nc.cu @@ -47,13 +47,12 @@ __global__ void gpuYoloLayer_nc(const float* input, float* output, int* count, c int _count = (int)atomicAdd(count, 1); - output[_count * 7 + 0] = xc; - output[_count * 7 + 1] = yc; - output[_count * 7 + 2] = w; - output[_count * 7 + 3] = h; - output[_count * 7 + 4] = maxProb; - output[_count * 7 + 5] = maxIndex; - output[_count * 7 + 6] = objectness; + output[_count * 6 + 0] = xc; + output[_count * 6 + 1] = yc; + output[_count * 6 + 2] = w; + output[_count * 6 + 3] = h; + output[_count * 6 + 4] = maxProb * objectness; + output[_count * 6 + 5] = maxIndex; } cudaError_t cudaYoloLayer_nc(const void* input, void* output, void* count, const uint& batchSize, uint64_t& inputSize, @@ -73,7 +72,7 @@ cudaError_t cudaYoloLayer_nc(const void* input, void* output, void* count, const for (unsigned int batch = 0; batch < batchSize; ++batch) { gpuYoloLayer_nc<<>>( reinterpret_cast (input) + (batch * inputSize), - reinterpret_cast (output) + (batch * 7 * outputSize), + reinterpret_cast (output) + (batch * 6 * outputSize), reinterpret_cast (count) + (batch), netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes, scaleXY, reinterpret_cast (anchors), reinterpret_cast (mask)); diff --git a/nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu b/nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu index 5fb74de..c13a1f0 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu +++ b/nvdsinfer_custom_impl_Yolo/yoloForward_v2.cu @@ -68,13 +68,12 @@ __global__ void gpuRegionLayer(const float* input, float* softmax, float* output int _count = (int)atomicAdd(count, 1); - output[_count * 7 + 0] = xc; - output[_count * 7 + 1] = yc; - output[_count * 7 + 2] = w; - output[_count * 7 + 3] = h; - output[_count * 7 + 4] = maxProb; - output[_count * 7 + 5] = maxIndex; - output[_count * 7 + 6] = objectness; + output[_count * 6 + 0] = xc; + output[_count * 6 + 1] = yc; + output[_count * 6 + 2] = w; + output[_count * 6 + 3] = h; + output[_count * 6 + 4] = maxProb * objectness; + output[_count * 6 + 5] = maxIndex; } cudaError_t cudaRegionLayer(const void* input, void* softmax, void* output, void* count, const uint& batchSize, @@ -93,7 +92,7 @@ cudaError_t cudaRegionLayer(const void* input, void* softmax, void* output, void gpuRegionLayer<<>>( reinterpret_cast (input) + (batch * inputSize), reinterpret_cast (softmax) + (batch * inputSize), - reinterpret_cast (output) + (batch * 7 * outputSize), + reinterpret_cast (output) + (batch * 6 * outputSize), reinterpret_cast (count) + (batch), netWidth, netHeight, gridSizeX, gridSizeY, numOutputClasses, numBBoxes, reinterpret_cast (anchors)); diff --git a/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp b/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp index 6633d10..633d336 100644 --- a/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp +++ b/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp @@ -103,7 +103,7 @@ nvinfer1::Dims YoloLayer::getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) noexcept { assert(index == 0); - return nvinfer1::Dims{2, {static_cast(m_OutputSize), 7}}; + return nvinfer1::Dims{2, {static_cast(m_OutputSize), 6}}; } bool @@ -125,7 +125,7 @@ YoloLayer::enqueue(int batchSize, void const* const* inputs, void* const* output noexcept { void* output = outputs[0]; - CUDA_CHECK(cudaMemsetAsync((float*) output, 0, sizeof(float) * m_OutputSize * 7 * batchSize, stream)); + CUDA_CHECK(cudaMemsetAsync((float*) output, 0, sizeof(float) * m_OutputSize * 6 * batchSize, stream)); void* count = workspace; CUDA_CHECK(cudaMemsetAsync((int*) count, 0, sizeof(int) * batchSize, stream)); diff --git a/utils/export_yoloV5.py b/utils/export_yoloV5.py index fe403e5..4edbe78 100644 --- a/utils/export_yoloV5.py +++ b/utils/export_yoloV5.py @@ -19,7 +19,7 @@ class DeepStreamOutput(nn.Module): boxes = x[:, :, :4] objectness = x[:, :, 4:5] scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True) - return torch.cat((boxes, scores, classes, objectness), dim=2) + return torch.cat((boxes, scores * objectness, classes), dim=2) def suppress_warnings(): diff --git a/utils/export_yoloV6.py b/utils/export_yoloV6.py index b4c593b..dc51a23 100644 --- a/utils/export_yoloV6.py +++ b/utils/export_yoloV6.py @@ -6,20 +6,24 @@ import onnx import torch import torch.nn as nn from yolov6.utils.checkpoint import load_checkpoint -from yolov6.layers.common import RepVGGBlock, ConvModule, SiLU +from yolov6.layers.common import RepVGGBlock, SiLU from yolov6.models.effidehead import Detect +try: + from yolov6.layers.common import ConvModule +except ImportError: + from yolov6.layers.common import Conv as ConvModule + class DeepStreamOutput(nn.Module): def __init__(self): super().__init__() def forward(self, x): - print(x) boxes = x[:, :, :4] objectness = x[:, :, 4:5] scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True) - return torch.cat((boxes, scores, classes, objectness), dim=2) + return torch.cat((boxes, scores * objectness, classes), dim=2) def suppress_warnings(): diff --git a/utils/export_yoloV7.py b/utils/export_yoloV7.py index 73961ac..06e452f 100644 --- a/utils/export_yoloV7.py +++ b/utils/export_yoloV7.py @@ -19,7 +19,7 @@ class DeepStreamOutput(nn.Module): boxes = x[:, :, :4] objectness = x[:, :, 4:5] scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True) - return torch.cat((boxes, scores, classes, objectness), dim=2) + return torch.cat((boxes, scores * objectness, classes), dim=2) def suppress_warnings(): diff --git a/utils/export_yoloV7_u6.py b/utils/export_yoloV7_u6.py new file mode 100644 index 0000000..bd7f399 --- /dev/null +++ b/utils/export_yoloV7_u6.py @@ -0,0 +1,77 @@ +import os +import sys +import argparse +import warnings +import onnx +import torch +import torch.nn as nn +from models.experimental import attempt_load +from models.yolo import Detect, V6Detect, IV6Detect +from utils.torch_utils import select_device + + +class DeepStreamOutput(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x.transpose(1, 2) + boxes = x[:, :, :4] + scores, classes = torch.max(x[:, :, 4:], 2, keepdim=True) + return torch.cat((boxes, scores, classes), dim=2) + + +def suppress_warnings(): + warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) + warnings.filterwarnings('ignore', category=UserWarning) + warnings.filterwarnings('ignore', category=DeprecationWarning) + + +def yolov7_u6_export(weights, device): + model = attempt_load(weights, device=device, inplace=True, fuse=True) + model.eval() + for k, m in model.named_modules(): + if isinstance(m, (Detect, V6Detect, IV6Detect)): + m.inplace = False + m.dynamic = False + m.export = True + return model + + +def main(args): + suppress_warnings() + device = select_device('cpu') + model = yolov7_u6_export(args.weights, device) + + model = nn.Sequential(model, DeepStreamOutput()) + + img_size = args.size * 2 if len(args.size) == 1 else args.size + + onnx_input_im = torch.zeros(1, 3, *img_size).to(device) + onnx_output_file = os.path.basename(args.weights).split('.pt')[0] + '.onnx' + + torch.onnx.export(model, onnx_input_im, onnx_output_file, verbose=False, opset_version=args.opset, + do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes=None) + + if args.simplify: + import onnxsim + model_onnx = onnx.load(onnx_output_file) + model_onnx, _ = onnxsim.simplify(model_onnx) + onnx.save(model_onnx, onnx_output_file) + + +def parse_args(): + parser = argparse.ArgumentParser(description='DeepStream YOLOv7-u6 conversion') + parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)') + parser.add_argument('-s', '--size', nargs='+', type=int, default=[640], help='Inference size [H,W] (default [640])') + parser.add_argument('--opset', type=int, default=12, help='ONNX opset version') + parser.add_argument('--simplify', action='store_true', help='ONNX simplify model') + args = parser.parse_args() + if not os.path.isfile(args.weights): + raise SystemExit('Invalid weights file') + return args + + +if __name__ == '__main__': + args = parse_args() + sys.exit(main(args)) diff --git a/utils/export_yolor.py b/utils/export_yolor.py index f1b3125..b20c1e2 100644 --- a/utils/export_yolor.py +++ b/utils/export_yolor.py @@ -16,7 +16,7 @@ class DeepStreamOutput(nn.Module): boxes = x[:, :, :4] objectness = x[:, :, 4:5] scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True) - return torch.cat((boxes, scores, classes, objectness), dim=2) + return torch.cat((boxes, scores * objectness, classes), dim=2) def suppress_warnings(): diff --git a/utils/export_yolox.py b/utils/export_yolox.py index 0c08e40..f51c61c 100644 --- a/utils/export_yolox.py +++ b/utils/export_yolox.py @@ -18,7 +18,7 @@ class DeepStreamOutput(nn.Module): boxes = x[:, :, :4] objectness = x[:, :, 4:5] scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True) - return torch.cat((boxes, scores, classes, objectness), dim=2) + return torch.cat((boxes, scores * objectness, classes), dim=2) def suppress_warnings():