This commit is contained in:
2021-09-27 16:02:11 -04:00
parent 90edf9bd45
commit e18232df84
35 changed files with 3037 additions and 78 deletions

40
convert_to_onnx.py Normal file
View File

@@ -0,0 +1,40 @@
import sys
sys.path.append('/home/thebears/Seafile/Designs/ML')
from model import Model
import torch
device = 'cpu'
model_rt_path = '/home/thebears/Seafile/Designs/ML/inaturalist_models/models/'#0210701_202822.json
newest_model = os.path.join(model_rt_path, max(os.listdir(model_rt_path)).replace('.pth',''))
with open(newest_model + '.json','r') as nmj:
model_json = json.load(nmj)
cats = model_json['categories']
cats.sort(key=lambda x: x['new_id'])
num_cat = len(cats) + 1
model_type = model_json['model_type']
model = Model(num_cat, model_type)
labels = [x['name'] for x in cats]
model.load_state_dict(
torch.load(newest_model + '.pth', map_location = torch.device(device))
)
model.eval()
# %%
onnx_model_path = "models"
onnx_model_name = "hbirds.onnx"
os.makedirs(onnx_model_path, exist_ok=True)
full_model_path = os.path.join(onnx_model_path, onnx_model_name)
# model export into ONNX format
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
torch.onnx.export(model, x, full_model_path, opset_version = 12)
# %%
import cv2
opencv_net = cv2.dnn.readNetFromONNX(full_model_path)
print("OpenCV model was successfully read. Layer IDs: \n", opencv_net.getLayerNames())