Update gen_wts

This commit is contained in:
Marcos Luciano
2022-06-28 01:14:44 -03:00
parent fdd9092284
commit 416a8e0108
3 changed files with 169 additions and 170 deletions

View File

@@ -7,39 +7,50 @@ from models.models import Darknet
def parse_args():
parser = argparse.ArgumentParser(description="PyTorch YOLOR conversion (main branch)")
parser.add_argument("-w", "--weights", required=True, help="Input weights (.pt) file path (required)")
parser.add_argument("-c", "--cfg", required=True, help="Input cfg (.cfg) file path (required)")
parser = argparse.ArgumentParser(description='PyTorch YOLOR conversion (main branch)')
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')
parser.add_argument('-c', '--cfg', help='Input cfg (.cfg) file path')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit("Invalid weights file")
raise SystemExit('Invalid weights file')
if not os.path.isfile(args.cfg):
raise SystemExit("Invalid cfg file")
raise SystemExit('Invalid cfg file')
return args.weights, args.cfg
pt_file, cfg_file = parse_args()
wts_file = "%s.wts" % cfg_file.rsplit("/")[1].split(".cfg")[0]
device = select_device("cpu")
model_name = os.path.basename(pt_file).split('.pt')[0]
wts_file = model_name + '.wts' if 'yolor' in model_name else 'yolor_' + model_name + '.wts'
new_cfg_file = model_name + '.cfg' if 'yolor' in model_name else 'yolor_' + model_name + '.cfg'
if cfg_file == '':
cfg_file = 'cfg/' + model_name + '.cfg'
if not os.path.isfile(cfg_file):
raise SystemExit('CFG file not found')
elif not os.path.isfile(cfg_file):
raise SystemExit('Invalid CFG file')
device = select_device('cpu')
model = Darknet(cfg_file).to(device)
model.load_state_dict(torch.load(pt_file, map_location=device)["model"])
model.load_state_dict(torch.load(pt_file, map_location=device)['model'])
model.to(device).eval()
with open(wts_file, "w") as f:
wts_write = ""
with open(wts_file, 'w') as f:
wts_write = ''
conv_count = 0
for k, v in model.state_dict().items():
if "num_batches_tracked" not in k:
if 'num_batches_tracked' not in k:
vr = v.reshape(-1).cpu().numpy()
wts_write += "{} {} ".format(k, len(vr))
wts_write += '{} {} '.format(k, len(vr))
for vv in vr:
wts_write += " "
wts_write += struct.pack(">f", float(vv)).hex()
wts_write += "\n"
wts_write += ' '
wts_write += struct.pack('>f', float(vv)).hex()
wts_write += '\n'
conv_count += 1
f.write("{}\n".format(conv_count))
f.write('{}\n'.format(conv_count))
f.write(wts_write)
os.system("cp %s ./" % cfg_file)
if not os.path.isfile(new_cfg_file):
os.system('cp %s %s' % (cfg_file, new_cfg_file))