diff --git a/flycheck_search_me.py b/flycheck_search_me.py new file mode 100644 index 0000000..3c1f81a --- /dev/null +++ b/flycheck_search_me.py @@ -0,0 +1,136 @@ +do_load = True +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, FieldCondition, Range, MatchText +from datetime import datetime, timedelta +import numpy as np +import traceback +from bottle import route, run, template, request, debug +import open_clip +# %% +if do_load: + + from lavis.models import load_model_and_preprocess, model_zoo + import torch + device = 'cpu' + model, vis_processors, txt_processors = load_model_and_preprocess("clip_feature_extractor", model_type="ViT-B-16", is_eval=True, device=device) + model.eval() + +TIMEOUT=2 +# %% + +collection_name = "nuggets_so400m" +client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True, timeout=TIMEOUT) +# %% +from bottle import route, run, template + +@route('/get_text_match') +def get_matches(): + valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'} + query = request.query.get('query','A large bird eating corn') + cameras = request.query.get('cameras','sidefeeder') + cams = set(cameras.split(',')).intersection(valid_cameras) + + max_age = int(request.query.get('age',5)); + print({'Cameras':cams,'Max Age':max_age,'Query':query}) +# %% + max_date = datetime.now() + min_date = max_date - timedelta(days=(max_age)) + + days_step = (max_date- min_date).days + date_arrays = list() + for i in range(days_step): + date_arrays.append(max_date - timedelta(days=i)) +# %% + string_filter = list() + for cand_date in date_arrays: + string_filter.append(cand_date.strftime('%Y/%m/%d')) + + + should_list = list() + if max_age == 0: + string_filter=[''] + + for str_filt in string_filter: + for cam in cams: + str_use = cam+'/'+str_filt + print(str_use) + ccond = FieldCondition(key='filepath',match=MatchText(text=str_use)) + should_list.append(ccond) + + + + + condition_dict = Filter(should = should_list) + if len(should_list) == 0: + return_this = {'query':query,'num_videos':0,'results':[]} + return return_this + + + averaged = False + if isinstance(query, str): + averaged = bool(averaged) + + num_videos = int(request.query.get('num_videos',5)) + + if do_load: + with torch.no_grad(): + text_input = txt_processors['eval'](query) + sample = {'text_input':text_input} + vec = model.extract_features( sample) + vec_search = vec.cpu().numpy().squeeze().tolist() + else: + sz_vec= client.get_collection(collection_name).config.params.vectors.size + vec_search = np.random.random(sz_vec).tolist() + + + + if averaged: + col_name=collection_name+'_averaged' + else: + col_name = collection_name + + + try: + error = '' + if True: + for i in range(num_videos, 100, 10): + results = client.search(collection_name = col_name, + query_vector = vec_search, limit=i, query_filter=condition_dict, timeout=TIMEOUT) + num_video_got = len(set([x.payload['filepath'] for x in results])) + if num_video_got >= num_videos: + break + else: + results = client.search(collection_name = col_name, + query_vector = vec_search, limit=1, query_filter=condition_dict, timeout=TIMEOUT) + except Exception as e: + print(traceback.format_exc()) + error = traceback.format_exc() + results = []; + + + def linux_to_win_path(form): + form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/') + return form + + def normalize_to_merged(path): + path = path.replace('/srv/ftp','/mergedfs/ftp') + path = path.replace('/mnt/archive2/videos/ftp','/mergedfs/ftp') + return path + + resul = list() + for x in results: + pload =dict( x.payload) + pload['filepath'] = normalize_to_merged(pload['filepath']) + pload['score'] = x.score + pload['winpath'] = linux_to_win_path(pload['filepath']) + resul.append(pload) + +# %% + return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error} + return return_this + + + + +debug(True) +run(host='0.0.0.0', port=53003, server='bjoern') diff --git a/search_me.py b/search_me.py index 90ced97..3c1f81a 100644 --- a/search_me.py +++ b/search_me.py @@ -1,39 +1,76 @@ do_load = True from qdrant_client import QdrantClient +from qdrant_client.models import Filter, FieldCondition, Range, MatchText +from datetime import datetime, timedelta import numpy as np +import traceback from bottle import route, run, template, request, debug +import open_clip # %% if do_load: + from lavis.models import load_model_and_preprocess, model_zoo import torch device = 'cpu' model, vis_processors, txt_processors = load_model_and_preprocess("clip_feature_extractor", model_type="ViT-B-16", is_eval=True, device=device) model.eval() - collection_name="nuggets_clip" +TIMEOUT=2 # %% -client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True) + +collection_name = "nuggets_so400m" +client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True, timeout=TIMEOUT) # %% from bottle import route, run, template @route('/get_text_match') def get_matches(): + valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'} query = request.query.get('query','A large bird eating corn') -# averaged = request.query.get('averaged',False) -# %% - max_age = request.query.get('age',5); -# %% - max_age = 5 - - + cameras = request.query.get('cameras','sidefeeder') + cams = set(cameras.split(',')).intersection(valid_cameras) + max_age = int(request.query.get('age',5)); + print({'Cameras':cams,'Max Age':max_age,'Query':query}) # %% + max_date = datetime.now() + min_date = max_date - timedelta(days=(max_age)) + + days_step = (max_date- min_date).days + date_arrays = list() + for i in range(days_step): + date_arrays.append(max_date - timedelta(days=i)) +# %% + string_filter = list() + for cand_date in date_arrays: + string_filter.append(cand_date.strftime('%Y/%m/%d')) + + + should_list = list() + if max_age == 0: + string_filter=[''] + + for str_filt in string_filter: + for cam in cams: + str_use = cam+'/'+str_filt + print(str_use) + ccond = FieldCondition(key='filepath',match=MatchText(text=str_use)) + should_list.append(ccond) + + + + + condition_dict = Filter(should = should_list) + if len(should_list) == 0: + return_this = {'query':query,'num_videos':0,'results':[]} + return return_this + + averaged = False if isinstance(query, str): averaged = bool(averaged) num_videos = int(request.query.get('num_videos',5)) - print(query, num_videos, averaged) if do_load: with torch.no_grad(): @@ -42,7 +79,8 @@ def get_matches(): vec = model.extract_features( sample) vec_search = vec.cpu().numpy().squeeze().tolist() else: - pass + sz_vec= client.get_collection(collection_name).config.params.vectors.size + vec_search = np.random.random(sz_vec).tolist() @@ -51,17 +89,23 @@ def get_matches(): else: col_name = collection_name - if True: - for i in range(num_videos, 100, 10): + + try: + error = '' + if True: + for i in range(num_videos, 100, 10): + results = client.search(collection_name = col_name, + query_vector = vec_search, limit=i, query_filter=condition_dict, timeout=TIMEOUT) + num_video_got = len(set([x.payload['filepath'] for x in results])) + if num_video_got >= num_videos: + break + else: results = client.search(collection_name = col_name, - query_vector = vec_search, limit=i) - num_video_got = len(set([x.payload['filepath'] for x in results])) - if num_video_got >= num_videos: - break - else: - results = client.search(collection_name = col_name, - query_vector = vec_search, limit=1) - + query_vector = vec_search, limit=1, query_filter=condition_dict, timeout=TIMEOUT) + except Exception as e: + print(traceback.format_exc()) + error = traceback.format_exc() + results = []; def linux_to_win_path(form): @@ -82,11 +126,11 @@ def get_matches(): resul.append(pload) # %% - return_this = {'query':query,'num_videos':num_videos,'results':resul} + return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error} return return_this debug(True) -run(host='0.0.0.0', port=53003) +run(host='0.0.0.0', port=53003, server='bjoern') diff --git a/update_qdrant.py b/update_qdrant.py new file mode 100644 index 0000000..debd01a --- /dev/null +++ b/update_qdrant.py @@ -0,0 +1,35 @@ +from qdrant_client import QdrantClient +from qdrant_client.http import models +from qdrant_client.models import Distance, VectorParams + +client = QdrantClient(host="localhost", port=6333) + +collection_head = "nuggets_clip" +collection_head = "nuggets_so400m" + + + +for collection_name in [collection_head, collection_head +'_averaged']: + try: + client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams(size=1152, distance=Distance.COSINE), + on_disk_payload = True + ) + + client.create_payload_index( + collection_name=collection_name, + field_name="filepath", + field_schema=models.TextIndexParams( + type="text", + tokenizer=models.TokenizerType.WORD, + min_token_len=1, + max_token_len=15, + lowercase=True, + ), + ) + except Exception as e: + print(e) + from prettyprinter import cpprint + cpprint(client.get_collection(collection_name).dict()['vectors_count']) +