Added YOLOv5 6.0 native support
This commit is contained in:
@@ -67,32 +67,63 @@ std::vector<float> loadWeights(const std::string weightsFilePath, const std::str
|
||||
{
|
||||
assert(fileExists(weightsFilePath));
|
||||
std::cout << "\nLoading pre-trained weights" << std::endl;
|
||||
std::ifstream file(weightsFilePath, std::ios_base::binary);
|
||||
assert(file.good());
|
||||
std::string line;
|
||||
|
||||
if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos)
|
||||
{
|
||||
// Remove 4 int32 bytes of data from the stream belonging to the header
|
||||
file.ignore(4 * 4);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Remove 5 int32 bytes of data from the stream belonging to the header
|
||||
file.ignore(4 * 5);
|
||||
}
|
||||
|
||||
std::vector<float> weights;
|
||||
char floatWeight[4];
|
||||
while (!file.eof())
|
||||
{
|
||||
file.read(floatWeight, 4);
|
||||
assert(file.gcount() == 4);
|
||||
weights.push_back(*reinterpret_cast<float*>(floatWeight));
|
||||
if (file.peek() == std::istream::traits_type::eof()) break;
|
||||
|
||||
if (weightsFilePath.find(".weights") != std::string::npos) {
|
||||
std::ifstream file(weightsFilePath, std::ios_base::binary);
|
||||
assert(file.good());
|
||||
std::string line;
|
||||
|
||||
if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos)
|
||||
{
|
||||
// Remove 4 int32 bytes of data from the stream belonging to the header
|
||||
file.ignore(4 * 4);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Remove 5 int32 bytes of data from the stream belonging to the header
|
||||
file.ignore(4 * 5);
|
||||
}
|
||||
|
||||
char floatWeight[4];
|
||||
while (!file.eof())
|
||||
{
|
||||
file.read(floatWeight, 4);
|
||||
assert(file.gcount() == 4);
|
||||
weights.push_back(*reinterpret_cast<float*>(floatWeight));
|
||||
if (file.peek() == std::istream::traits_type::eof()) break;
|
||||
}
|
||||
}
|
||||
|
||||
else if (weightsFilePath.find(".wts") != std::string::npos) {
|
||||
std::ifstream file(weightsFilePath);
|
||||
assert(file.good());
|
||||
int32_t count;
|
||||
file >> count;
|
||||
assert(count > 0 && "Invalid .wts file.");
|
||||
|
||||
uint32_t floatWeight;
|
||||
std::string name;
|
||||
uint32_t size;
|
||||
|
||||
while (count--) {
|
||||
file >> name >> std::dec >> size;
|
||||
for (uint32_t x = 0, y = size; x < y; ++x)
|
||||
{
|
||||
file >> std::hex >> floatWeight;
|
||||
weights.push_back(*reinterpret_cast<float *>(&floatWeight));
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
else {
|
||||
std::cerr << "File " << weightsFilePath << " is not supported" << std::endl;
|
||||
std::abort();
|
||||
}
|
||||
|
||||
std::cout << "Loading weights of " << networkType << " complete"
|
||||
<< std::endl;
|
||||
<< std::endl;
|
||||
std::cout << "Total weights read: " << weights.size() << std::endl;
|
||||
return weights;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user