Update YOLOv5 conversor
This commit is contained in:
@@ -59,8 +59,16 @@ elif not os.path.isfile(yaml_file):
|
|||||||
|
|
||||||
device = select_device("cpu")
|
device = select_device("cpu")
|
||||||
model = torch.load(pt_file, map_location=device)["model"].float()
|
model = torch.load(pt_file, map_location=device)["model"].float()
|
||||||
|
|
||||||
|
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]
|
||||||
|
delattr(model.model[-1], "anchor_grid")
|
||||||
|
model.model[-1].register_buffer("anchor_grid", anchor_grid)
|
||||||
|
|
||||||
model.to(device).eval()
|
model.to(device).eval()
|
||||||
|
|
||||||
|
anchors = ""
|
||||||
|
masks = []
|
||||||
|
|
||||||
with open(wts_file, "w") as f:
|
with open(wts_file, "w") as f:
|
||||||
wts_write = ""
|
wts_write = ""
|
||||||
conv_count = 0
|
conv_count = 0
|
||||||
@@ -101,6 +109,17 @@ with open(wts_file, "w") as f:
|
|||||||
wts_write += struct.pack(">f", float(vv)).hex()
|
wts_write += struct.pack(">f", float(vv)).hex()
|
||||||
wts_write += "\n"
|
wts_write += "\n"
|
||||||
conv_count += 1
|
conv_count += 1
|
||||||
|
elif "anchor_grid" in k:
|
||||||
|
vr = v.cpu().numpy().tolist()
|
||||||
|
a = v.reshape(-1).cpu().numpy().astype(int).tolist()
|
||||||
|
anchors = str(a)[1:-1]
|
||||||
|
num = 0
|
||||||
|
for m in vr:
|
||||||
|
mask = []
|
||||||
|
for _ in range(len(m)):
|
||||||
|
mask.append(num)
|
||||||
|
num += 1
|
||||||
|
masks.append(mask)
|
||||||
f.write("{}\n".format(conv_count))
|
f.write("{}\n".format(conv_count))
|
||||||
f.write(wts_write)
|
f.write(wts_write)
|
||||||
|
|
||||||
@@ -109,9 +128,6 @@ with open(cfg_file, "w") as c:
|
|||||||
nc = 0
|
nc = 0
|
||||||
depth_multiple = 0
|
depth_multiple = 0
|
||||||
width_multiple = 0
|
width_multiple = 0
|
||||||
anchors = ""
|
|
||||||
masks = []
|
|
||||||
num = 0
|
|
||||||
detections = []
|
detections = []
|
||||||
layers = []
|
layers = []
|
||||||
f = yaml.load(f,Loader=yaml.FullLoader)
|
f = yaml.load(f,Loader=yaml.FullLoader)
|
||||||
@@ -126,16 +142,6 @@ with open(cfg_file, "w") as c:
|
|||||||
depth_multiple = f[l]
|
depth_multiple = f[l]
|
||||||
elif l == "width_multiple":
|
elif l == "width_multiple":
|
||||||
width_multiple = f[l]
|
width_multiple = f[l]
|
||||||
elif l == "anchors":
|
|
||||||
a = []
|
|
||||||
for v in f[l]:
|
|
||||||
a.extend(v)
|
|
||||||
mask = []
|
|
||||||
for _ in range(int(len(v) / 2)):
|
|
||||||
mask.append(num)
|
|
||||||
num += 1
|
|
||||||
masks.append(mask)
|
|
||||||
anchors = str(a)[1:-1]
|
|
||||||
elif l == "backbone" or l == "head":
|
elif l == "backbone" or l == "head":
|
||||||
for v in f[l]:
|
for v in f[l]:
|
||||||
if v[2] == "Conv":
|
if v[2] == "Conv":
|
||||||
|
|||||||
Reference in New Issue
Block a user