Added YOLOR native support

YOLOR-CSP
YOLOR-CSP*
YOLOR-CSP-X
YOLOR-CSP-X*
This commit is contained in:
unknown
2021-12-12 00:47:32 -03:00
parent 7761ca7a6b
commit e2257a81c0
12 changed files with 336 additions and 6 deletions

View File

@@ -8,7 +8,7 @@ from utils.torch_utils import select_device
def parse_args():
parser = argparse.ArgumentParser(description="PyTorch conversion")
parser = argparse.ArgumentParser(description="PyTorch YOLOv5 conversion")
parser.add_argument("-w", "--weights", required=True, help="Input weights (.pt) file path (required)")
parser.add_argument("-c", "--yaml", help="Input cfg (.yaml) file path")
parser.add_argument("-mw", "--width", help="Model width (default = 640 / 1280 [P6])")
@@ -76,7 +76,7 @@ with open(wts_file, "w") as f:
cv1 += "{} {} ".format(k, len(vr))
for vv in vr:
cv1 += " "
cv1 += struct.pack(">f" ,float(vv)).hex()
cv1 += struct.pack(">f", float(vv)).hex()
cv1 += "\n"
conv_count += 1
elif cv1 != "" and ".m." in k:
@@ -86,7 +86,7 @@ with open(wts_file, "w") as f:
cv3 += "{} {} ".format(k, len(vr))
for vv in vr:
cv3 += " "
cv3 += struct.pack(">f" ,float(vv)).hex()
cv3 += struct.pack(">f", float(vv)).hex()
cv3 += "\n"
cv3_idx = idx
conv_count += 1
@@ -98,7 +98,7 @@ with open(wts_file, "w") as f:
wts_write += "{} {} ".format(k, len(vr))
for vv in vr:
wts_write += " "
wts_write += struct.pack(">f" ,float(vv)).hex()
wts_write += struct.pack(">f", float(vv)).hex()
wts_write += "\n"
conv_count += 1
f.write("{}\n".format(conv_count))

43
utils/gen_wts_yolor.py Normal file
View File

@@ -0,0 +1,43 @@
import argparse
import os
import struct
import torch
from utils.torch_utils import select_device
from models.models import Darknet
def parse_args():
parser = argparse.ArgumentParser(description="PyTorch YOLOR conversion (main branch)")
parser.add_argument("-w", "--weights", required=True, help="Input weights (.pt) file path (required)")
parser.add_argument("-c", "--cfg", required=True, help="Input cfg (.cfg) file path (required)")
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit("Invalid weights file")
if not os.path.isfile(args.cfg):
raise SystemExit("Invalid cfg file")
return args.weights, args.cfg
pt_file, cfg_file = parse_args()
wts_file = pt_file.split(".pt")[0] + ".wts"
device = select_device("cpu")
model = Darknet(cfg_file).to(device)
model.load_state_dict(torch.load(pt_file, map_location=device)["model"])
model.to(device).eval()
with open(wts_file, "w") as f:
wts_write = ""
conv_count = 0
for k, v in model.state_dict().items():
if not "num_batches_tracked" in k:
vr = v.reshape(-1).cpu().numpy()
wts_write += "{} {} ".format(k, len(vr))
for vv in vr:
wts_write += " "
wts_write += struct.pack(">f", float(vv)).hex()
wts_write += "\n"
conv_count += 1
f.write("{}\n".format(conv_count))
f.write(wts_write)