Update utils
This commit is contained in:
@@ -45,8 +45,8 @@ def get_depth(x, gd):
|
||||
pt_file, yaml_file, model_width, model_height, model_channels, p6 = parse_args()
|
||||
|
||||
model_name = pt_file.split(".pt")[0]
|
||||
wts_file = model_name + ".wts"
|
||||
cfg_file = model_name + ".cfg"
|
||||
wts_file = model_name + ".wts" if "yolov5" in model_name else "yolov5_" + model_name + ".wts"
|
||||
cfg_file = model_name + ".cfg" if "yolov5" in model_name else "yolov5_" + model_name + ".cfg"
|
||||
|
||||
if yaml_file == "":
|
||||
yaml_file = "models/" + model_name + ".yaml"
|
||||
@@ -326,7 +326,7 @@ with open(cfg_file, "w") as c:
|
||||
layer += "size=1\n"
|
||||
layer += "stride=1\n"
|
||||
layer += "pad=1\n"
|
||||
layer += "filters=%d\n" % ((nc + 5) * 3)
|
||||
layer += "filters=%d\n" % ((nc + 5) * len(masks[i]))
|
||||
layer += "activation=logistic\n"
|
||||
blocks += 1
|
||||
layer += "\n[yolo]\n"
|
||||
|
||||
@@ -20,7 +20,7 @@ def parse_args():
|
||||
|
||||
pt_file, cfg_file = parse_args()
|
||||
|
||||
wts_file = pt_file.split(".pt")[0] + ".wts"
|
||||
wts_file = cfg_file.split(".cfg")[0] + ".wts"
|
||||
|
||||
device = select_device("cpu")
|
||||
model = Darknet(cfg_file).to(device)
|
||||
|
||||
Reference in New Issue
Block a user