yawc
This commit is contained in:
40
convert_to_onnx.py
Normal file
40
convert_to_onnx.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import sys
|
||||
sys.path.append('/home/thebears/Seafile/Designs/ML')
|
||||
from model import Model
|
||||
import torch
|
||||
device = 'cpu'
|
||||
|
||||
model_rt_path = '/home/thebears/Seafile/Designs/ML/inaturalist_models/models/'#0210701_202822.json
|
||||
newest_model = os.path.join(model_rt_path, max(os.listdir(model_rt_path)).replace('.pth',''))
|
||||
with open(newest_model + '.json','r') as nmj:
|
||||
model_json = json.load(nmj)
|
||||
|
||||
cats = model_json['categories']
|
||||
cats.sort(key=lambda x: x['new_id'])
|
||||
num_cat = len(cats) + 1
|
||||
model_type = model_json['model_type']
|
||||
model = Model(num_cat, model_type)
|
||||
labels = [x['name'] for x in cats]
|
||||
model.load_state_dict(
|
||||
torch.load(newest_model + '.pth', map_location = torch.device(device))
|
||||
)
|
||||
model.eval()
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
onnx_model_path = "models"
|
||||
onnx_model_name = "hbirds.onnx"
|
||||
os.makedirs(onnx_model_path, exist_ok=True)
|
||||
full_model_path = os.path.join(onnx_model_path, onnx_model_name)
|
||||
|
||||
|
||||
# model export into ONNX format
|
||||
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
||||
torch.onnx.export(model, x, full_model_path, opset_version = 12)
|
||||
# %%
|
||||
|
||||
import cv2
|
||||
opencv_net = cv2.dnn.readNetFromONNX(full_model_path)
|
||||
print("OpenCV model was successfully read. Layer IDs: \n", opencv_net.getLayerNames())
|
||||
Reference in New Issue
Block a user