This commit is contained in:
2024-05-24 21:04:37 -04:00
parent cc448ab44b
commit 3c8c99c186
3 changed files with 238 additions and 23 deletions

136
flycheck_search_me.py Normal file
View File

@@ -0,0 +1,136 @@
do_load = True
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, Range, MatchText
from datetime import datetime, timedelta
import numpy as np
import traceback
from bottle import route, run, template, request, debug
import open_clip
# %%
if do_load:
from lavis.models import load_model_and_preprocess, model_zoo
import torch
device = 'cpu'
model, vis_processors, txt_processors = load_model_and_preprocess("clip_feature_extractor", model_type="ViT-B-16", is_eval=True, device=device)
model.eval()
TIMEOUT=2
# %%
collection_name = "nuggets_so400m"
client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True, timeout=TIMEOUT)
# %%
from bottle import route, run, template
@route('/get_text_match')
def get_matches():
valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'}
query = request.query.get('query','A large bird eating corn')
cameras = request.query.get('cameras','sidefeeder')
cams = set(cameras.split(',')).intersection(valid_cameras)
max_age = int(request.query.get('age',5));
print({'Cameras':cams,'Max Age':max_age,'Query':query})
# %%
max_date = datetime.now()
min_date = max_date - timedelta(days=(max_age))
days_step = (max_date- min_date).days
date_arrays = list()
for i in range(days_step):
date_arrays.append(max_date - timedelta(days=i))
# %%
string_filter = list()
for cand_date in date_arrays:
string_filter.append(cand_date.strftime('%Y/%m/%d'))
should_list = list()
if max_age == 0:
string_filter=['']
for str_filt in string_filter:
for cam in cams:
str_use = cam+'/'+str_filt
print(str_use)
ccond = FieldCondition(key='filepath',match=MatchText(text=str_use))
should_list.append(ccond)
condition_dict = Filter(should = should_list)
if len(should_list) == 0:
return_this = {'query':query,'num_videos':0,'results':[]}
return return_this
averaged = False
if isinstance(query, str):
averaged = bool(averaged)
num_videos = int(request.query.get('num_videos',5))
if do_load:
with torch.no_grad():
text_input = txt_processors['eval'](query)
sample = {'text_input':text_input}
vec = model.extract_features( sample)
vec_search = vec.cpu().numpy().squeeze().tolist()
else:
sz_vec= client.get_collection(collection_name).config.params.vectors.size
vec_search = np.random.random(sz_vec).tolist()
if averaged:
col_name=collection_name+'_averaged'
else:
col_name = collection_name
try:
error = ''
if True:
for i in range(num_videos, 100, 10):
results = client.search(collection_name = col_name,
query_vector = vec_search, limit=i, query_filter=condition_dict, timeout=TIMEOUT)
num_video_got = len(set([x.payload['filepath'] for x in results]))
if num_video_got >= num_videos:
break
else:
results = client.search(collection_name = col_name,
query_vector = vec_search, limit=1, query_filter=condition_dict, timeout=TIMEOUT)
except Exception as e:
print(traceback.format_exc())
error = traceback.format_exc()
results = [];
def linux_to_win_path(form):
form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/')
return form
def normalize_to_merged(path):
path = path.replace('/srv/ftp','/mergedfs/ftp')
path = path.replace('/mnt/archive2/videos/ftp','/mergedfs/ftp')
return path
resul = list()
for x in results:
pload =dict( x.payload)
pload['filepath'] = normalize_to_merged(pload['filepath'])
pload['score'] = x.score
pload['winpath'] = linux_to_win_path(pload['filepath'])
resul.append(pload)
# %%
return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error}
return return_this
debug(True)
run(host='0.0.0.0', port=53003, server='bjoern')

View File

@@ -1,39 +1,76 @@
do_load = True
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, Range, MatchText
from datetime import datetime, timedelta
import numpy as np
import traceback
from bottle import route, run, template, request, debug
import open_clip
# %%
if do_load:
from lavis.models import load_model_and_preprocess, model_zoo
import torch
device = 'cpu'
model, vis_processors, txt_processors = load_model_and_preprocess("clip_feature_extractor", model_type="ViT-B-16", is_eval=True, device=device)
model.eval()
collection_name="nuggets_clip"
TIMEOUT=2
# %%
client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True)
collection_name = "nuggets_so400m"
client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True, timeout=TIMEOUT)
# %%
from bottle import route, run, template
@route('/get_text_match')
def get_matches():
valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'}
query = request.query.get('query','A large bird eating corn')
# averaged = request.query.get('averaged',False)
cameras = request.query.get('cameras','sidefeeder')
cams = set(cameras.split(',')).intersection(valid_cameras)
max_age = int(request.query.get('age',5));
print({'Cameras':cams,'Max Age':max_age,'Query':query})
# %%
max_age = request.query.get('age',5);
max_date = datetime.now()
min_date = max_date - timedelta(days=(max_age))
days_step = (max_date- min_date).days
date_arrays = list()
for i in range(days_step):
date_arrays.append(max_date - timedelta(days=i))
# %%
max_age = 5
string_filter = list()
for cand_date in date_arrays:
string_filter.append(cand_date.strftime('%Y/%m/%d'))
should_list = list()
if max_age == 0:
string_filter=['']
for str_filt in string_filter:
for cam in cams:
str_use = cam+'/'+str_filt
print(str_use)
ccond = FieldCondition(key='filepath',match=MatchText(text=str_use))
should_list.append(ccond)
# %%
condition_dict = Filter(should = should_list)
if len(should_list) == 0:
return_this = {'query':query,'num_videos':0,'results':[]}
return return_this
averaged = False
if isinstance(query, str):
averaged = bool(averaged)
num_videos = int(request.query.get('num_videos',5))
print(query, num_videos, averaged)
if do_load:
with torch.no_grad():
@@ -42,7 +79,8 @@ def get_matches():
vec = model.extract_features( sample)
vec_search = vec.cpu().numpy().squeeze().tolist()
else:
pass
sz_vec= client.get_collection(collection_name).config.params.vectors.size
vec_search = np.random.random(sz_vec).tolist()
@@ -51,17 +89,23 @@ def get_matches():
else:
col_name = collection_name
try:
error = ''
if True:
for i in range(num_videos, 100, 10):
results = client.search(collection_name = col_name,
query_vector = vec_search, limit=i)
query_vector = vec_search, limit=i, query_filter=condition_dict, timeout=TIMEOUT)
num_video_got = len(set([x.payload['filepath'] for x in results]))
if num_video_got >= num_videos:
break
else:
results = client.search(collection_name = col_name,
query_vector = vec_search, limit=1)
query_vector = vec_search, limit=1, query_filter=condition_dict, timeout=TIMEOUT)
except Exception as e:
print(traceback.format_exc())
error = traceback.format_exc()
results = [];
def linux_to_win_path(form):
@@ -82,11 +126,11 @@ def get_matches():
resul.append(pload)
# %%
return_this = {'query':query,'num_videos':num_videos,'results':resul}
return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error}
return return_this
debug(True)
run(host='0.0.0.0', port=53003)
run(host='0.0.0.0', port=53003, server='bjoern')

35
update_qdrant.py Normal file
View File

@@ -0,0 +1,35 @@
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.models import Distance, VectorParams
client = QdrantClient(host="localhost", port=6333)
collection_head = "nuggets_clip"
collection_head = "nuggets_so400m"
for collection_name in [collection_head, collection_head +'_averaged']:
try:
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=1152, distance=Distance.COSINE),
on_disk_payload = True
)
client.create_payload_index(
collection_name=collection_name,
field_name="filepath",
field_schema=models.TextIndexParams(
type="text",
tokenizer=models.TokenizerType.WORD,
min_token_len=1,
max_token_len=15,
lowercase=True,
),
)
except Exception as e:
print(e)
from prettyprinter import cpprint
cpprint(client.get_collection(collection_name).dict()['vectors_count'])