This commit is contained in:
Marcos Luciano
2023-06-05 18:33:03 -03:00
parent 79d4a0a8cd
commit 9fd80c5248
25 changed files with 108 additions and 41 deletions

View File

@@ -62,6 +62,7 @@ getYoloNetworkInfo(NetworkInfo& networkInfo, const NvDsInferContextInitParams* i
networkInfo.clusterMode = initParams->clusterMode;
networkInfo.scaleFactor = initParams->networkScaleFactor;
networkInfo.offsets = initParams->offsets;
networkInfo.workspaceSize = initParams->workspaceSize;
if (initParams->networkMode == NvDsInferNetworkMode_FP32)
networkInfo.networkMode = "FP32";
@@ -101,6 +102,8 @@ NvDsInferCreateModelParser(const NvDsInferContextInitParams* initParams)
return new Yolo(networkInfo);
}
#else
#if NV_TENSORRT_MAJOR >= 8
extern "C" bool
NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder* const builder, nvinfer1::IBuilderConfig* const builderConfig,
const NvDsInferContextInitParams* const initParams, nvinfer1::DataType dataType, nvinfer1::ICudaEngine*& cudaEngine);
@@ -108,13 +111,29 @@ NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder* const builder, nvinfer1::IBuilder
extern "C" bool
NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder* const builder, nvinfer1::IBuilderConfig* const builderConfig,
const NvDsInferContextInitParams* const initParams, nvinfer1::DataType dataType, nvinfer1::ICudaEngine*& cudaEngine)
#else
extern "C" bool
NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder* const builder, const NvDsInferContextInitParams* const initParams,
nvinfer1::DataType dataType, nvinfer1::ICudaEngine*& cudaEngine);
extern "C" bool
NvDsInferYoloCudaEngineGet(nvinfer1::IBuilder* const builder, const NvDsInferContextInitParams* const initParams,
nvinfer1::DataType dataType, nvinfer1::ICudaEngine*& cudaEngine)
#endif
{
NetworkInfo networkInfo;
if (!getYoloNetworkInfo(networkInfo, initParams))
return false;
Yolo yolo(networkInfo);
#if NV_TENSORRT_MAJOR >= 8
cudaEngine = yolo.createEngine(builder, builderConfig);
#else
cudaEngine = yolo.createEngine(builder);
#endif
if (cudaEngine == nullptr) {
std::cerr << "Failed to build CUDA engine" << std::endl;
return false;