Files
vector_search/search_me.py
2025-04-17 15:55:54 -04:00

109 lines
3.3 KiB
Python

do_load = True
import requests
from pymilvus import MilvusClient, DataType
from pymilvus.client.types import LoadState
from datetime import datetime, timedelta
import numpy as np
import traceback
from bottle import route, run, template, request, debug
import time
import prettyprinter
# %%
TIMEOUT=2
# %%
#collection_name = "nuggets_{camera}_so400m_siglip2"
collection_name = "nuggets_{camera}_so400m"
client = MilvusClient(
uri="http://localhost:19530"
)
# %%
from bottle import route, run, template
@route('/get_text_match')
def get_matches():
# %%
valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'}
query = request.query.get('query','A large bird eating corn')
cameras = request.query.get('cameras','sidefeeder')
num_videos = int(request.query.get('num_videos',5))
cams = set(cameras.split(',')).intersection(valid_cameras)
max_age = int(request.query.get('age',5));
print({'Cameras':cams,'Max Age':max_age,'Query':query})
max_date = datetime.now()
min_date = max_date - timedelta(days=(max_age))
days_step = (max_date- min_date).days
day_strs = list()
# %%
for x in range(days_step):
day_strs.append( (min_date + timedelta(days=x)).strftime('%Y%m%d') )
str_insert = ','.join([f'"{x}"'for x in day_strs])
filter_string = 'date in [' + str_insert + ']'
if max_age == 0:
filter_string = ''
# %%
vec_form = requests.get('http://192.168.1.242:53004/encode',params={'query':query}).json()['vector'][0]
vec_search = np.asarray(vec_form).astype(np.float16)
print(f'Doing query for {query}!')
all_results=list()
try:
error = ''
if True:
for cam in cams:
col_name = collection_name.format(camera=cam)
for i in [num_videos]:
results = client.search(collection_name = col_name,
consistency_level="Eventually",
filter=filter_string,
data = [vec_search],
limit=i,
search_params={'metric_type':'COSINE', 'params':{}},
output_fields=['filepath','frame_number']
)
all_results.extend(results[0])
except Exception as e:
print(traceback.format_exc())
error = traceback.format_exc()
results = []
# %%
def linux_to_win_path(form):
form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/')
return form
def normalize_to_merged(path):
if not path.startswith('/'):
path= '/mergedfs/ftp/'+path
return path
resul = list()
all_results = sorted(all_results, key=lambda x: x['distance'])
for x in all_results:
pload =dict( x['entity'])
pload['filepath'] = normalize_to_merged(pload['filepath'])
pload['score'] = x['distance']
pload['winpath'] = linux_to_win_path(pload['filepath'])
pload['frame'] = pload['frame_number']
resul.append(pload)
# %%
return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error}
prettyprinter.pprint(return_this)
return return_this
debug(True)
run(host='0.0.0.0', port=53003, server='bjoern')