This commit is contained in:
2024-05-24 21:04:37 -04:00
parent cc448ab44b
commit 3c8c99c186
3 changed files with 238 additions and 23 deletions

136
flycheck_search_me.py Normal file
View File

@@ -0,0 +1,136 @@
do_load = True
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, Range, MatchText
from datetime import datetime, timedelta
import numpy as np
import traceback
from bottle import route, run, template, request, debug
import open_clip
# %%
if do_load:
from lavis.models import load_model_and_preprocess, model_zoo
import torch
device = 'cpu'
model, vis_processors, txt_processors = load_model_and_preprocess("clip_feature_extractor", model_type="ViT-B-16", is_eval=True, device=device)
model.eval()
TIMEOUT=2
# %%
collection_name = "nuggets_so400m"
client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True, timeout=TIMEOUT)
# %%
from bottle import route, run, template
@route('/get_text_match')
def get_matches():
valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'}
query = request.query.get('query','A large bird eating corn')
cameras = request.query.get('cameras','sidefeeder')
cams = set(cameras.split(',')).intersection(valid_cameras)
max_age = int(request.query.get('age',5));
print({'Cameras':cams,'Max Age':max_age,'Query':query})
# %%
max_date = datetime.now()
min_date = max_date - timedelta(days=(max_age))
days_step = (max_date- min_date).days
date_arrays = list()
for i in range(days_step):
date_arrays.append(max_date - timedelta(days=i))
# %%
string_filter = list()
for cand_date in date_arrays:
string_filter.append(cand_date.strftime('%Y/%m/%d'))
should_list = list()
if max_age == 0:
string_filter=['']
for str_filt in string_filter:
for cam in cams:
str_use = cam+'/'+str_filt
print(str_use)
ccond = FieldCondition(key='filepath',match=MatchText(text=str_use))
should_list.append(ccond)
condition_dict = Filter(should = should_list)
if len(should_list) == 0:
return_this = {'query':query,'num_videos':0,'results':[]}
return return_this
averaged = False
if isinstance(query, str):
averaged = bool(averaged)
num_videos = int(request.query.get('num_videos',5))
if do_load:
with torch.no_grad():
text_input = txt_processors['eval'](query)
sample = {'text_input':text_input}
vec = model.extract_features( sample)
vec_search = vec.cpu().numpy().squeeze().tolist()
else:
sz_vec= client.get_collection(collection_name).config.params.vectors.size
vec_search = np.random.random(sz_vec).tolist()
if averaged:
col_name=collection_name+'_averaged'
else:
col_name = collection_name
try:
error = ''
if True:
for i in range(num_videos, 100, 10):
results = client.search(collection_name = col_name,
query_vector = vec_search, limit=i, query_filter=condition_dict, timeout=TIMEOUT)
num_video_got = len(set([x.payload['filepath'] for x in results]))
if num_video_got >= num_videos:
break
else:
results = client.search(collection_name = col_name,
query_vector = vec_search, limit=1, query_filter=condition_dict, timeout=TIMEOUT)
except Exception as e:
print(traceback.format_exc())
error = traceback.format_exc()
results = [];
def linux_to_win_path(form):
form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/')
return form
def normalize_to_merged(path):
path = path.replace('/srv/ftp','/mergedfs/ftp')
path = path.replace('/mnt/archive2/videos/ftp','/mergedfs/ftp')
return path
resul = list()
for x in results:
pload =dict( x.payload)
pload['filepath'] = normalize_to_merged(pload['filepath'])
pload['score'] = x.score
pload['winpath'] = linux_to_win_path(pload['filepath'])
resul.append(pload)
# %%
return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error}
return return_this
debug(True)
run(host='0.0.0.0', port=53003, server='bjoern')

View File

@@ -1,39 +1,76 @@
do_load = True do_load = True
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, Range, MatchText
from datetime import datetime, timedelta
import numpy as np import numpy as np
import traceback
from bottle import route, run, template, request, debug from bottle import route, run, template, request, debug
import open_clip
# %% # %%
if do_load: if do_load:
from lavis.models import load_model_and_preprocess, model_zoo from lavis.models import load_model_and_preprocess, model_zoo
import torch import torch
device = 'cpu' device = 'cpu'
model, vis_processors, txt_processors = load_model_and_preprocess("clip_feature_extractor", model_type="ViT-B-16", is_eval=True, device=device) model, vis_processors, txt_processors = load_model_and_preprocess("clip_feature_extractor", model_type="ViT-B-16", is_eval=True, device=device)
model.eval() model.eval()
collection_name="nuggets_clip"
TIMEOUT=2
# %% # %%
client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True)
collection_name = "nuggets_so400m"
client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True, timeout=TIMEOUT)
# %% # %%
from bottle import route, run, template from bottle import route, run, template
@route('/get_text_match') @route('/get_text_match')
def get_matches(): def get_matches():
valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'}
query = request.query.get('query','A large bird eating corn') query = request.query.get('query','A large bird eating corn')
# averaged = request.query.get('averaged',False) cameras = request.query.get('cameras','sidefeeder')
# %% cams = set(cameras.split(',')).intersection(valid_cameras)
max_age = request.query.get('age',5);
# %%
max_age = 5
max_age = int(request.query.get('age',5));
print({'Cameras':cams,'Max Age':max_age,'Query':query})
# %% # %%
max_date = datetime.now()
min_date = max_date - timedelta(days=(max_age))
days_step = (max_date- min_date).days
date_arrays = list()
for i in range(days_step):
date_arrays.append(max_date - timedelta(days=i))
# %%
string_filter = list()
for cand_date in date_arrays:
string_filter.append(cand_date.strftime('%Y/%m/%d'))
should_list = list()
if max_age == 0:
string_filter=['']
for str_filt in string_filter:
for cam in cams:
str_use = cam+'/'+str_filt
print(str_use)
ccond = FieldCondition(key='filepath',match=MatchText(text=str_use))
should_list.append(ccond)
condition_dict = Filter(should = should_list)
if len(should_list) == 0:
return_this = {'query':query,'num_videos':0,'results':[]}
return return_this
averaged = False averaged = False
if isinstance(query, str): if isinstance(query, str):
averaged = bool(averaged) averaged = bool(averaged)
num_videos = int(request.query.get('num_videos',5)) num_videos = int(request.query.get('num_videos',5))
print(query, num_videos, averaged)
if do_load: if do_load:
with torch.no_grad(): with torch.no_grad():
@@ -42,7 +79,8 @@ def get_matches():
vec = model.extract_features( sample) vec = model.extract_features( sample)
vec_search = vec.cpu().numpy().squeeze().tolist() vec_search = vec.cpu().numpy().squeeze().tolist()
else: else:
pass sz_vec= client.get_collection(collection_name).config.params.vectors.size
vec_search = np.random.random(sz_vec).tolist()
@@ -51,17 +89,23 @@ def get_matches():
else: else:
col_name = collection_name col_name = collection_name
if True:
for i in range(num_videos, 100, 10): try:
error = ''
if True:
for i in range(num_videos, 100, 10):
results = client.search(collection_name = col_name,
query_vector = vec_search, limit=i, query_filter=condition_dict, timeout=TIMEOUT)
num_video_got = len(set([x.payload['filepath'] for x in results]))
if num_video_got >= num_videos:
break
else:
results = client.search(collection_name = col_name, results = client.search(collection_name = col_name,
query_vector = vec_search, limit=i) query_vector = vec_search, limit=1, query_filter=condition_dict, timeout=TIMEOUT)
num_video_got = len(set([x.payload['filepath'] for x in results])) except Exception as e:
if num_video_got >= num_videos: print(traceback.format_exc())
break error = traceback.format_exc()
else: results = [];
results = client.search(collection_name = col_name,
query_vector = vec_search, limit=1)
def linux_to_win_path(form): def linux_to_win_path(form):
@@ -82,11 +126,11 @@ def get_matches():
resul.append(pload) resul.append(pload)
# %% # %%
return_this = {'query':query,'num_videos':num_videos,'results':resul} return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error}
return return_this return return_this
debug(True) debug(True)
run(host='0.0.0.0', port=53003) run(host='0.0.0.0', port=53003, server='bjoern')

35
update_qdrant.py Normal file
View File

@@ -0,0 +1,35 @@
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.models import Distance, VectorParams
client = QdrantClient(host="localhost", port=6333)
collection_head = "nuggets_clip"
collection_head = "nuggets_so400m"
for collection_name in [collection_head, collection_head +'_averaged']:
try:
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=1152, distance=Distance.COSINE),
on_disk_payload = True
)
client.create_payload_index(
collection_name=collection_name,
field_name="filepath",
field_schema=models.TextIndexParams(
type="text",
tokenizer=models.TokenizerType.WORD,
min_token_len=1,
max_token_len=15,
lowercase=True,
),
)
except Exception as e:
print(e)
from prettyprinter import cpprint
cpprint(client.get_collection(collection_name).dict()['vectors_count'])