New features and fixes
This commit is contained in:
@@ -35,39 +35,56 @@
|
||||
static bool
|
||||
getYoloNetworkInfo(NetworkInfo& networkInfo, const NvDsInferContextInitParams* initParams)
|
||||
{
|
||||
std::string yoloCfg = initParams->customNetworkConfigFilePath;
|
||||
std::string yoloType;
|
||||
std::string onnxWtsFilePath = initParams->onnxFilePath;
|
||||
std::string darknetWtsFilePath = initParams->modelFilePath;
|
||||
std::string darknetCfgFilePath = initParams->customNetworkConfigFilePath;
|
||||
|
||||
std::transform(yoloCfg.begin(), yoloCfg.end(), yoloCfg.begin(), [] (uint8_t c) {
|
||||
std::string yoloType = onnxWtsFilePath != "" ? "onnx" : "darknet";
|
||||
std::string modelName = yoloType == "onnx" ?
|
||||
onnxWtsFilePath.substr(0, onnxWtsFilePath.find(".onnx")).substr(onnxWtsFilePath.rfind("/") + 1) :
|
||||
darknetWtsFilePath.substr(0, darknetWtsFilePath.find(".weights")).substr(darknetWtsFilePath.rfind("/") + 1);
|
||||
|
||||
std::transform(modelName.begin(), modelName.end(), modelName.begin(), [] (uint8_t c) {
|
||||
return std::tolower(c);
|
||||
});
|
||||
|
||||
yoloType = yoloCfg.substr(0, yoloCfg.find(".cfg"));
|
||||
|
||||
networkInfo.inputBlobName = "input";
|
||||
networkInfo.networkType = yoloType;
|
||||
networkInfo.configFilePath = initParams->customNetworkConfigFilePath;
|
||||
networkInfo.wtsFilePath = initParams->modelFilePath;
|
||||
networkInfo.modelName = modelName;
|
||||
networkInfo.onnxWtsFilePath = onnxWtsFilePath;
|
||||
networkInfo.darknetWtsFilePath = darknetWtsFilePath;
|
||||
networkInfo.darknetCfgFilePath = darknetCfgFilePath;
|
||||
networkInfo.batchSize = initParams->maxBatchSize;
|
||||
networkInfo.implicitBatch = initParams->forceImplicitBatchDimension;
|
||||
networkInfo.int8CalibPath = initParams->int8CalibrationFilePath;
|
||||
networkInfo.deviceType = (initParams->useDLA ? "kDLA" : "kGPU");
|
||||
networkInfo.deviceType = initParams->useDLA ? "kDLA" : "kGPU";
|
||||
networkInfo.numDetectedClasses = initParams->numDetectedClasses;
|
||||
networkInfo.clusterMode = initParams->clusterMode;
|
||||
networkInfo.scaleFactor = initParams->networkScaleFactor;
|
||||
networkInfo.offsets = initParams->offsets;
|
||||
|
||||
if (initParams->networkMode == 0)
|
||||
if (initParams->networkMode == NvDsInferNetworkMode_FP32)
|
||||
networkInfo.networkMode = "FP32";
|
||||
else if (initParams->networkMode == 1)
|
||||
else if (initParams->networkMode == NvDsInferNetworkMode_INT8)
|
||||
networkInfo.networkMode = "INT8";
|
||||
else if (initParams->networkMode == 2)
|
||||
else if (initParams->networkMode == NvDsInferNetworkMode_FP16)
|
||||
networkInfo.networkMode = "FP16";
|
||||
|
||||
if (networkInfo.configFilePath.empty() || networkInfo.wtsFilePath.empty()) {
|
||||
std::cerr << "YOLO config file or weights file is not specified\n" << std::endl;
|
||||
return false;
|
||||
if (yoloType == "onnx") {
|
||||
if (!fileExists(networkInfo.onnxWtsFilePath)) {
|
||||
std::cerr << "ONNX model file does not exist\n" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!fileExists(networkInfo.configFilePath) || !fileExists(networkInfo.wtsFilePath)) {
|
||||
std::cerr << "YOLO config file or weights file is not exist\n" << std::endl;
|
||||
return false;
|
||||
else {
|
||||
if (!fileExists(networkInfo.darknetWtsFilePath)) {
|
||||
std::cerr << "Darknet weights file does not exist\n" << std::endl;
|
||||
return false;
|
||||
}
|
||||
else if (!fileExists(networkInfo.darknetCfgFilePath)) {
|
||||
std::cerr << "Darknet cfg file does not exist\n" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -99,7 +116,7 @@ NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder* const builder, nvinfer1::IBuilder
|
||||
Yolo yolo(networkInfo);
|
||||
cudaEngine = yolo.createEngine(builder, builderConfig);
|
||||
if (cudaEngine == nullptr) {
|
||||
std::cerr << "Failed to build CUDA engine on " << networkInfo.configFilePath << std::endl;
|
||||
std::cerr << "Failed to build CUDA engine" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user