Fix INT8 calibrator
This commit is contained in:
@@ -73,6 +73,23 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
|
||||
parseConfigBlocks();
|
||||
orderParams(&m_OutputMasks);
|
||||
|
||||
std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType);
|
||||
std::vector<nvinfer1::Weights> trtWeights;
|
||||
|
||||
nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0);
|
||||
if (parseModel(*network) != NVDSINFER_SUCCESS) {
|
||||
network->destroy();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::cout << "Building the TensorRT Engine" << std::endl;
|
||||
|
||||
if (m_LetterBox == 1) {
|
||||
std::cout << "\nNOTE: letter_box is set in cfg file, make sure to set maintain-aspect-ratio=1 in config_infer file to get better accuracy\n" << std::endl;
|
||||
}
|
||||
|
||||
nvinfer1::IBuilderConfig *config = builder->createBuilderConfig();
|
||||
|
||||
if (m_NetworkMode == "INT8" && !fileExists(m_Int8CalibPath)) {
|
||||
assert(builder->platformHasFastInt8());
|
||||
#ifdef OPENCV
|
||||
@@ -92,31 +109,15 @@ nvinfer1::ICudaEngine *Yolo::createEngine (nvinfer1::IBuilder* builder)
|
||||
std::cerr << "INT8_CALIB_BATCH_SIZE not set" << std::endl;
|
||||
std::abort();
|
||||
}
|
||||
nvinfer1::int8EntroyCalibrator *calibrator = new nvinfer1::int8EntroyCalibrator(calib_batch_size, m_InputC, m_InputH, m_InputW, m_LetterBox, calib_image_list, m_Int8CalibPath);
|
||||
builder->setInt8Mode(true);
|
||||
builder->setInt8Calibrator(calibrator);
|
||||
nvinfer1::Int8EntropyCalibrator2 *calibrator = new nvinfer1::Int8EntropyCalibrator2(calib_batch_size, m_InputC, m_InputH, m_InputW, m_LetterBox, calib_image_list, m_Int8CalibPath);
|
||||
config->setFlag(nvinfer1::BuilderFlag::kINT8);
|
||||
config->setInt8Calibrator(calibrator);
|
||||
#else
|
||||
std::cerr << "OpenCV is required to run INT8 calibrator" << std::endl;
|
||||
std::abort();
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<float> weights = loadWeights(m_WtsFilePath, m_NetworkType);
|
||||
std::vector<nvinfer1::Weights> trtWeights;
|
||||
|
||||
nvinfer1::INetworkDefinition *network = builder->createNetworkV2(0);
|
||||
if (parseModel(*network) != NVDSINFER_SUCCESS) {
|
||||
network->destroy();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::cout << "Building the TensorRT Engine" << std::endl;
|
||||
|
||||
if (m_LetterBox == 1) {
|
||||
std::cout << "\nNOTE: letter_box is set in cfg file, make sure to set maintain-aspect-ratio=1 in config_infer file to get better accuracy\n" << std::endl;
|
||||
}
|
||||
|
||||
nvinfer1::IBuilderConfig *config = builder->createBuilderConfig();
|
||||
nvinfer1::ICudaEngine *engine = builder->buildEngineWithConfig(*network, *config);
|
||||
if (engine) {
|
||||
std::cout << "Building complete\n" << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user