From bfd9268a310370792746e25fded95004e1803776 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 9 Dec 2021 15:44:17 -0300 Subject: [PATCH] Added YOLOv5 6.0 native support --- config_infer_primary_yoloV5.txt | 24 ++ .../layers/activation_layer.cpp | 24 +- .../layers/convolutional_layer.cpp | 139 ++++--- .../layers/convolutional_layer.h | 1 + nvdsinfer_custom_impl_Yolo/utils.cpp | 75 ++-- nvdsinfer_custom_impl_Yolo/yolo.cpp | 28 +- readme.md | 133 ++++++- utils/gen_wts_yoloV5.py | 344 ++++++++++++++++++ 8 files changed, 688 insertions(+), 80 deletions(-) create mode 100644 config_infer_primary_yoloV5.txt create mode 100644 utils/gen_wts_yoloV5.py diff --git a/config_infer_primary_yoloV5.txt b/config_infer_primary_yoloV5.txt new file mode 100644 index 0000000..28f2a35 --- /dev/null +++ b/config_infer_primary_yoloV5.txt @@ -0,0 +1,24 @@ +[property] +gpu-id=0 +net-scale-factor=0.0039215697906911373 +model-color-format=0 +custom-network-config=yolov5n.cfg +model-file=yolov5n.wts +model-engine-file=model_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=4 +maintain-aspect-ratio=0 +parse-bbox-func-name=NvDsInferParseYolo +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so +engine-create-func-name=NvDsInferYoloCudaEngineGet + +[class-attrs-all] +pre-cluster-threshold=0.25 diff --git a/nvdsinfer_custom_impl_Yolo/layers/activation_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/activation_layer.cpp index d730fd2..a1ae957 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/activation_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/activation_layer.cpp @@ -12,7 +12,10 @@ nvinfer1::ILayer* activationLayer( nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network) { - if (activation == "relu") + if (activation == "linear") { + // Pass + } + else if (activation == "relu") { nvinfer1::IActivationLayer* relu = network->addActivation( *input, nvinfer1::ActivationType::kRELU); @@ -78,5 +81,24 @@ nvinfer1::ILayer* activationLayer( mish->setName(mishLayerName.c_str()); output = mish; } + else if (activation == "silu") + { + nvinfer1::IActivationLayer* sigmoid = network->addActivation( + *input, nvinfer1::ActivationType::kSIGMOID); + assert(sigmoid != nullptr); + std::string sigmoidLayerName = "sigmoid_" + std::to_string(layerIdx); + sigmoid->setName(sigmoidLayerName.c_str()); + nvinfer1::IElementWiseLayer* silu = network->addElementWise( + *sigmoid->getOutput(0), *input, + nvinfer1::ElementWiseOperation::kPROD); + assert(silu != nullptr); + std::string siluLayerName = "silu_" + std::to_string(layerIdx); + silu->setName(siluLayerName.c_str()); + output = silu; + } + else { + std::cerr << "Activation not supported: " << activation << std::endl; + std::abort(); + } return output; } \ No newline at end of file diff --git a/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.cpp b/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.cpp index abb0d32..1be7b3f 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.cpp +++ b/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.cpp @@ -12,6 +12,7 @@ nvinfer1::ILayer* convolutionalLayer( std::vector& weights, std::vector& trtWeights, int& weightPtr, + std::string weightsType, int& inputChannels, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network) @@ -56,57 +57,111 @@ nvinfer1::ILayer* convolutionalLayer( nvinfer1::Weights convWt{nvinfer1::DataType::kFLOAT, nullptr, size}; nvinfer1::Weights convBias{nvinfer1::DataType::kFLOAT, nullptr, bias}; - if (batchNormalize == false) - { - float* val = new float[filters]; - for (int i = 0; i < filters; ++i) + if (weightsType == "weights") { + if (batchNormalize == false) { - val[i] = weights[weightPtr]; - weightPtr++; + float* val = new float[filters]; + for (int i = 0; i < filters; ++i) + { + val[i] = weights[weightPtr]; + weightPtr++; + } + convBias.values = val; + trtWeights.push_back(convBias); + val = new float[size]; + for (int i = 0; i < size; ++i) + { + val[i] = weights[weightPtr]; + weightPtr++; + } + convWt.values = val; + trtWeights.push_back(convWt); } - convBias.values = val; - trtWeights.push_back(convBias); - val = new float[size]; - for (int i = 0; i < size; ++i) + else { - val[i] = weights[weightPtr]; - weightPtr++; + for (int i = 0; i < filters; ++i) + { + bnBiases.push_back(weights[weightPtr]); + weightPtr++; + } + for (int i = 0; i < filters; ++i) + { + bnWeights.push_back(weights[weightPtr]); + weightPtr++; + } + for (int i = 0; i < filters; ++i) + { + bnRunningMean.push_back(weights[weightPtr]); + weightPtr++; + } + for (int i = 0; i < filters; ++i) + { + bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5)); + weightPtr++; + } + float* val = new float[size]; + for (int i = 0; i < size; ++i) + { + val[i] = weights[weightPtr]; + weightPtr++; + } + convWt.values = val; + trtWeights.push_back(convWt); + trtWeights.push_back(convBias); } - convWt.values = val; - trtWeights.push_back(convWt); } - else - { - for (int i = 0; i < filters; ++i) + else { + if (batchNormalize == false) { - bnBiases.push_back(weights[weightPtr]); - weightPtr++; + float* val = new float[size]; + for (int i = 0; i < size; ++i) + { + val[i] = weights[weightPtr]; + weightPtr++; + } + convWt.values = val; + trtWeights.push_back(convWt); + val = new float[filters]; + for (int i = 0; i < filters; ++i) + { + val[i] = weights[weightPtr]; + weightPtr++; + } + convBias.values = val; + trtWeights.push_back(convBias); } - - for (int i = 0; i < filters; ++i) + else { - bnWeights.push_back(weights[weightPtr]); - weightPtr++; + float* val = new float[size]; + for (int i = 0; i < size; ++i) + { + val[i] = weights[weightPtr]; + weightPtr++; + } + convWt.values = val; + for (int i = 0; i < filters; ++i) + { + bnWeights.push_back(weights[weightPtr]); + weightPtr++; + } + for (int i = 0; i < filters; ++i) + { + bnBiases.push_back(weights[weightPtr]); + weightPtr++; + } + for (int i = 0; i < filters; ++i) + { + bnRunningMean.push_back(weights[weightPtr]); + weightPtr++; + } + for (int i = 0; i < filters; ++i) + { + bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5)); + weightPtr++; + } + trtWeights.push_back(convWt); + trtWeights.push_back(convBias); } - for (int i = 0; i < filters; ++i) - { - bnRunningMean.push_back(weights[weightPtr]); - weightPtr++; - } - for (int i = 0; i < filters; ++i) - { - bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5)); - weightPtr++; - } - float* val = new float[size]; - for (int i = 0; i < size; ++i) - { - val[i] = weights[weightPtr]; - weightPtr++; - } - convWt.values = val; - trtWeights.push_back(convWt); - trtWeights.push_back(convBias); } nvinfer1::IConvolutionLayer* conv = network->addConvolution( diff --git a/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.h b/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.h index b114493..a3e0bea 100644 --- a/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.h +++ b/nvdsinfer_custom_impl_Yolo/layers/convolutional_layer.h @@ -19,6 +19,7 @@ nvinfer1::ILayer* convolutionalLayer( std::vector& weights, std::vector& trtWeights, int& weightPtr, + std::string weightsType, int& inputChannels, nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network); diff --git a/nvdsinfer_custom_impl_Yolo/utils.cpp b/nvdsinfer_custom_impl_Yolo/utils.cpp index c89302b..2fd7911 100644 --- a/nvdsinfer_custom_impl_Yolo/utils.cpp +++ b/nvdsinfer_custom_impl_Yolo/utils.cpp @@ -67,32 +67,63 @@ std::vector loadWeights(const std::string weightsFilePath, const std::str { assert(fileExists(weightsFilePath)); std::cout << "\nLoading pre-trained weights" << std::endl; - std::ifstream file(weightsFilePath, std::ios_base::binary); - assert(file.good()); - std::string line; - - if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos) - { - // Remove 4 int32 bytes of data from the stream belonging to the header - file.ignore(4 * 4); - } - else - { - // Remove 5 int32 bytes of data from the stream belonging to the header - file.ignore(4 * 5); - } std::vector weights; - char floatWeight[4]; - while (!file.eof()) - { - file.read(floatWeight, 4); - assert(file.gcount() == 4); - weights.push_back(*reinterpret_cast(floatWeight)); - if (file.peek() == std::istream::traits_type::eof()) break; + + if (weightsFilePath.find(".weights") != std::string::npos) { + std::ifstream file(weightsFilePath, std::ios_base::binary); + assert(file.good()); + std::string line; + + if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos) + { + // Remove 4 int32 bytes of data from the stream belonging to the header + file.ignore(4 * 4); + } + else + { + // Remove 5 int32 bytes of data from the stream belonging to the header + file.ignore(4 * 5); + } + + char floatWeight[4]; + while (!file.eof()) + { + file.read(floatWeight, 4); + assert(file.gcount() == 4); + weights.push_back(*reinterpret_cast(floatWeight)); + if (file.peek() == std::istream::traits_type::eof()) break; + } } + + else if (weightsFilePath.find(".wts") != std::string::npos) { + std::ifstream file(weightsFilePath); + assert(file.good()); + int32_t count; + file >> count; + assert(count > 0 && "Invalid .wts file."); + + uint32_t floatWeight; + std::string name; + uint32_t size; + + while (count--) { + file >> name >> std::dec >> size; + for (uint32_t x = 0, y = size; x < y; ++x) + { + file >> std::hex >> floatWeight; + weights.push_back(*reinterpret_cast(&floatWeight)); + }; + } + } + + else { + std::cerr << "File " << weightsFilePath << " is not supported" << std::endl; + std::abort(); + } + std::cout << "Loading weights of " << networkType << " complete" - << std::endl; + << std::endl; std::cout << "Total weights read: " << weights.size() << std::endl; return weights; } diff --git a/nvdsinfer_custom_impl_Yolo/yolo.cpp b/nvdsinfer_custom_impl_Yolo/yolo.cpp index 8e7d9e0..a0fb9d1 100644 --- a/nvdsinfer_custom_impl_Yolo/yolo.cpp +++ b/nvdsinfer_custom_impl_Yolo/yolo.cpp @@ -73,9 +73,6 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder) parseConfigBlocks(); orderParams(&m_OutputMasks); - std::vector weights = loadWeights(m_WtsFilePath, m_NetworkType); - std::vector trtWeights; - nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0); if (parseModel(*network) != NVDSINFER_SUCCESS) { network->destroy(); @@ -134,7 +131,7 @@ NvDsInferStatus Yolo::parseModel(nvinfer1::INetworkDefinition& network) { destroyNetworkUtils(); std::vector weights = loadWeights(m_WtsFilePath, m_NetworkType); - std::cout << "Building YOLO network" << std::endl; + std::cout << "Building YOLO network\n" << std::endl; NvDsInferStatus status = buildYoloNetwork(weights, network); if (status == NVDSINFER_SUCCESS) { @@ -151,6 +148,15 @@ NvDsInferStatus Yolo::buildYoloNetwork( int weightPtr = 0; int channels = m_InputC; + std::string weightsType; + + if (m_WtsFilePath.find(".weights") != std::string::npos) { + weightsType = "weights"; + } + else { + weightsType = "wts"; + } + nvinfer1::ITensor* data = network.addInput(m_InputBlobName.c_str(), nvinfer1::DataType::kFLOAT, nvinfer1::Dims3{static_cast(m_InputC), @@ -171,7 +177,7 @@ NvDsInferStatus Yolo::buildYoloNetwork( else if (m_ConfigBlocks.at(i).at("type") == "convolutional") { std::string inputVol = dimsToString(previous->getDimensions()); - nvinfer1::ILayer* out = convolutionalLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, channels, previous, &network); + nvinfer1::ILayer* out = convolutionalLayer(i, m_ConfigBlocks.at(i), weights, m_TrtWeights, weightPtr, weightsType, channels, previous, &network); previous = out->getOutput(0); assert(previous != nullptr); channels = getNumChannels(previous); @@ -272,10 +278,10 @@ NvDsInferStatus Yolo::buildYoloNetwork( beta_nms = std::stof(m_ConfigBlocks.at(i).at("beta_nms")); } nvinfer1::IPluginV2* yoloPlugin - = new YoloLayer(m_OutputTensors.at(outputTensorCount).numBBoxes, - m_OutputTensors.at(outputTensorCount).numClasses, - m_OutputTensors.at(outputTensorCount).gridSizeX, - m_OutputTensors.at(outputTensorCount).gridSizeY, + = new YoloLayer(curYoloTensor.numBBoxes, + curYoloTensor.numClasses, + curYoloTensor.gridSizeX, + curYoloTensor.gridSizeY, 1, new_coords, scale_x_y, beta_nms, curYoloTensor.anchors, m_OutputMasks); @@ -436,7 +442,7 @@ void Yolo::parseConfigBlocks() m_LetterBox = 0; } } - else if ((block.at("type") == "region") || (block.at("type") == "yolo")) + else if ((block.at("type") == "region") || (block.at("type") == "yolo") || (block.at("type") == "detect")) { assert((block.find("num") != block.end()) && std::string("Missing 'num' param in " + block.at("type") + " layer").c_str()); @@ -466,9 +472,7 @@ void Yolo::parseConfigBlocks() } } - if (block.find("mask") != block.end()) { - std::string maskString = block.at("mask"); std::vector pMASKS; while (!maskString.empty()) diff --git a/readme.md b/readme.md index b02319e..1bd89f7 100644 --- a/readme.md +++ b/readme.md @@ -6,14 +6,12 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models * New documentation for multiple models * DeepStream tutorials -* Native PyTorch support (YOLOv5 and YOLOR) +* Native YOLOR support * Native PP-YOLO support * Models benchmark * GPU NMS * Dynamic batch-size -**NOTE**: The support for YOLOv5 was removed in this current update. If you want the old repo version, please use the commit 297e0e9 and DeepStream 5.1 requirements. - ### Improvements on this repository * Darknet CFG params parser (it doesn't need to edit nvdsparsebbox_Yolo.cpp or another file for native models) @@ -24,6 +22,7 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models * Support for convolutional groups * Support for INT8 calibration * Support for non square models +* **YOLOv5 6.0 native support** ## @@ -33,6 +32,7 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models * [Tested models](#tested-models) * [dGPU installation](#dgpu-installation) * [Basic usage](#basic-usage) +* [YOLOv5 usage](#yolov5-usage) * [INT8 calibration](#int8-calibration) * [Using your custom model](docs/customModels.md) @@ -48,9 +48,14 @@ NVIDIA DeepStream SDK 6.0 configuration for YOLO models * [NVIDIA DeepStream SDK 6.0](https://developer.nvidia.com/deepstream-sdk) * [DeepStream-Yolo](https://github.com/marcoslucianops/DeepStream-Yolo) +**For YOLOv5**: + +* [PyTorch >= 1.7.0](https://pytorch.org/get-started/locally/) + ## ### Tested models +* [YOLOv5 6.0](https://github.com/ultralytics/yolov5) [[pt]](https://github.com/ultralytics/yolov5/releases/tag/v6.0) * [YOLOv4x-Mish](https://github.com/AlexeyAB/darknet) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4x-mish.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4x-mish.weights)] * [YOLOv4-CSP](https://github.com/WongKinYiu/ScaledYOLOv4/tree/yolov4-csp) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-csp.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4-csp.weights)] * [YOLOv4](https://github.com/AlexeyAB/darknet) [[cfg](https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4.cfg)] [[weights](https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v3_optimal/yolov4.weights)] @@ -265,6 +270,7 @@ deepstream-app -c deepstream_app_config.txt **NOTE**: If you want to use YOLOv2 or YOLOv2-Tiny models, change the deepstream_app_config.txt file before run it ``` +... [primary-gie] enable=1 gpu-id=0 @@ -277,6 +283,127 @@ config-file=config_infer_primary_yoloV2.txt ## +### YOLOv5 usage + +#### 1. Copy gen_wts_yoloV5.py from DeepStream-Yolo/utils to ultralytics/yolov5 folder + +#### 2. Open the ultralytics/yolov5 folder + +#### 3. Download pt file from [ultralytics/yolov5](https://github.com/ultralytics/yolov5/releases/tag/v6.0) website (example for YOLOv5n) + +``` +wget https://github.com/ultralytics/yolov5/releases/download/v6.0/yolov5n.pt +``` + +#### 4. Generate cfg and wts files (example for YOLOv5n) + +``` +python3 gen_wts_yoloV5.py -w yolov5n.pt +``` + +#### 5. Copy generated cfg and wts files to DeepStream-Yolo folder + +#### 6. Open DeepStream-Yolo folder + +#### 7. Compile lib + +* x86 platform + +``` +CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo +``` + +* Jetson platform + +``` +CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo +``` + +#### 8. Edit config_infer_primary_yoloV5.txt for your model (example for YOLOv5n) + +``` +[property] +... +# 0=RGB, 1=BGR, 2=GRAYSCALE +model-color-format=0 +# CFG +custom-network-config=yolov5n.cfg +# WTS +model-file=yolov5n.wts +# Generated TensorRT model (will be created if it doesn't exist) +model-engine-file=model_b1_gpu0_fp32.engine +# Model labels file +labelfile-path=labels.txt +# Batch size +batch-size=1 +# 0=FP32, 1=INT8, 2=FP16 mode +network-mode=0 +# Number of classes in label file +num-detected-classes=80 +... +[class-attrs-all] +# CONF_THRESH +pre-cluster-threshold=0.25 +``` + +#### 8. Change the deepstream_app_config.txt file + +``` +... +[primary-gie] +enable=1 +gpu-id=0 +gie-unique-id=1 +nvbuf-memory-type=0 +config-file=config_infer_primary_yoloV5.txt +``` + +#### 9. Run + +``` +deepstream-app -c deepstream_app_config.txt +``` + +**NOTE**: For YOLOv5 P6 or custom models, check the gen_wts_yoloV5.py args and use them according to your model + +* Input weights (.pt) file path **(required)** + +``` +-w or --weights +``` + +* Input cfg (.yaml) file path + +``` +-c or --yaml +``` + +* Model width **(default = 640 / 1280 [P6])** + +``` +-mw or --width +``` + +* Model height **(default = 640 / 1280 [P6])** + +``` +-mh or --height +``` + +* Model channels **(default = 3)** + +``` +-mc or --channels +``` + +* P6 model + +``` +--p6 +``` + +## + ### INT8 calibration #### 1. Install OpenCV diff --git a/utils/gen_wts_yoloV5.py b/utils/gen_wts_yoloV5.py new file mode 100644 index 0000000..97adc72 --- /dev/null +++ b/utils/gen_wts_yoloV5.py @@ -0,0 +1,344 @@ +import argparse +import yaml +import math +import os +import struct +import torch +from utils.torch_utils import select_device + + +def parse_args(): + parser = argparse.ArgumentParser(description="PyTorch conversion") + parser.add_argument("-w", "--weights", required=True, help="Input weights (.pt) file path (required)") + parser.add_argument("-c", "--yaml", help="Input cfg (.yaml) file path") + parser.add_argument("-mw", "--width", help="Model width (default = 640 / 1280 [P6])") + parser.add_argument("-mh", "--height", help="Model height (default = 640 / 1280 [P6])") + parser.add_argument("-mc", "--channels", help="Model channels (default = 3)") + parser.add_argument("--p6", action="store_true", help="P6 model") + args = parser.parse_args() + if not os.path.isfile(args.weights): + raise SystemExit("Invalid weights file") + if not args.yaml: + args.yaml = "" + if not args.width: + args.width = 1280 if args.p6 else 640 + if not args.height: + args.height = 1280 if args.p6 else 640 + if not args.channels: + args.channels = 3 + return args.weights, args.yaml, args.width, args.height, args.channels, args.p6 + + +def get_width(x, gw, divisor=8): + return int(math.ceil((x * gw) / divisor)) * divisor + + +def get_depth(x, gd): + if x == 1: + return 1 + r = int(round(x * gd)) + if x * gd - int(x * gd) == 0.5 and int(x * gd) % 2 == 0: + r -= 1 + return max(r, 1) + + +pt_file, yaml_file, model_width, model_height, model_channels, p6 = parse_args() + +model_name = pt_file.split(".pt")[0] +wts_file = model_name + ".wts" +cfg_file = model_name + ".cfg" + +if yaml_file == "": + yaml_file = "models/" + model_name + ".yaml" + if not os.path.isfile(yaml_file): + yaml_file = "models/hub/" + model_name + ".yaml" + if not os.path.isfile(yaml_file): + raise SystemExit("YAML file not found") +elif not os.path.isfile(yaml_file): + raise SystemExit("Invalid YAML file") + +device = select_device("cpu") +model = torch.load(pt_file, map_location=device)["model"].float() +model.to(device).eval() + +with open(wts_file, "w") as f: + wts_write = "" + conv_count = 0 + cv1 = "" + cv3 = "" + cv3_idx = 0 + sppf_idx = 11 if p6 else 9 + for k, v in model.state_dict().items(): + if not "num_batches_tracked" in k and not "anchors" in k and not "anchor_grid" in k: + vr = v.reshape(-1).cpu().numpy() + idx = int(k.split(".")[1]) + if ".cv1." in k and not ".m." in k and idx != sppf_idx: + cv1 += "{} {} ".format(k, len(vr)) + for vv in vr: + cv1 += " " + cv1 += struct.pack(">f" ,float(vv)).hex() + cv1 += "\n" + conv_count += 1 + elif cv1 != "" and ".m." in k: + wts_write += cv1 + cv1 = "" + if ".cv3." in k: + cv3 += "{} {} ".format(k, len(vr)) + for vv in vr: + cv3 += " " + cv3 += struct.pack(">f" ,float(vv)).hex() + cv3 += "\n" + cv3_idx = idx + conv_count += 1 + elif cv3 != "" and cv3_idx != idx: + wts_write += cv3 + cv3 = "" + cv3_idx = 0 + if not ".cv3." in k and not (".cv1." in k and not ".m." in k and idx != sppf_idx): + wts_write += "{} {} ".format(k, len(vr)) + for vv in vr: + wts_write += " " + wts_write += struct.pack(">f" ,float(vv)).hex() + wts_write += "\n" + conv_count += 1 + f.write("{}\n".format(conv_count)) + f.write(wts_write) + +with open(cfg_file, "w") as c: + with open(yaml_file, "r") as f: + nc = 0 + depth_multiple = 0 + width_multiple = 0 + anchors = "" + masks = [] + num = 0 + detections = [] + layers = [] + f = yaml.load(f,Loader=yaml.FullLoader) + c.write("[net]\n") + c.write("width=%d\n" % model_width) + c.write("height=%d\n" % model_height) + c.write("channels=%d\n" % model_channels) + for l in f: + if l == "nc": + nc = f[l] + elif l == "depth_multiple": + depth_multiple = f[l] + elif l == "width_multiple": + width_multiple = f[l] + elif l == "anchors": + a = [] + for v in f[l]: + a.extend(v) + mask = [] + for _ in range(int(len(v) / 2)): + mask.append(num) + num += 1 + masks.append(mask) + anchors = str(a)[1:-1] + elif l == "backbone" or l == "head": + for v in f[l]: + if v[2] == "Conv": + layer = "" + blocks = 0 + layer += "\n[convolutional]\n" + layer += "batch_normalize=1\n" + layer += "filters=%d\n" % get_width(v[3][0], width_multiple) + layer += "size=%d\n" % v[3][1] + layer += "stride=%d\n" % v[3][2] + layer += "pad=1\n" + layer += "activation=silu\n" + blocks += 1 + layers.append([layer, blocks]) + elif v[2] == "C3": + layer = "" + blocks = 0 + layer += "\n# C3\n" + # SPLIT + layer += "\n[convolutional]\n" + layer += "batch_normalize=1\n" + layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple) + layer += "size=1\n" + layer += "stride=1\n" + layer += "pad=1\n" + layer += "activation=silu\n" + blocks += 1 + layer += "\n[route]\n" + layer += "layers=-2\n" + blocks += 1 + layer += "\n[convolutional]\n" + layer += "batch_normalize=1\n" + layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple) + layer += "size=1\n" + layer += "stride=1\n" + layer += "pad=1\n" + layer += "activation=silu\n" + blocks += 1 + # Residual Block + if len(v[3]) == 1 or v[3][1] == True: + for _ in range(get_depth(v[1], depth_multiple)): + layer += "\n[convolutional]\n" + layer += "batch_normalize=1\n" + layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple) + layer += "size=1\n" + layer += "stride=1\n" + layer += "pad=1\n" + layer += "activation=silu\n" + blocks += 1 + layer += "\n[convolutional]\n" + layer += "batch_normalize=1\n" + layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple) + layer += "size=3\n" + layer += "stride=1\n" + layer += "pad=1\n" + layer += "activation=silu\n" + blocks += 1 + layer += "\n[shortcut]\n" + layer += "from=-3\n" + layer += "activation=linear\n" + blocks += 1 + # Merge + layer += "\n[route]\n" + layer += "layers=-1, -%d\n" % (3 * get_depth(v[1], depth_multiple) + 3) + blocks += 1 + else: + for _ in range(get_depth(v[1], depth_multiple)): + layer += "\n[convolutional]\n" + layer += "batch_normalize=1\n" + layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple) + layer += "size=1\n" + layer += "stride=1\n" + layer += "pad=1\n" + layer += "activation=silu\n" + blocks += 1 + layer += "\n[convolutional]\n" + layer += "batch_normalize=1\n" + layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple) + layer += "size=3\n" + layer += "stride=1\n" + layer += "pad=1\n" + layer += "activation=silu\n" + blocks += 1 + # Merge + layer += "\n[route]\n" + layer += "layers=-1, -%d\n" % (2 * get_depth(v[1], depth_multiple) + 3) + blocks += 1 + # Transition + layer += "\n[convolutional]\n" + layer += "batch_normalize=1\n" + layer += "filters=%d\n" % get_width(v[3][0], width_multiple) + layer += "size=1\n" + layer += "stride=1\n" + layer += "pad=1\n" + layer += "activation=silu\n" + layer += "\n##########\n" + blocks += 1 + layers.append([layer, blocks]) + elif v[2] == "SPPF": + layer = "" + blocks = 0 + layer += "\n# SPPF\n" + layer += "\n[convolutional]\n" + layer += "batch_normalize=1\n" + layer += "filters=%d\n" % get_width(v[3][0] / 2, width_multiple) + layer += "size=1\n" + layer += "stride=1\n" + layer += "pad=1\n" + layer += "activation=silu\n" + blocks += 1 + layer += "\n[maxpool]\n" + layer += "stride=1\n" + layer += "size=%d\n" % v[3][1] + blocks += 1 + layer += "\n[route]\n" + layer += "layers=-2\n" + blocks += 1 + layer += "\n[maxpool]\n" + layer += "stride=1\n" + layer += "size=%d\n" % v[3][1] + blocks += 1 + layer += "\n[route]\n" + layer += "layers=-2\n" + blocks += 1 + layer += "\n[maxpool]\n" + layer += "stride=1\n" + layer += "size=%d\n" % v[3][1] + blocks += 1 + layer += "\n[route]\n" + layer += "layers=-1, -3, -5, -6\n" + blocks += 1 + layer += "\n[convolutional]\n" + layer += "batch_normalize=1\n" + layer += "filters=%d\n" % get_width(v[3][0], width_multiple) + layer += "size=1\n" + layer += "stride=1\n" + layer += "pad=1\n" + layer += "activation=silu\n" + layer += "\n##########\n" + blocks += 1 + layers.append([layer, blocks]) + elif v[2] == "nn.Upsample": + layer = "" + blocks = 0 + layer += "\n[upsample]\n" + layer += "stride=%d\n" % v[3][1] + blocks += 1 + layers.append([layer, blocks]) + elif v[2] == "Concat": + layer = "" + blocks = 0 + route = v[0][1] + r = 0 + if route > 0: + for i, item in enumerate(layers): + if i <= route: + r += item[1] + else: + break + else: + route = len(layers) + route + for i, item in enumerate(layers): + if i <= route: + r += item[1] + else: + break + layer += "\n# Concat\n" + layer += "\n[route]\n" + layer += "layers=-1, %d\n" % (r - 1) + layer += "\n##########\n" + blocks += 1 + layers.append([layer, blocks]) + elif v[2] == "Detect": + for i, n in enumerate(v[0]): + layer = "" + blocks = 0 + r = 0 + for j, item in enumerate(layers): + if j <= n: + r += item[1] + else: + break + layer += "\n# Detect\n" + layer += "\n[route]\n" + layer += "layers=%d\n" % (r - 1) + blocks += 1 + layer += "\n[convolutional]\n" + layer += "size=1\n" + layer += "stride=1\n" + layer += "pad=1\n" + layer += "filters=%d\n" % ((nc + 5) * 3) + layer += "activation=logistic\n" + blocks += 1 + layer += "\n[yolo]\n" + layer += "mask=%s\n" % str(masks[i])[1:-1] + layer += "anchors=%s\n" % anchors + layer += "classes=%d\n" % nc + layer += "num=%d\n" % num + layer += "scale_x_y=2.0\n" + layer += "beta_nms=0.6\n" + layer += "new_coords=1\n" + layer += "\n##########\n" + blocks += 1 + layers.append([layer, blocks]) + for layer in layers: + c.write(layer[0])