Files
pipelines_orin/01_do_obj_det/model_runner.py
2025-07-02 12:08:56 -04:00

311 lines
11 KiB
Python

import sys
sys.path.insert(0, "/home/thebears/source/models/yolov7")
import time
import base64 as b64
from datetime import datetime
import cv2
import numpy as np
import json
from pymediainfo import MediaInfo
import inspect
import open_clip
import sys
import torch
import yaml
from models.experimental import attempt_load
from utils.general import check_img_size, non_max_suppression
from torchvision import transforms
import torch.nn.functional as F
import os
device = torch.device("cuda")
# %%
class ModelRunner:
def __init__(self):
self.pretrained_name = "webli"
self.model_name = "ViT-SO400M-16-SigLIP2-512"
self.det_root_path = "/home/thebears/source/model_weights"
def init_model_clip(self):
if hasattr(self, 'clip_preprocess'):
return
model_name = self.model_name
pretrained_name = self.pretrained_name
clip_model, _, clip_preprocess_og = open_clip.create_model_and_transforms(
model_name, pretrained=pretrained_name
)
tokenizer = open_clip.get_tokenizer("hf-hub:timm/" + model_name)
clip_model = clip_model.half().to(device)
clip_dtype = next(clip_model.parameters()).dtype
clip_img_size = clip_preprocess_og.transforms[0].size
clip_model.encode_image(
torch.rand(1, 3, *clip_img_size, dtype=clip_dtype, device=device))
clip_preprocess = transforms.Compose(
[clip_preprocess_og.transforms[x] for x in [0, 3]]
)
self.clip_model = clip_model
self.clip_preprocess_og = clip_preprocess_og
self.clip_tokenizer = tokenizer
self.clip_dtype = clip_dtype
self.clip_img_size = clip_img_size
self.clip_preprocess = clip_preprocess
def init_model_det(self):
if hasattr(self, 'det_model'):
return
det_root_path = self.det_root_path
det_model_weights_root = os.path.join(det_root_path, "yolov7")
det_model_weights_path = os.path.join(det_model_weights_root, "best.pt")
det_data_yaml_path = os.path.join(det_model_weights_root, "inaturalist.yaml")
det_model = attempt_load(det_model_weights_path, map_location=device)
det_model = det_model.half().to(device)
det_dtype = next(det_model.parameters()).dtype
det_imgsz = 1280
det_stride = int(det_model.stride.max())
det_imgsz = check_img_size(det_imgsz, s=det_stride)
_ = det_model(
torch.zeros(1, 3, det_imgsz, det_imgsz, dtype=det_dtype).to(device)
)
with open(det_data_yaml_path, "r") as ff:
det_model_info = yaml.safe_load(ff)
det_labels = det_model_info["names"]
self.det_dtype = det_dtype
self.det_imgsz = det_imgsz
self.det_stride = det_stride
self.det_model_info = det_model_info
self.det_labels = det_labels
self.det_model = det_model
def get_det_vid_preprocessor(self, vid_h, vid_w):
if not hasattr(self, "_det_vid_preprocessors"):
self._det_vid_preprocessors = dict()
self.curr_det_vid_preprocessor = None
dict_key = (vid_h, vid_w)
det_stride = self.det_stride
if dict_key in self._det_vid_preprocessors:
self.curr_det_vid_preprocessor = self._det_vid_preprocessors[dict_key]
return self.curr_det_vid_preprocessor
target_max = self.det_imgsz
if vid_h > vid_w:
target_h = target_max
target_w = target_max * vid_w / vid_h
elif vid_h == vid_w:
target_h = target_max
target_w = target_max
elif vid_h < vid_w:
target_h = target_max * vid_h / vid_w
target_w = target_max
target_h = int(target_h)
target_w = int(target_w)
pad_amt = [None, None, None, None]
if target_w % det_stride != 0:
off = det_stride - target_w % det_stride
new_w = target_w + off
pad_diff = new_w - target_w
pad_left = round(pad_diff / 2)
pad_right = pad_diff - pad_left
pad_amt[0] = pad_left
pad_amt[2] = pad_right
else:
pad_amt[0] = 0
pad_amt[2] = 0
if target_h % det_stride != 0:
off = det_stride - target_h % det_stride
new_h = target_h + off
pad_diff = new_h - target_h
pad_up = round(pad_diff / 2)
pad_down = pad_diff - pad_up
pad_amt[1] = pad_up
pad_amt[3] = pad_down
else:
pad_amt[1] = 0
pad_amt[3] = 0
det_vid_preprocess = transforms.Compose(
[transforms.Resize((target_h, target_w)), transforms.Pad(pad_amt, fill=127)]
)
self.target_h = target_h
self.target_w = target_w
self.pad_amt = pad_amt
self._det_vid_preprocessors[dict_key] = det_vid_preprocess
self.curr_det_vid_preprocessor = self._det_vid_preprocessors[dict_key]
return self.curr_det_vid_preprocessor
def score_frames_det(self, array_score, det_vid_preprocess=None):
det_model = self.det_model
if det_vid_preprocess is None:
det_vid_preprocess = self.curr_det_vid_preprocessor
frame_numbers = [x[0] for x in array_score]
frame_values = [x[1] for x in array_score]
frame_as_tensor = (
torch.from_numpy(np.stack(frame_values)[:, :, :, 0:3])
.to(torch.float16)
.to(device)
.permute([0, 3, 1, 2])
)
with torch.no_grad():
frame_for_model = det_vid_preprocess(frame_as_tensor).div(255)[
:, [2, 1, 0], :, :
]
det_preds = det_model(frame_for_model)[0]
det_pred_post_nms = non_max_suppression(det_preds, 0.25, 0.5)
det_cpu_pred = [x.detach().cpu().numpy() for x in det_pred_post_nms]
return {"det": det_cpu_pred, "fr#": frame_numbers}
def score_frames_clip(self, clip_array_score):
frame_numbers = [x[0] for x in clip_array_score]
frame_values = [x[1] for x in clip_array_score]
frame_as_tensor = (
torch.from_numpy(np.stack(frame_values)[:, :, :, 0:3])
.to(torch.float16)
.to(device)
.permute([0, 3, 1, 2])
)
with torch.no_grad():
frame_for_clip = self.clip_preprocess(frame_as_tensor[:, [0, 1, 2], :, :])
clip_pred = self.clip_model.encode_image(frame_for_clip).detach().cpu().numpy()
return {"clip": clip_pred, "fr#": frame_numbers}
def get_video_info(self, file_path):
file_info = MediaInfo.parse(file_path)
video_info = None
frame_count = 0
if len(file_info.video_tracks) > 0:
video_info = file_info.video_tracks[0]
video_info.frame_count = int(video_info.frame_count)
return video_info
def score_video(self, file_to_score, batch_size = 6, clip_interval = 10):
video_info = self.get_video_info(file_to_score)
vid_decoder = "h264parse"
if video_info.format.lower() == "HEVC".lower():
vid_decoder = "h265parse"
gst_cmd = "filesrc location={file_to_score} ! qtdemux name=demux demux.video_0 ! queue ! {vid_decoder} ! nvv4l2decoder ! nvvidconv ! videoscale method=1 add-borders=false ! video/x-raw,width=1280,height=1280 ! appsink sync=false".format(
file_to_score=file_to_score, vid_decoder=vid_decoder
)
cap_handle = cv2.VideoCapture(gst_cmd, cv2.CAP_GSTREAMER)
vid_h = video_info.height
vid_w = video_info.width
vid_preprocessor = self.get_det_vid_preprocessor(vid_h, vid_w)
target_w = self.target_w
target_h = self.target_h
pad_amt = self.pad_amt
array_score = list()
final_output = dict()
final_output["start_score_time"] = time.time()
final_output["num_frames"] = video_info.frame_count
st = time.time()
frame_numbers = list()
det_results = list()
clip_results = list()
clip_frame_numbers = list()
clip_array = list()
for i in range(video_info.frame_count):
success, frame_matrix = cap_handle.read()
if not success:
break
array_score.append((i, frame_matrix))
if len(array_score) >= batch_size:
score_result = self.score_frames_det(array_score, det_vid_preprocess = vid_preprocessor)
det_results.extend(score_result["det"])
frame_numbers.extend(score_result["fr#"])
array_score = list()
if not (i % clip_interval):
clip_score_result = self.score_frames_clip([(i, frame_matrix)])
clip_results.extend(clip_score_result["clip"])
clip_frame_numbers.extend(clip_score_result["fr#"])
if len(array_score) > 0:
score_result = self.score_frames_det(array_score, det_vid_preprocess = vid_preprocessor)
det_results.extend(score_result["det"])
frame_numbers.extend(score_result["fr#"])
cap_handle.release()
final_output["end_score_time"] = time.time()
final_output["video"] = {
"w": vid_w,
"h": vid_h,
"path": file_to_score,
"target_w": target_w,
"target_h": target_h,
"pad_amt": pad_amt,
}
try:
final_output["scoring_fps"] = final_output["num_frames"] / (
final_output["end_score_time"] - final_output["start_score_time"]
)
except Exception as e:
pass
final_output["scores"] = list()
clip_results_as_np = np.asarray(clip_results)
for frame_number, frame in zip(frame_numbers, det_results):
cframe_dict = dict()
cframe_dict["frame"] = frame_number
cframe_dict["detections"] = list()
for det in frame:
data = dict()
data["coords"] = [float(x) for x in list(det[0:4])]
data["score"] = float(det[4])
data["idx"] = int(det[5])
try:
data["name"] = det_labels[data["idx"]]
except:
data["name"] = "Code failed"
cframe_dict["detections"].append(data)
final_output["scores"].append(cframe_dict)
emb_dict = dict()
emb_dict["frame_numbers"] = clip_frame_numbers
emb_dict["array_size"] = clip_results_as_np.shape
emb_dict["array_dtype"] = str(clip_results_as_np.dtype)
emb_dict["array_binary"] = b64.b64encode(clip_results_as_np).decode()
final_output["embeds"] = emb_dict
return final_output