diff --git a/utils/gen_wts_yoloV5.py b/utils/gen_wts_yoloV5.py index ff12353..cc01c85 100644 --- a/utils/gen_wts_yoloV5.py +++ b/utils/gen_wts_yoloV5.py @@ -45,8 +45,8 @@ def get_depth(x, gd): pt_file, yaml_file, model_width, model_height, model_channels, p6 = parse_args() model_name = pt_file.split(".pt")[0] -wts_file = model_name + ".wts" -cfg_file = model_name + ".cfg" +wts_file = model_name + ".wts" if "yolov5" in model_name else "yolov5_" + model_name + ".wts" +cfg_file = model_name + ".cfg" if "yolov5" in model_name else "yolov5_" + model_name + ".cfg" if yaml_file == "": yaml_file = "models/" + model_name + ".yaml" @@ -326,7 +326,7 @@ with open(cfg_file, "w") as c: layer += "size=1\n" layer += "stride=1\n" layer += "pad=1\n" - layer += "filters=%d\n" % ((nc + 5) * 3) + layer += "filters=%d\n" % ((nc + 5) * len(masks[i])) layer += "activation=logistic\n" blocks += 1 layer += "\n[yolo]\n" diff --git a/utils/gen_wts_yolor.py b/utils/gen_wts_yolor.py index 6358b72..b9e2011 100644 --- a/utils/gen_wts_yolor.py +++ b/utils/gen_wts_yolor.py @@ -20,7 +20,7 @@ def parse_args(): pt_file, cfg_file = parse_args() -wts_file = pt_file.split(".pt")[0] + ".wts" +wts_file = cfg_file.split(".cfg")[0] + ".wts" device = select_device("cpu") model = Darknet(cfg_file).to(device)