From 79f9a1dbc215f4ea1d471ebcd8f8562863745d17 Mon Sep 17 00:00:00 2001 From: "Ishan S. Patel" Date: Thu, 17 Apr 2025 15:55:54 -0400 Subject: [PATCH] YACWC --- clip_endpoint.py | 31 ++++++ dump_qdrant.py | 21 ++++ milvus_migrate/*dashboard* | 15 +++ milvus_migrate/create_collection.py | 56 ++++++++++ milvus_migrate/create_collection_v2.py | 66 ++++++++++++ milvus_migrate/create_collection_v3.py | 64 ++++++++++++ milvus_migrate/search_try.py | 25 +++++ milvus_migrate/upload_from_folder.py | 75 +++++++++++++ ngt_migrate/create_index.py | 38 +++++++ search_me.py | 130 +++++++++-------------- search_me_qdrant.py | 139 +++++++++++++++++++++++++ update_qdrant.py | 31 +++++- 12 files changed, 611 insertions(+), 80 deletions(-) create mode 100644 clip_endpoint.py create mode 100644 dump_qdrant.py create mode 100644 milvus_migrate/*dashboard* create mode 100644 milvus_migrate/create_collection.py create mode 100644 milvus_migrate/create_collection_v2.py create mode 100644 milvus_migrate/create_collection_v3.py create mode 100644 milvus_migrate/search_try.py create mode 100644 milvus_migrate/upload_from_folder.py create mode 100644 ngt_migrate/create_index.py create mode 100644 search_me_qdrant.py diff --git a/clip_endpoint.py b/clip_endpoint.py new file mode 100644 index 0000000..6ab766e --- /dev/null +++ b/clip_endpoint.py @@ -0,0 +1,31 @@ +import open_clip +import torch + +do_load = True +if do_load: + + #model_name = 'hf-hub:timm/ViT-L-16-SigLIP2-512' +# model, preprocess = open_clip.create_model_from_pretrained('hf-hub:timm/ViT-L-16-SigLIP2-512') + + + model_name = 'ViT-SO400M-14-SigLIP-384' + pretrained_name = 'webli' + model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained_name) + device = 'cpu' + model.eval() + tokenizer = open_clip.get_tokenizer(model_name) + +from bottle import route, run, template, request, debug +@route('/encode') +def get_matches(): + query = request.query.get('query','A large bird eating corn') + with torch.no_grad(): + text_tokenized = tokenizer(query) + vec = model.encode_text(text_tokenized).detach().cpu().tolist() + + return {'vector':vec} + + + +debug(True) +run(host='0.0.0.0', port=53004, server='bjoern') diff --git a/dump_qdrant.py b/dump_qdrant.py new file mode 100644 index 0000000..a30055a --- /dev/null +++ b/dump_qdrant.py @@ -0,0 +1,21 @@ +from qdrant_client import QdrantClient +from qdrant_client.http import models +from qdrant_client.models import Distance, VectorParams +client = QdrantClient(host="localhost", port=6333) + + +collection_head = "nuggets_so400m" +out = client.scroll('nuggets_so400m',limit=10,with_vectors=True) + + + +offset_id = None +#while True: +if True: + res = client.scroll(collection_name=collection_head, + limit=1000000, + offset=offset_id, + with_payload=False, + with_vectors=True); + +# ou = np.asarray([x.vector for x in out[0]], dtype=np.float16) diff --git a/milvus_migrate/*dashboard* b/milvus_migrate/*dashboard* new file mode 100644 index 0000000..a1b915b --- /dev/null +++ b/milvus_migrate/*dashboard* @@ -0,0 +1,15 @@ +from pymilvus import MilvusClient, DataType +import numpy as np + +client = MilvusClient( + uri="http://localhost:19530" +) +cname = 'nuggets_so400m' +client.get_collection_stats('nuggets_so400m') + + + +vec = np.random.random(1152).astype(np.float16) + + +client.search \ No newline at end of file diff --git a/milvus_migrate/create_collection.py b/milvus_migrate/create_collection.py new file mode 100644 index 0000000..bb972b8 --- /dev/null +++ b/milvus_migrate/create_collection.py @@ -0,0 +1,56 @@ +from pymilvus import MilvusClient, DataType + +# 1. Set up a Milvus client +client = MilvusClient( + uri="http://localhost:19530" +) +client.get_collection_stats('nuggets_so400m') +# %% + +schema = MilvusClient.create_schema( + auto_id=False, + enable_dynamic_field=False, +) +schema.add_field(field_name="primary_id",datatype=DataType.INT64, is_primary=True) +schema.add_field(field_name="filepath", datatype=DataType.VARCHAR, max_length=128) +schema.add_field(field_name="frame_number", datatype=DataType.INT32) +schema.add_field(field_name="so400m", datatype=DataType.FLOAT16_VECTOR, dim=1152) + + +index_params = client.prepare_index_params() + + +index_params.add_index( + field_name="primary_id", + index_type="STL_SORT") + +index_params.add_index( + field_name="filepath", + index_type="Trie") + +index_params.add_index( + field_name="so400m", + index_type="IVF_FLAT", + metric_type="COSINE", + params={ "nlist": 128 }) + + + +client.create_collection( + collection_name="nuggets_so400m", + schema=schema, + index_params=index_params +) + +# %% +res = client.get_load_state( + collection_name="nuggets_so400m" +) + + +res = client.load_collection(collection_name="nuggets_so400m") + + + + + diff --git a/milvus_migrate/create_collection_v2.py b/milvus_migrate/create_collection_v2.py new file mode 100644 index 0000000..54d0519 --- /dev/null +++ b/milvus_migrate/create_collection_v2.py @@ -0,0 +1,66 @@ +from pymilvus import MilvusClient, DataType + + +client = MilvusClient( + uri="http://localhost:19530" +) +for x in client.list_collections(): + client.drop_collection(x) + +# %% + +import os +out = os.listdir('/mergedfs/ftp/') + +# %% +for cam in out: + schema = MilvusClient.create_schema( + auto_id=False, + nenable_dynamic_field=False, + ) + schema.add_field(field_name="primary_id",datatype=DataType.INT64, is_primary=True) + schema.add_field(field_name="filepath", datatype=DataType.VARCHAR, max_length=128) + schema.add_field(field_name="frame_number", datatype=DataType.INT32) + schema.add_field(field_name="date", datatype=DataType.VARCHAR, max_length=len('20241220'), is_partition_key=True) + schema.add_field(field_name="so400m", datatype=DataType.FLOAT16_VECTOR, dim=1152) + + + index_params = client.prepare_index_params() + + index_params.add_index( + field_name="primary_id", + index_type="STL_SORT") + + index_params.add_index( + field_name="filepatph", + index_type="Trie") + + index_params.add_index( + field_name="so400m", + index_type="IVF_SQ8", + metric_type="COSINE", + params={'nlist':128}) + + index_params.add_index( + field_name='date', + index_type='Trie') + + client.create_collection( + collection_name=f"nuggets_{cam}_so400m", + schema=schema, + index_params=index_params + ) + print(cam) + +# %% +res = client.get_load_state( + collection_name="nuggets_so400m" +) + + +res = client.load_collection(collection_name="nuggets_so400m") + + + + + diff --git a/milvus_migrate/create_collection_v3.py b/milvus_migrate/create_collection_v3.py new file mode 100644 index 0000000..aa83c27 --- /dev/null +++ b/milvus_migrate/create_collection_v3.py @@ -0,0 +1,64 @@ +from pymilvus import MilvusClient, DataType + + +client = MilvusClient( + uri="http://localhost:19530" +) +#for x in client.list_collections(): +# client.drop_collection(x) + + + + +# %% + +import os +out = os.listdir('/mergedfs/ftp/') + +# %% +for cam in out: + schema = MilvusClient.create_schema( + auto_id=False, + nenable_dynamic_field=False, + ) + schema.add_field(field_name="primary_id",datatype=DataType.INT64, is_primary=True) + schema.add_field(field_name="filepath", datatype=DataType.VARCHAR, max_length=128) + schema.add_field(field_name="frame_number", datatype=DataType.INT32) + schema.add_field(field_name="date", datatype=DataType.VARCHAR, max_length=len('20241220'), is_partition_key=True) + schema.add_field(field_name="so400m", datatype=DataType.FLOAT16_VECTOR, dim=1024) + + + index_params = client.prepare_index_params() + + index_params.add_index( + field_name="primary_id", + index_type="STL_SORT") + + index_params.add_index( + field_name="filepath", + index_type="Trie") + + index_params.add_index( + field_name="so400m", + index_type="IVF_SQ8", + metric_type="COSINE", + params={'nlist':128}) + + index_params.add_index( + field_name='date', + index_type='Trie') + + client.create_collection( + collection_name=f"nuggets_{cam}_so400m_siglip2", + schema=schema, + index_params=index_params + ) + print(cam) + +#res = client.get_load_state( collection_name="nuggets_ptz_so400m_siglip2") +#res = client.load_collection(collection_name="nuggets_so400m") + + + + + diff --git a/milvus_migrate/search_try.py b/milvus_migrate/search_try.py new file mode 100644 index 0000000..6d49ba8 --- /dev/null +++ b/milvus_migrate/search_try.py @@ -0,0 +1,25 @@ +from pymilvus import MilvusClient, DataType +import numpy as np + +# 1. Set up a Milvus client +client = MilvusClient(uri="http://localhost:19530") +cname = "nuggets_so400m" +ou = client.get_collection_stats(cname) + +import random +vec = [random.random() for x in range(1152)] +# %% + +from prettyprinter import cpprint +vec = random. +out = client.search( + collection_name=cname, + limit = 100, + data=[vec], + output_fields=["filepath", "frame_number"], + filter='(filepath like "%2024/09/20%") or (filepath like "%2024/09/23%")' + ) + + + +cpprint([x['entity']['filepath'] for x in out[0]]) diff --git a/milvus_migrate/upload_from_folder.py b/milvus_migrate/upload_from_folder.py new file mode 100644 index 0000000..6c002b1 --- /dev/null +++ b/milvus_migrate/upload_from_folder.py @@ -0,0 +1,75 @@ +from pymilvus import MilvusClient, DataType +import numpy as np +import time +from pymilvus.client.types import LoadState +client = MilvusClient( + uri="http://localhost:19530" +) + + +res = client.get_load_state( + collection_name="nuggets_so400m" +) +if res['state'] == LoadState.Loaded: + pass +else: + client.load_collection(collection_name = 'nuggets_so400m') + for i in range(10): + time.sleep(1) + if res['state'] == LoadState.Loaded: + break + + +def get_vec_path(vpath): + return os.path.splitext(vpath)[0]+'.oclip_embeds.npz' + +def get_db_embed_done_path(vpath): + return os.path.splitext(vpath)[0]+'.db_has_oclip_embeds' + + +def upload_vector_file(vector_file_to_upload): + if os.path.exists(get_embed_done_path(vector_file_to_upload)): + print('Already exists in DB, skipping upload') + return + + vector_file_to_upload = get_vec_path(vector_file_to_upload) + vf = np.load(vector_file_to_upload) + + embeds = vf['embeds'] + fr_nums = vf['frame_numbers'] + + fname_root = vector_file_to_upload.rsplit('/',1)[-1].split('.')[0] + fc = fname_root.split('_')[-1] + + data = list() + filepath = vector_file_to_upload.replace('/srv/ftp/','').replace('/mergedfs/ftp','').split('.')[-0] + + for embed, frame_num in zip(embeds, fr_nums): + fg = '{0:05g}'.format(frame_num) + id_num = int(fc+fg) + to_put = dict(primary_id= id_num, filepath=filepath, frame_number = int(frame_num), so400m=embed) + data.append(to_put) + + client.insert(collection_name = 'nuggets_so400m', data = data) + print(f'Inserting into DB, {vector_file_to_upload}') + + with open(get_embed_done_path(vector_file_to_upload),'w') as ff: + ff.write(str(time.time())) + + + + + +root_path = '/srv/ftp/railing/2024' +to_put = list() +for root, dirs, files in os.walk(root_path): + for x in files: + if x.endswith('oclip_embeds.npz'): + to_put.append(os.path.join(root, x)) + + +for x in to_put: + upload_vector_file(x) + + + diff --git a/ngt_migrate/create_index.py b/ngt_migrate/create_index.py new file mode 100644 index 0000000..d76b3d6 --- /dev/null +++ b/ngt_migrate/create_index.py @@ -0,0 +1,38 @@ +import ngtpy +dim = 1152 +index_path = b"/mnt/ssd_nvm/ngt/openclip_so400m" +if not os.path.exists(index_path): + ngtpy.create(index_path, dim) + print(f'Created index at {index_path}') + +index = ngtpy.Index(index_path) +# %% + +import os +to_add_to_index = list() +for root, dirs, files in os.walk('/mergedfs/ftp'): + for x in files: + if x.endswith('.oclip_embeds.npz'): + to_add_to_index.append(os.path.join(root, x)) + +# %% +import numpy as np +import progressbar +# %% +total_vecs = 0 +import progressbar +bar = progressbar.ProgressBar(max_value = len(to_add_to_index)) +for idx, to_add in enumerate(to_add_to_index): + try: + emb_vec = np.load(to_add)['embeds'] + total_vecs+= emb_vec.shape[0] + index.batch_insert(emb_vec) + index.save() + except: + pass + + bar.update(idx) + + + + diff --git a/search_me.py b/search_me.py index 3c1f81a..b860eae 100644 --- a/search_me.py +++ b/search_me.py @@ -1,132 +1,104 @@ do_load = True -from qdrant_client import QdrantClient -from qdrant_client.models import Filter, FieldCondition, Range, MatchText +import requests +from pymilvus import MilvusClient, DataType +from pymilvus.client.types import LoadState from datetime import datetime, timedelta import numpy as np import traceback from bottle import route, run, template, request, debug -import open_clip +import time +import prettyprinter # %% -if do_load: - - from lavis.models import load_model_and_preprocess, model_zoo - import torch - device = 'cpu' - model, vis_processors, txt_processors = load_model_and_preprocess("clip_feature_extractor", model_type="ViT-B-16", is_eval=True, device=device) - model.eval() TIMEOUT=2 # %% -collection_name = "nuggets_so400m" -client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True, timeout=TIMEOUT) +#collection_name = "nuggets_{camera}_so400m_siglip2" +collection_name = "nuggets_{camera}_so400m" +client = MilvusClient( + uri="http://localhost:19530" +) # %% + from bottle import route, run, template @route('/get_text_match') def get_matches(): + # %% valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'} query = request.query.get('query','A large bird eating corn') cameras = request.query.get('cameras','sidefeeder') + num_videos = int(request.query.get('num_videos',5)) + cams = set(cameras.split(',')).intersection(valid_cameras) max_age = int(request.query.get('age',5)); print({'Cameras':cams,'Max Age':max_age,'Query':query}) -# %% max_date = datetime.now() min_date = max_date - timedelta(days=(max_age)) - days_step = (max_date- min_date).days - date_arrays = list() - for i in range(days_step): - date_arrays.append(max_date - timedelta(days=i)) -# %% - string_filter = list() - for cand_date in date_arrays: - string_filter.append(cand_date.strftime('%Y/%m/%d')) + day_strs = list() + # %% + for x in range(days_step): + day_strs.append( (min_date + timedelta(days=x)).strftime('%Y%m%d') ) + str_insert = ','.join([f'"{x}"'for x in day_strs]) + filter_string = 'date in [' + str_insert + ']' - should_list = list() if max_age == 0: - string_filter=[''] - - for str_filt in string_filter: - for cam in cams: - str_use = cam+'/'+str_filt - print(str_use) - ccond = FieldCondition(key='filepath',match=MatchText(text=str_use)) - should_list.append(ccond) - - - - - condition_dict = Filter(should = should_list) - if len(should_list) == 0: - return_this = {'query':query,'num_videos':0,'results':[]} - return return_this - - - averaged = False - if isinstance(query, str): - averaged = bool(averaged) - - num_videos = int(request.query.get('num_videos',5)) - - if do_load: - with torch.no_grad(): - text_input = txt_processors['eval'](query) - sample = {'text_input':text_input} - vec = model.extract_features( sample) - vec_search = vec.cpu().numpy().squeeze().tolist() - else: - sz_vec= client.get_collection(collection_name).config.params.vectors.size - vec_search = np.random.random(sz_vec).tolist() - - - - if averaged: - col_name=collection_name+'_averaged' - else: - col_name = collection_name + filter_string = '' + # %% + vec_form = requests.get('http://192.168.1.242:53004/encode',params={'query':query}).json()['vector'][0] + vec_search = np.asarray(vec_form).astype(np.float16) + print(f'Doing query for {query}!') + all_results=list() try: error = '' if True: - for i in range(num_videos, 100, 10): - results = client.search(collection_name = col_name, - query_vector = vec_search, limit=i, query_filter=condition_dict, timeout=TIMEOUT) - num_video_got = len(set([x.payload['filepath'] for x in results])) - if num_video_got >= num_videos: - break - else: - results = client.search(collection_name = col_name, - query_vector = vec_search, limit=1, query_filter=condition_dict, timeout=TIMEOUT) + + for cam in cams: + col_name = collection_name.format(camera=cam) + for i in [num_videos]: + + results = client.search(collection_name = col_name, + consistency_level="Eventually", + filter=filter_string, + data = [vec_search], + limit=i, + search_params={'metric_type':'COSINE', 'params':{}}, + output_fields=['filepath','frame_number'] + ) + all_results.extend(results[0]) + except Exception as e: print(traceback.format_exc()) error = traceback.format_exc() - results = []; + results = [] - +# %% def linux_to_win_path(form): form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/') return form def normalize_to_merged(path): - path = path.replace('/srv/ftp','/mergedfs/ftp') - path = path.replace('/mnt/archive2/videos/ftp','/mergedfs/ftp') + if not path.startswith('/'): + path= '/mergedfs/ftp/'+path return path resul = list() - for x in results: - pload =dict( x.payload) + all_results = sorted(all_results, key=lambda x: x['distance']) + for x in all_results: + pload =dict( x['entity']) pload['filepath'] = normalize_to_merged(pload['filepath']) - pload['score'] = x.score + pload['score'] = x['distance'] pload['winpath'] = linux_to_win_path(pload['filepath']) + pload['frame'] = pload['frame_number'] resul.append(pload) - # %% return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error} + prettyprinter.pprint(return_this) return return_this diff --git a/search_me_qdrant.py b/search_me_qdrant.py new file mode 100644 index 0000000..a0ee031 --- /dev/null +++ b/search_me_qdrant.py @@ -0,0 +1,139 @@ +do_load = True +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, FieldCondition, Range, MatchText +from datetime import datetime, timedelta +import numpy as np +import traceback +from bottle import route, run, template, request, debug +import open_clip +import torch +# %% +if do_load: + model_name = 'ViT-SO400M-14-SigLIP-384' + pretrained_name = 'webli' + + model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained_name) + device = 'cpu' + model.eval() + tokenizer = open_clip.get_tokenizer(model_name) + +TIMEOUT=2 +# %% + +collection_name = "nuggets_so400m" +client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True, timeout=TIMEOUT) +# %% +from bottle import route, run, template + +@route('/get_text_match') +def get_matches(): + valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'} + query = request.query.get('query','A large bird eating corn') + cameras = request.query.get('cameras','sidefeeder') + cams = set(cameras.split(',')).intersection(valid_cameras) + + max_age = int(request.query.get('age',5)); + print({'Cameras':cams,'Max Age':max_age,'Query':query}) +# %% + max_date = datetime.now() + min_date = max_date - timedelta(days=(max_age)) + + days_step = (max_date- min_date).days + date_arrays = list() + for i in range(days_step): + date_arrays.append(max_date - timedelta(days=i)) +# %% + string_filter = list() + for cand_date in date_arrays: + string_filter.append(cand_date.strftime('%Y/%m/%d')) + + + should_list = list() + if max_age == 0: + string_filter=[''] + + for str_filt in string_filter: + for cam in cams: + str_use = cam+'/'+str_filt + print(str_use) + ccond = FieldCondition(key='filepath',match=MatchText(text=str_use)) + should_list.append(ccond) + + + + + condition_dict = Filter(should = should_list) + if len(should_list) == 0: + return_this = {'query':query,'num_videos':0,'results':[]} + return return_this + + + averaged = False + if isinstance(query, str): + averaged = bool(averaged) + + num_videos = int(request.query.get('num_videos',5)) + + if do_load: + with torch.no_grad(): + text_tokenized = tokenizer(query) + vec = model.encode_text(text_tokenized) +# text_input = txt_processors['eval'](query) +# sample = {'text_input':text_input} +# vec = model.extract_features( sample) + vec_search = vec.cpu().numpy().squeeze().tolist() + else: + sz_vec= client.get_collection(collection_name).config.params.vectors.size + vec_search = np.random.random(sz_vec).tolist() + + + + if averaged: + col_name=collection_name+'_averaged' + else: + col_name = collection_name + +# %% + try: + error = '' + if True: + for i in range(num_videos, 100, 10): + results = client.search(collection_name = col_name, + query_vector = vec_search, limit=i, query_filter=condition_dict, timeout=TIMEOUT) + num_video_got = len(set([x.payload['filepath'] for x in results])) + if num_video_got >= num_videos: + break + else: + results = client.search(collection_name = col_name, + query_vector = vec_search, limit=1, query_filter=condition_dict, timeout=TIMEOUT) + except Exception as e: + print(traceback.format_exc()) + error = traceback.format_exc() + results = []; + + + def linux_to_win_path(form): + form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/') + return form + + def normalize_to_merged(path): + path = path.replace('/srv/ftp','/mergedfs/ftp') + path = path.replace('/mnt/archive2/videos/ftp','/mergedfs/ftp') + return path + + resul = list() + for x in results: + pload =dict( x.payload) + pload['filepath'] = normalize_to_merged(pload['filepath']) + pload['score'] = x.score + pload['winpath'] = linux_to_win_path(pload['filepath']) + resul.append(pload) +# %% + return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error} + return return_this + + + + +debug(True) +run(host='0.0.0.0', port=53003, server='bjoern') diff --git a/update_qdrant.py b/update_qdrant.py index debd01a..f008783 100644 --- a/update_qdrant.py +++ b/update_qdrant.py @@ -3,12 +3,17 @@ from qdrant_client.http import models from qdrant_client.models import Distance, VectorParams client = QdrantClient(host="localhost", port=6333) - +# %% collection_head = "nuggets_clip" collection_head = "nuggets_so400m" +out = client.scroll('nuggets_so400m',limit=10,with_vectors=True) +print(len(out[0][0].vector)) +# %% +ou = np.asarray([x.vector for x in out[0]], dtype=np.float16) +# %% for collection_name in [collection_head, collection_head +'_averaged']: try: client.create_collection( @@ -33,3 +38,27 @@ for collection_name in [collection_head, collection_head +'_averaged']: from prettyprinter import cpprint cpprint(client.get_collection(collection_name).dict()['vectors_count']) +# %% + + + +client.update_collection( + collection_name=f"{collection_name}", + + hnsw_config=models.HnswConfigDiff( + on_disk=True + ), + + vectors_config={ + "": models.VectorParamsDiff( + on_disk=True, + hnsw_config=models.HnswConfigDiff( + on_disk=True + ), + ) + } +) + + +# %% +cpprint(client.get_collection(collection_name).dict())