Files
deepstream_yolo/utils/gen_wts_yolor.py
Marcos Luciano 2aa52a8e8c Minor fixes
2022-04-08 13:31:36 -03:00

44 lines
1.4 KiB
Python

import argparse
import os
import struct
import torch
from utils.torch_utils import select_device
from models.models import Darknet
def parse_args():
parser = argparse.ArgumentParser(description="PyTorch YOLOR conversion (main branch)")
parser.add_argument("-w", "--weights", required=True, help="Input weights (.pt) file path (required)")
parser.add_argument("-c", "--cfg", required=True, help="Input cfg (.cfg) file path (required)")
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit("Invalid weights file")
if not os.path.isfile(args.cfg):
raise SystemExit("Invalid cfg file")
return args.weights, args.cfg
pt_file, cfg_file = parse_args()
wts_file = cfg_file.split(".cfg")[0] + ".wts"
device = select_device("cpu")
model = Darknet(cfg_file).to(device)
model.load_state_dict(torch.load(pt_file, map_location=device)["model"])
model.to(device).eval()
with open(wts_file, "w") as f:
wts_write = ""
conv_count = 0
for k, v in model.state_dict().items():
if "num_batches_tracked" not in k:
vr = v.reshape(-1).cpu().numpy()
wts_write += "{} {} ".format(k, len(vr))
for vv in vr:
wts_write += " "
wts_write += struct.pack(">f", float(vv)).hex()
wts_write += "\n"
conv_count += 1
f.write("{}\n".format(conv_count))
f.write(wts_write)