yawc
This commit is contained in:
155
quantize_model.py
Normal file
155
quantize_model.py
Normal file
@@ -0,0 +1,155 @@
|
||||
|
||||
|
||||
import torchvision
|
||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||
from collections import defaultdict as ddict
|
||||
import json
|
||||
import torch
|
||||
from torchvision import datasets, transforms as T
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
sys.path.append('/home/thebears/Seafile/Designs/ML')
|
||||
import json
|
||||
import cv2
|
||||
import random
|
||||
|
||||
from model import Model
|
||||
import socket
|
||||
from torchvision.utils import draw_bounding_boxes
|
||||
import torch as t
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
|
||||
|
||||
|
||||
|
||||
no_cuda = socket.gethostname() == 'tree'
|
||||
device='cpu'
|
||||
model_path = '/home/thebears/Seafile/Designs/ML/inaturalist_models/models/hummingbird'#0210701_202822.json
|
||||
with open(model_path + '.json','r') as nmj:
|
||||
model_json = json.load(nmj)
|
||||
|
||||
cats = model_json['categories']
|
||||
cats.sort(key=lambda x: x['new_id'])
|
||||
num_cat = len(cats) + 1
|
||||
model_type = model_json['model_type']
|
||||
model = Model(num_cat, model_type)
|
||||
labels = [x['name'] for x in cats]
|
||||
model.load_state_dict(
|
||||
torch.load(model_path + '.pth', map_location = torch.device(device))
|
||||
)
|
||||
model.eval()
|
||||
# %%
|
||||
backend = "fbgemm"
|
||||
model.qconfig = torch.quantization.get_default_qconfig(backend)
|
||||
torch.backends.quantized.engine = backend
|
||||
model_static_quantized = torch.quantization.prepare(model, inplace=False)
|
||||
model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)
|
||||
|
||||
def print_model_size(mdl):
|
||||
torch.save(mdl.state_dict(), "tmp.pt")
|
||||
print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
|
||||
os.remove('tmp.pt')
|
||||
|
||||
print_model_size(model_static_quantized)
|
||||
# %%
|
||||
|
||||
|
||||
|
||||
|
||||
results = list()
|
||||
vid_path = '/srv/ftp/hummingbird/2021/07/28/Hummingbird_01_20210728063745.mp4'
|
||||
cap = cv2.VideoCapture(vid_path)
|
||||
frame_num = 0
|
||||
|
||||
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
step_frame = 15
|
||||
|
||||
import time
|
||||
idces = 0
|
||||
st = time.time()
|
||||
for frame_num in range(0, total_frames, step_frame):
|
||||
srcimg = cap.read()[1]
|
||||
print(frame_num)
|
||||
if srcimg is None:
|
||||
break
|
||||
|
||||
|
||||
image = srcimg[:, :, ::-1].copy()
|
||||
o = T.ToTensor()(image)
|
||||
img = o[None, :, :, :]
|
||||
with torch.no_grad():
|
||||
ou = model(img)
|
||||
|
||||
|
||||
print(ou)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
for i in range(step_frame):
|
||||
|
||||
img = cap.read()[1];
|
||||
if img is None:
|
||||
break
|
||||
# %%
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
et = time.time()
|
||||
|
||||
model(img)
|
||||
st = time.time()
|
||||
print(st-et)
|
||||
# %%
|
||||
|
||||
|
||||
img_use = img
|
||||
|
||||
|
||||
st = time.time()
|
||||
features = model.backbone(img_use)
|
||||
print(time.time() - st)
|
||||
|
||||
st = time.time()
|
||||
proposals = model.rpn(img_use, features)
|
||||
print(time.time() - st)
|
||||
|
||||
st = time.time()
|
||||
head = model.head(features, proposals)
|
||||
print(time.time() - st)
|
||||
|
||||
# %%
|
||||
# vid_path = '/srv/ftp/hummingbird/2021/06/27/Hummingbird_01_20210627101803.mp4'
|
||||
# import time
|
||||
# import cv2
|
||||
# video = cv2.VideoCapture(vid_path)
|
||||
|
||||
# total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
# # %%
|
||||
# st = time.time()
|
||||
|
||||
# while True:
|
||||
# ret, read = video.read()
|
||||
# if not ret:
|
||||
# break
|
||||
|
||||
# et = time.time()
|
||||
|
||||
# print(et-st)
|
||||
|
||||
# st = time.time()
|
||||
# frs = list()
|
||||
# for i in range(0,total_frames, 150):
|
||||
# video.set(cv2.CAP_PROP_POS_FRAMES, i)
|
||||
# ret, frame = video.read()
|
||||
# frs.append(frame)
|
||||
# et = time.time()
|
||||
# print(et-st)
|
||||
Reference in New Issue
Block a user