YACWC
This commit is contained in:
92
search_me.py
Normal file
92
search_me.py
Normal file
@@ -0,0 +1,92 @@
|
||||
do_load = True
|
||||
from qdrant_client import QdrantClient
|
||||
import numpy as np
|
||||
from bottle import route, run, template, request, debug
|
||||
# %%
|
||||
if do_load:
|
||||
from lavis.models import load_model_and_preprocess, model_zoo
|
||||
import torch
|
||||
device = 'cpu'
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess("clip_feature_extractor", model_type="ViT-B-16", is_eval=True, device=device)
|
||||
model.eval()
|
||||
collection_name="nuggets_clip"
|
||||
|
||||
# %%
|
||||
client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True)
|
||||
# %%
|
||||
from bottle import route, run, template
|
||||
|
||||
@route('/get_text_match')
|
||||
def get_matches():
|
||||
query = request.query.get('query','A large bird eating corn')
|
||||
# averaged = request.query.get('averaged',False)
|
||||
# %%
|
||||
max_age = request.query.get('age',5);
|
||||
# %%
|
||||
max_age = 5
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
averaged = False
|
||||
if isinstance(query, str):
|
||||
averaged = bool(averaged)
|
||||
|
||||
num_videos = int(request.query.get('num_videos',5))
|
||||
print(query, num_videos, averaged)
|
||||
|
||||
if do_load:
|
||||
with torch.no_grad():
|
||||
text_input = txt_processors['eval'](query)
|
||||
sample = {'text_input':text_input}
|
||||
vec = model.extract_features( sample)
|
||||
vec_search = vec.cpu().numpy().squeeze().tolist()
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
if averaged:
|
||||
col_name=collection_name+'_averaged'
|
||||
else:
|
||||
col_name = collection_name
|
||||
|
||||
if True:
|
||||
for i in range(num_videos, 100, 10):
|
||||
results = client.search(collection_name = col_name,
|
||||
query_vector = vec_search, limit=i)
|
||||
num_video_got = len(set([x.payload['filepath'] for x in results]))
|
||||
if num_video_got >= num_videos:
|
||||
break
|
||||
else:
|
||||
results = client.search(collection_name = col_name,
|
||||
query_vector = vec_search, limit=1)
|
||||
|
||||
|
||||
|
||||
def linux_to_win_path(form):
|
||||
form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/')
|
||||
return form
|
||||
|
||||
def normalize_to_merged(path):
|
||||
path = path.replace('/srv/ftp','/mergedfs/ftp')
|
||||
path = path.replace('/mnt/archive2/videos/ftp','/mergedfs/ftp')
|
||||
return path
|
||||
|
||||
resul = list()
|
||||
for x in results:
|
||||
pload =dict( x.payload)
|
||||
pload['filepath'] = normalize_to_merged(pload['filepath'])
|
||||
pload['score'] = x.score
|
||||
pload['winpath'] = linux_to_win_path(pload['filepath'])
|
||||
resul.append(pload)
|
||||
|
||||
# %%
|
||||
return_this = {'query':query,'num_videos':num_videos,'results':resul}
|
||||
return return_this
|
||||
|
||||
|
||||
|
||||
|
||||
debug(True)
|
||||
run(host='0.0.0.0', port=53003)
|
||||
Reference in New Issue
Block a user