57 lines
1.9 KiB
Python
57 lines
1.9 KiB
Python
import argparse
|
|
import os
|
|
import struct
|
|
import torch
|
|
from utils.torch_utils import select_device
|
|
from models.models import Darknet
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='PyTorch YOLOR conversion (main branch)')
|
|
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')
|
|
parser.add_argument('-c', '--cfg', help='Input cfg (.cfg) file path')
|
|
args = parser.parse_args()
|
|
if not os.path.isfile(args.weights):
|
|
raise SystemExit('Invalid weights file')
|
|
if not os.path.isfile(args.cfg):
|
|
raise SystemExit('Invalid cfg file')
|
|
return args.weights, args.cfg
|
|
|
|
|
|
pt_file, cfg_file = parse_args()
|
|
|
|
|
|
model_name = os.path.basename(pt_file).split('.pt')[0]
|
|
wts_file = model_name + '.wts' if 'yolor' in model_name else 'yolor_' + model_name + '.wts'
|
|
new_cfg_file = model_name + '.cfg' if 'yolor' in model_name else 'yolor_' + model_name + '.cfg'
|
|
|
|
if cfg_file == '':
|
|
cfg_file = 'cfg/' + model_name + '.cfg'
|
|
if not os.path.isfile(cfg_file):
|
|
raise SystemExit('CFG file not found')
|
|
elif not os.path.isfile(cfg_file):
|
|
raise SystemExit('Invalid CFG file')
|
|
|
|
device = select_device('cpu')
|
|
model = Darknet(cfg_file).to(device)
|
|
model.load_state_dict(torch.load(pt_file, map_location=device)['model'])
|
|
model.to(device).eval()
|
|
|
|
with open(wts_file, 'w') as f:
|
|
wts_write = ''
|
|
conv_count = 0
|
|
for k, v in model.state_dict().items():
|
|
if 'num_batches_tracked' not in k:
|
|
vr = v.reshape(-1).cpu().numpy()
|
|
wts_write += '{} {} '.format(k, len(vr))
|
|
for vv in vr:
|
|
wts_write += ' '
|
|
wts_write += struct.pack('>f', float(vv)).hex()
|
|
wts_write += '\n'
|
|
conv_count += 1
|
|
f.write('{}\n'.format(conv_count))
|
|
f.write(wts_write)
|
|
|
|
if not os.path.isfile(new_cfg_file):
|
|
os.system('cp %s %s' % (cfg_file, new_cfg_file))
|