Update gen_wts_yoloV5.py
This commit is contained in:
@@ -120,11 +120,16 @@ model.model[-1].register_buffer("anchor_grid", anchor_grid)
|
|||||||
|
|
||||||
model.to(device).eval()
|
model.to(device).eval()
|
||||||
|
|
||||||
|
nc = 0
|
||||||
anchors = ""
|
anchors = ""
|
||||||
masks = []
|
masks = []
|
||||||
|
|
||||||
|
yolo_idx = 0
|
||||||
|
spp_idx = 0
|
||||||
|
|
||||||
for k, v in model.state_dict().items():
|
for k, v in model.state_dict().items():
|
||||||
if "anchor_grid" in k:
|
if "anchor_grid" in k:
|
||||||
|
yolo_idx = int(k.split(".")[1])
|
||||||
vr = v.cpu().numpy().tolist()
|
vr = v.cpu().numpy().tolist()
|
||||||
a = v.reshape(-1).cpu().numpy().astype(float).tolist()
|
a = v.reshape(-1).cpu().numpy().astype(float).tolist()
|
||||||
anchors = str(a)[1:-1]
|
anchors = str(a)[1:-1]
|
||||||
@@ -135,8 +140,9 @@ for k, v in model.state_dict().items():
|
|||||||
mask.append(num)
|
mask.append(num)
|
||||||
num += 1
|
num += 1
|
||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
|
elif ".%d.m.0.weight" % yolo_idx in k:
|
||||||
spp_idx = 0
|
vr = v.cpu().numpy().tolist()
|
||||||
|
nc = int((len(vr) / len(masks[0])) - 5)
|
||||||
|
|
||||||
with open(cfg_file, "w") as c:
|
with open(cfg_file, "w") as c:
|
||||||
with open(yaml_file, "r", encoding="utf-8") as f:
|
with open(yaml_file, "r", encoding="utf-8") as f:
|
||||||
@@ -145,16 +151,13 @@ with open(cfg_file, "w") as c:
|
|||||||
c.write("height=%d\n" % model_height)
|
c.write("height=%d\n" % model_height)
|
||||||
c.write("channels=%d\n" % model_channels)
|
c.write("channels=%d\n" % model_channels)
|
||||||
c.write("letter_box=1\n")
|
c.write("letter_box=1\n")
|
||||||
nc = 0
|
|
||||||
depth_multiple = 0
|
depth_multiple = 0
|
||||||
width_multiple = 0
|
width_multiple = 0
|
||||||
layers = []
|
layers = []
|
||||||
yoloLayers = YoloLayers()
|
yoloLayers = YoloLayers()
|
||||||
f = yaml.load(f, Loader=yaml.FullLoader)
|
f = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
for topic in f:
|
for topic in f:
|
||||||
if topic == "nc":
|
if topic == "depth_multiple":
|
||||||
nc = f[topic]
|
|
||||||
elif topic == "depth_multiple":
|
|
||||||
depth_multiple = f[topic]
|
depth_multiple = f[topic]
|
||||||
elif topic == "width_multiple":
|
elif topic == "width_multiple":
|
||||||
width_multiple = f[topic]
|
width_multiple = f[topic]
|
||||||
|
|||||||
Reference in New Issue
Block a user