Update YOLOv5 conversor
This commit is contained in:
@@ -59,8 +59,16 @@ elif not os.path.isfile(yaml_file):
|
||||
|
||||
device = select_device("cpu")
|
||||
model = torch.load(pt_file, map_location=device)["model"].float()
|
||||
|
||||
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]
|
||||
delattr(model.model[-1], "anchor_grid")
|
||||
model.model[-1].register_buffer("anchor_grid", anchor_grid)
|
||||
|
||||
model.to(device).eval()
|
||||
|
||||
anchors = ""
|
||||
masks = []
|
||||
|
||||
with open(wts_file, "w") as f:
|
||||
wts_write = ""
|
||||
conv_count = 0
|
||||
@@ -101,6 +109,17 @@ with open(wts_file, "w") as f:
|
||||
wts_write += struct.pack(">f", float(vv)).hex()
|
||||
wts_write += "\n"
|
||||
conv_count += 1
|
||||
elif "anchor_grid" in k:
|
||||
vr = v.cpu().numpy().tolist()
|
||||
a = v.reshape(-1).cpu().numpy().astype(int).tolist()
|
||||
anchors = str(a)[1:-1]
|
||||
num = 0
|
||||
for m in vr:
|
||||
mask = []
|
||||
for _ in range(len(m)):
|
||||
mask.append(num)
|
||||
num += 1
|
||||
masks.append(mask)
|
||||
f.write("{}\n".format(conv_count))
|
||||
f.write(wts_write)
|
||||
|
||||
@@ -109,9 +128,6 @@ with open(cfg_file, "w") as c:
|
||||
nc = 0
|
||||
depth_multiple = 0
|
||||
width_multiple = 0
|
||||
anchors = ""
|
||||
masks = []
|
||||
num = 0
|
||||
detections = []
|
||||
layers = []
|
||||
f = yaml.load(f,Loader=yaml.FullLoader)
|
||||
@@ -126,16 +142,6 @@ with open(cfg_file, "w") as c:
|
||||
depth_multiple = f[l]
|
||||
elif l == "width_multiple":
|
||||
width_multiple = f[l]
|
||||
elif l == "anchors":
|
||||
a = []
|
||||
for v in f[l]:
|
||||
a.extend(v)
|
||||
mask = []
|
||||
for _ in range(int(len(v) / 2)):
|
||||
mask.append(num)
|
||||
num += 1
|
||||
masks.append(mask)
|
||||
anchors = str(a)[1:-1]
|
||||
elif l == "backbone" or l == "head":
|
||||
for v in f[l]:
|
||||
if v[2] == "Conv":
|
||||
|
||||
Reference in New Issue
Block a user