Update gen_wts_yoloV5.py
This commit is contained in:
@@ -120,11 +120,16 @@ model.model[-1].register_buffer("anchor_grid", anchor_grid)
|
||||
|
||||
model.to(device).eval()
|
||||
|
||||
nc = 0
|
||||
anchors = ""
|
||||
masks = []
|
||||
|
||||
yolo_idx = 0
|
||||
spp_idx = 0
|
||||
|
||||
for k, v in model.state_dict().items():
|
||||
if "anchor_grid" in k:
|
||||
yolo_idx = int(k.split(".")[1])
|
||||
vr = v.cpu().numpy().tolist()
|
||||
a = v.reshape(-1).cpu().numpy().astype(float).tolist()
|
||||
anchors = str(a)[1:-1]
|
||||
@@ -135,8 +140,9 @@ for k, v in model.state_dict().items():
|
||||
mask.append(num)
|
||||
num += 1
|
||||
masks.append(mask)
|
||||
|
||||
spp_idx = 0
|
||||
elif ".%d.m.0.weight" % yolo_idx in k:
|
||||
vr = v.cpu().numpy().tolist()
|
||||
nc = int((len(vr) / len(masks[0])) - 5)
|
||||
|
||||
with open(cfg_file, "w") as c:
|
||||
with open(yaml_file, "r", encoding="utf-8") as f:
|
||||
@@ -145,16 +151,13 @@ with open(cfg_file, "w") as c:
|
||||
c.write("height=%d\n" % model_height)
|
||||
c.write("channels=%d\n" % model_channels)
|
||||
c.write("letter_box=1\n")
|
||||
nc = 0
|
||||
depth_multiple = 0
|
||||
width_multiple = 0
|
||||
layers = []
|
||||
yoloLayers = YoloLayers()
|
||||
f = yaml.load(f, Loader=yaml.FullLoader)
|
||||
for topic in f:
|
||||
if topic == "nc":
|
||||
nc = f[topic]
|
||||
elif topic == "depth_multiple":
|
||||
if topic == "depth_multiple":
|
||||
depth_multiple = f[topic]
|
||||
elif topic == "width_multiple":
|
||||
width_multiple = f[topic]
|
||||
|
||||
Reference in New Issue
Block a user