Files
inaturalist_pytorch_model/quantize_model.py
2021-09-27 16:02:11 -04:00

156 lines
3.2 KiB
Python

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from collections import defaultdict as ddict
import json
import torch
from torchvision import datasets, transforms as T
import numpy as np
import os
import sys
sys.path.append('/home/thebears/Seafile/Designs/ML')
import json
import cv2
import random
from model import Model
import socket
from torchvision.utils import draw_bounding_boxes
import torch as t
import matplotlib.pyplot as plt
import matplotlib
no_cuda = socket.gethostname() == 'tree'
device='cpu'
model_path = '/home/thebears/Seafile/Designs/ML/inaturalist_models/models/hummingbird'#0210701_202822.json
with open(model_path + '.json','r') as nmj:
model_json = json.load(nmj)
cats = model_json['categories']
cats.sort(key=lambda x: x['new_id'])
num_cat = len(cats) + 1
model_type = model_json['model_type']
model = Model(num_cat, model_type)
labels = [x['name'] for x in cats]
model.load_state_dict(
torch.load(model_path + '.pth', map_location = torch.device(device))
)
model.eval()
# %%
backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
model_static_quantized = torch.quantization.prepare(model, inplace=False)
model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)
def print_model_size(mdl):
torch.save(mdl.state_dict(), "tmp.pt")
print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
os.remove('tmp.pt')
print_model_size(model_static_quantized)
# %%
results = list()
vid_path = '/srv/ftp/hummingbird/2021/07/28/Hummingbird_01_20210728063745.mp4'
cap = cv2.VideoCapture(vid_path)
frame_num = 0
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
step_frame = 15
import time
idces = 0
st = time.time()
for frame_num in range(0, total_frames, step_frame):
srcimg = cap.read()[1]
print(frame_num)
if srcimg is None:
break
image = srcimg[:, :, ::-1].copy()
o = T.ToTensor()(image)
img = o[None, :, :, :]
with torch.no_grad():
ou = model(img)
print(ou)
for i in range(step_frame):
img = cap.read()[1];
if img is None:
break
# %%
et = time.time()
model(img)
st = time.time()
print(st-et)
# %%
img_use = img
st = time.time()
features = model.backbone(img_use)
print(time.time() - st)
st = time.time()
proposals = model.rpn(img_use, features)
print(time.time() - st)
st = time.time()
head = model.head(features, proposals)
print(time.time() - st)
# %%
# vid_path = '/srv/ftp/hummingbird/2021/06/27/Hummingbird_01_20210627101803.mp4'
# import time
# import cv2
# video = cv2.VideoCapture(vid_path)
# total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
# # %%
# st = time.time()
# while True:
# ret, read = video.read()
# if not ret:
# break
# et = time.time()
# print(et-st)
# st = time.time()
# frs = list()
# for i in range(0,total_frames, 150):
# video.set(cv2.CAP_PROP_POS_FRAMES, i)
# ret, frame = video.read()
# frs.append(frame)
# et = time.time()
# print(et-st)