From 0516a4613872b704ac78ab9325d4c115cd3f2dcb Mon Sep 17 00:00:00 2001 From: Marcos Luciano Date: Sun, 19 Jun 2022 15:47:20 -0300 Subject: [PATCH] Update gen_wts_yoloV5.py --- utils/gen_wts_yoloV5.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/utils/gen_wts_yoloV5.py b/utils/gen_wts_yoloV5.py index 24b49f6..6c176b6 100644 --- a/utils/gen_wts_yoloV5.py +++ b/utils/gen_wts_yoloV5.py @@ -120,11 +120,16 @@ model.model[-1].register_buffer("anchor_grid", anchor_grid) model.to(device).eval() +nc = 0 anchors = "" masks = [] +yolo_idx = 0 +spp_idx = 0 + for k, v in model.state_dict().items(): if "anchor_grid" in k: + yolo_idx = int(k.split(".")[1]) vr = v.cpu().numpy().tolist() a = v.reshape(-1).cpu().numpy().astype(float).tolist() anchors = str(a)[1:-1] @@ -135,8 +140,9 @@ for k, v in model.state_dict().items(): mask.append(num) num += 1 masks.append(mask) - -spp_idx = 0 + elif ".%d.m.0.weight" % yolo_idx in k: + vr = v.cpu().numpy().tolist() + nc = int((len(vr) / len(masks[0])) - 5) with open(cfg_file, "w") as c: with open(yaml_file, "r", encoding="utf-8") as f: @@ -145,16 +151,13 @@ with open(cfg_file, "w") as c: c.write("height=%d\n" % model_height) c.write("channels=%d\n" % model_channels) c.write("letter_box=1\n") - nc = 0 depth_multiple = 0 width_multiple = 0 layers = [] yoloLayers = YoloLayers() f = yaml.load(f, Loader=yaml.FullLoader) for topic in f: - if topic == "nc": - nc = f[topic] - elif topic == "depth_multiple": + if topic == "depth_multiple": depth_multiple = f[topic] elif topic == "width_multiple": width_multiple = f[topic]