Update gen_wts_yoloV5.py

This commit is contained in:
Marcos Luciano
2022-06-19 15:47:20 -03:00
parent 42abd4a80a
commit 0516a46138

View File

@@ -120,11 +120,16 @@ model.model[-1].register_buffer("anchor_grid", anchor_grid)
model.to(device).eval() model.to(device).eval()
nc = 0
anchors = "" anchors = ""
masks = [] masks = []
yolo_idx = 0
spp_idx = 0
for k, v in model.state_dict().items(): for k, v in model.state_dict().items():
if "anchor_grid" in k: if "anchor_grid" in k:
yolo_idx = int(k.split(".")[1])
vr = v.cpu().numpy().tolist() vr = v.cpu().numpy().tolist()
a = v.reshape(-1).cpu().numpy().astype(float).tolist() a = v.reshape(-1).cpu().numpy().astype(float).tolist()
anchors = str(a)[1:-1] anchors = str(a)[1:-1]
@@ -135,8 +140,9 @@ for k, v in model.state_dict().items():
mask.append(num) mask.append(num)
num += 1 num += 1
masks.append(mask) masks.append(mask)
elif ".%d.m.0.weight" % yolo_idx in k:
spp_idx = 0 vr = v.cpu().numpy().tolist()
nc = int((len(vr) / len(masks[0])) - 5)
with open(cfg_file, "w") as c: with open(cfg_file, "w") as c:
with open(yaml_file, "r", encoding="utf-8") as f: with open(yaml_file, "r", encoding="utf-8") as f:
@@ -145,16 +151,13 @@ with open(cfg_file, "w") as c:
c.write("height=%d\n" % model_height) c.write("height=%d\n" % model_height)
c.write("channels=%d\n" % model_channels) c.write("channels=%d\n" % model_channels)
c.write("letter_box=1\n") c.write("letter_box=1\n")
nc = 0
depth_multiple = 0 depth_multiple = 0
width_multiple = 0 width_multiple = 0
layers = [] layers = []
yoloLayers = YoloLayers() yoloLayers = YoloLayers()
f = yaml.load(f, Loader=yaml.FullLoader) f = yaml.load(f, Loader=yaml.FullLoader)
for topic in f: for topic in f:
if topic == "nc": if topic == "depth_multiple":
nc = f[topic]
elif topic == "depth_multiple":
depth_multiple = f[topic] depth_multiple = f[topic]
elif topic == "width_multiple": elif topic == "width_multiple":
width_multiple = f[topic] width_multiple = f[topic]