Added YOLOv5 6.0 native support

This commit is contained in:
unknown
2021-12-09 15:44:17 -03:00
parent dcc44b730c
commit bfd9268a31
8 changed files with 688 additions and 80 deletions

View File

@@ -67,32 +67,63 @@ std::vector<float> loadWeights(const std::string weightsFilePath, const std::str
{
assert(fileExists(weightsFilePath));
std::cout << "\nLoading pre-trained weights" << std::endl;
std::ifstream file(weightsFilePath, std::ios_base::binary);
assert(file.good());
std::string line;
if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos)
{
// Remove 4 int32 bytes of data from the stream belonging to the header
file.ignore(4 * 4);
}
else
{
// Remove 5 int32 bytes of data from the stream belonging to the header
file.ignore(4 * 5);
}
std::vector<float> weights;
char floatWeight[4];
while (!file.eof())
{
file.read(floatWeight, 4);
assert(file.gcount() == 4);
weights.push_back(*reinterpret_cast<float*>(floatWeight));
if (file.peek() == std::istream::traits_type::eof()) break;
if (weightsFilePath.find(".weights") != std::string::npos) {
std::ifstream file(weightsFilePath, std::ios_base::binary);
assert(file.good());
std::string line;
if (networkType.find("yolov2") != std::string::npos && networkType.find("yolov2-tiny") == std::string::npos)
{
// Remove 4 int32 bytes of data from the stream belonging to the header
file.ignore(4 * 4);
}
else
{
// Remove 5 int32 bytes of data from the stream belonging to the header
file.ignore(4 * 5);
}
char floatWeight[4];
while (!file.eof())
{
file.read(floatWeight, 4);
assert(file.gcount() == 4);
weights.push_back(*reinterpret_cast<float*>(floatWeight));
if (file.peek() == std::istream::traits_type::eof()) break;
}
}
else if (weightsFilePath.find(".wts") != std::string::npos) {
std::ifstream file(weightsFilePath);
assert(file.good());
int32_t count;
file >> count;
assert(count > 0 && "Invalid .wts file.");
uint32_t floatWeight;
std::string name;
uint32_t size;
while (count--) {
file >> name >> std::dec >> size;
for (uint32_t x = 0, y = size; x < y; ++x)
{
file >> std::hex >> floatWeight;
weights.push_back(*reinterpret_cast<float *>(&floatWeight));
};
}
}
else {
std::cerr << "File " << weightsFilePath << " is not supported" << std::endl;
std::abort();
}
std::cout << "Loading weights of " << networkType << " complete"
<< std::endl;
<< std::endl;
std::cout << "Total weights read: " << weights.size() << std::endl;
return weights;
}