This commit is contained in:
2025-04-17 15:55:54 -04:00
parent 5ca30d5e11
commit 79f9a1dbc2
12 changed files with 611 additions and 80 deletions

View File

@@ -1,132 +1,104 @@
do_load = True
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, Range, MatchText
import requests
from pymilvus import MilvusClient, DataType
from pymilvus.client.types import LoadState
from datetime import datetime, timedelta
import numpy as np
import traceback
from bottle import route, run, template, request, debug
import open_clip
import time
import prettyprinter
# %%
if do_load:
from lavis.models import load_model_and_preprocess, model_zoo
import torch
device = 'cpu'
model, vis_processors, txt_processors = load_model_and_preprocess("clip_feature_extractor", model_type="ViT-B-16", is_eval=True, device=device)
model.eval()
TIMEOUT=2
# %%
collection_name = "nuggets_so400m"
client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True, timeout=TIMEOUT)
#collection_name = "nuggets_{camera}_so400m_siglip2"
collection_name = "nuggets_{camera}_so400m"
client = MilvusClient(
uri="http://localhost:19530"
)
# %%
from bottle import route, run, template
@route('/get_text_match')
def get_matches():
# %%
valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'}
query = request.query.get('query','A large bird eating corn')
cameras = request.query.get('cameras','sidefeeder')
num_videos = int(request.query.get('num_videos',5))
cams = set(cameras.split(',')).intersection(valid_cameras)
max_age = int(request.query.get('age',5));
print({'Cameras':cams,'Max Age':max_age,'Query':query})
# %%
max_date = datetime.now()
min_date = max_date - timedelta(days=(max_age))
days_step = (max_date- min_date).days
date_arrays = list()
for i in range(days_step):
date_arrays.append(max_date - timedelta(days=i))
# %%
string_filter = list()
for cand_date in date_arrays:
string_filter.append(cand_date.strftime('%Y/%m/%d'))
day_strs = list()
# %%
for x in range(days_step):
day_strs.append( (min_date + timedelta(days=x)).strftime('%Y%m%d') )
str_insert = ','.join([f'"{x}"'for x in day_strs])
filter_string = 'date in [' + str_insert + ']'
should_list = list()
if max_age == 0:
string_filter=['']
for str_filt in string_filter:
for cam in cams:
str_use = cam+'/'+str_filt
print(str_use)
ccond = FieldCondition(key='filepath',match=MatchText(text=str_use))
should_list.append(ccond)
condition_dict = Filter(should = should_list)
if len(should_list) == 0:
return_this = {'query':query,'num_videos':0,'results':[]}
return return_this
averaged = False
if isinstance(query, str):
averaged = bool(averaged)
num_videos = int(request.query.get('num_videos',5))
if do_load:
with torch.no_grad():
text_input = txt_processors['eval'](query)
sample = {'text_input':text_input}
vec = model.extract_features( sample)
vec_search = vec.cpu().numpy().squeeze().tolist()
else:
sz_vec= client.get_collection(collection_name).config.params.vectors.size
vec_search = np.random.random(sz_vec).tolist()
if averaged:
col_name=collection_name+'_averaged'
else:
col_name = collection_name
filter_string = ''
# %%
vec_form = requests.get('http://192.168.1.242:53004/encode',params={'query':query}).json()['vector'][0]
vec_search = np.asarray(vec_form).astype(np.float16)
print(f'Doing query for {query}!')
all_results=list()
try:
error = ''
if True:
for i in range(num_videos, 100, 10):
results = client.search(collection_name = col_name,
query_vector = vec_search, limit=i, query_filter=condition_dict, timeout=TIMEOUT)
num_video_got = len(set([x.payload['filepath'] for x in results]))
if num_video_got >= num_videos:
break
else:
results = client.search(collection_name = col_name,
query_vector = vec_search, limit=1, query_filter=condition_dict, timeout=TIMEOUT)
for cam in cams:
col_name = collection_name.format(camera=cam)
for i in [num_videos]:
results = client.search(collection_name = col_name,
consistency_level="Eventually",
filter=filter_string,
data = [vec_search],
limit=i,
search_params={'metric_type':'COSINE', 'params':{}},
output_fields=['filepath','frame_number']
)
all_results.extend(results[0])
except Exception as e:
print(traceback.format_exc())
error = traceback.format_exc()
results = [];
results = []
# %%
def linux_to_win_path(form):
form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/')
return form
def normalize_to_merged(path):
path = path.replace('/srv/ftp','/mergedfs/ftp')
path = path.replace('/mnt/archive2/videos/ftp','/mergedfs/ftp')
if not path.startswith('/'):
path= '/mergedfs/ftp/'+path
return path
resul = list()
for x in results:
pload =dict( x.payload)
all_results = sorted(all_results, key=lambda x: x['distance'])
for x in all_results:
pload =dict( x['entity'])
pload['filepath'] = normalize_to_merged(pload['filepath'])
pload['score'] = x.score
pload['score'] = x['distance']
pload['winpath'] = linux_to_win_path(pload['filepath'])
pload['frame'] = pload['frame_number']
resul.append(pload)
# %%
return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error}
prettyprinter.pprint(return_this)
return return_this