diff --git a/utils/gen_wts_yoloV5.py b/utils/gen_wts_yoloV5.py index 1f6ea3c..963428c 100644 --- a/utils/gen_wts_yoloV5.py +++ b/utils/gen_wts_yoloV5.py @@ -59,8 +59,16 @@ elif not os.path.isfile(yaml_file): device = select_device("cpu") model = torch.load(pt_file, map_location=device)["model"].float() + +anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None] +delattr(model.model[-1], "anchor_grid") +model.model[-1].register_buffer("anchor_grid", anchor_grid) + model.to(device).eval() +anchors = "" +masks = [] + with open(wts_file, "w") as f: wts_write = "" conv_count = 0 @@ -101,6 +109,17 @@ with open(wts_file, "w") as f: wts_write += struct.pack(">f", float(vv)).hex() wts_write += "\n" conv_count += 1 + elif "anchor_grid" in k: + vr = v.cpu().numpy().tolist() + a = v.reshape(-1).cpu().numpy().astype(int).tolist() + anchors = str(a)[1:-1] + num = 0 + for m in vr: + mask = [] + for _ in range(len(m)): + mask.append(num) + num += 1 + masks.append(mask) f.write("{}\n".format(conv_count)) f.write(wts_write) @@ -109,9 +128,6 @@ with open(cfg_file, "w") as c: nc = 0 depth_multiple = 0 width_multiple = 0 - anchors = "" - masks = [] - num = 0 detections = [] layers = [] f = yaml.load(f,Loader=yaml.FullLoader) @@ -126,16 +142,6 @@ with open(cfg_file, "w") as c: depth_multiple = f[l] elif l == "width_multiple": width_multiple = f[l] - elif l == "anchors": - a = [] - for v in f[l]: - a.extend(v) - mask = [] - for _ in range(int(len(v) / 2)): - mask.append(num) - num += 1 - masks.append(mask) - anchors = str(a)[1:-1] elif l == "backbone" or l == "head": for v in f[l]: if v[2] == "Conv":