This commit is contained in:
2025-04-17 15:55:54 -04:00
parent 5ca30d5e11
commit 79f9a1dbc2
12 changed files with 611 additions and 80 deletions

31
clip_endpoint.py Normal file
View File

@@ -0,0 +1,31 @@
import open_clip
import torch
do_load = True
if do_load:
#model_name = 'hf-hub:timm/ViT-L-16-SigLIP2-512'
# model, preprocess = open_clip.create_model_from_pretrained('hf-hub:timm/ViT-L-16-SigLIP2-512')
model_name = 'ViT-SO400M-14-SigLIP-384'
pretrained_name = 'webli'
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained_name)
device = 'cpu'
model.eval()
tokenizer = open_clip.get_tokenizer(model_name)
from bottle import route, run, template, request, debug
@route('/encode')
def get_matches():
query = request.query.get('query','A large bird eating corn')
with torch.no_grad():
text_tokenized = tokenizer(query)
vec = model.encode_text(text_tokenized).detach().cpu().tolist()
return {'vector':vec}
debug(True)
run(host='0.0.0.0', port=53004, server='bjoern')

21
dump_qdrant.py Normal file
View File

@@ -0,0 +1,21 @@
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.models import Distance, VectorParams
client = QdrantClient(host="localhost", port=6333)
collection_head = "nuggets_so400m"
out = client.scroll('nuggets_so400m',limit=10,with_vectors=True)
offset_id = None
#while True:
if True:
res = client.scroll(collection_name=collection_head,
limit=1000000,
offset=offset_id,
with_payload=False,
with_vectors=True);
# ou = np.asarray([x.vector for x in out[0]], dtype=np.float16)

View File

@@ -0,0 +1,15 @@
from pymilvus import MilvusClient, DataType
import numpy as np
client = MilvusClient(
uri="http://localhost:19530"
)
cname = 'nuggets_so400m'
client.get_collection_stats('nuggets_so400m')
vec = np.random.random(1152).astype(np.float16)
client.search

View File

@@ -0,0 +1,56 @@
from pymilvus import MilvusClient, DataType
# 1. Set up a Milvus client
client = MilvusClient(
uri="http://localhost:19530"
)
client.get_collection_stats('nuggets_so400m')
# %%
schema = MilvusClient.create_schema(
auto_id=False,
enable_dynamic_field=False,
)
schema.add_field(field_name="primary_id",datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="filepath", datatype=DataType.VARCHAR, max_length=128)
schema.add_field(field_name="frame_number", datatype=DataType.INT32)
schema.add_field(field_name="so400m", datatype=DataType.FLOAT16_VECTOR, dim=1152)
index_params = client.prepare_index_params()
index_params.add_index(
field_name="primary_id",
index_type="STL_SORT")
index_params.add_index(
field_name="filepath",
index_type="Trie")
index_params.add_index(
field_name="so400m",
index_type="IVF_FLAT",
metric_type="COSINE",
params={ "nlist": 128 })
client.create_collection(
collection_name="nuggets_so400m",
schema=schema,
index_params=index_params
)
# %%
res = client.get_load_state(
collection_name="nuggets_so400m"
)
res = client.load_collection(collection_name="nuggets_so400m")

View File

@@ -0,0 +1,66 @@
from pymilvus import MilvusClient, DataType
client = MilvusClient(
uri="http://localhost:19530"
)
for x in client.list_collections():
client.drop_collection(x)
# %%
import os
out = os.listdir('/mergedfs/ftp/')
# %%
for cam in out:
schema = MilvusClient.create_schema(
auto_id=False,
nenable_dynamic_field=False,
)
schema.add_field(field_name="primary_id",datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="filepath", datatype=DataType.VARCHAR, max_length=128)
schema.add_field(field_name="frame_number", datatype=DataType.INT32)
schema.add_field(field_name="date", datatype=DataType.VARCHAR, max_length=len('20241220'), is_partition_key=True)
schema.add_field(field_name="so400m", datatype=DataType.FLOAT16_VECTOR, dim=1152)
index_params = client.prepare_index_params()
index_params.add_index(
field_name="primary_id",
index_type="STL_SORT")
index_params.add_index(
field_name="filepatph",
index_type="Trie")
index_params.add_index(
field_name="so400m",
index_type="IVF_SQ8",
metric_type="COSINE",
params={'nlist':128})
index_params.add_index(
field_name='date',
index_type='Trie')
client.create_collection(
collection_name=f"nuggets_{cam}_so400m",
schema=schema,
index_params=index_params
)
print(cam)
# %%
res = client.get_load_state(
collection_name="nuggets_so400m"
)
res = client.load_collection(collection_name="nuggets_so400m")

View File

@@ -0,0 +1,64 @@
from pymilvus import MilvusClient, DataType
client = MilvusClient(
uri="http://localhost:19530"
)
#for x in client.list_collections():
# client.drop_collection(x)
# %%
import os
out = os.listdir('/mergedfs/ftp/')
# %%
for cam in out:
schema = MilvusClient.create_schema(
auto_id=False,
nenable_dynamic_field=False,
)
schema.add_field(field_name="primary_id",datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="filepath", datatype=DataType.VARCHAR, max_length=128)
schema.add_field(field_name="frame_number", datatype=DataType.INT32)
schema.add_field(field_name="date", datatype=DataType.VARCHAR, max_length=len('20241220'), is_partition_key=True)
schema.add_field(field_name="so400m", datatype=DataType.FLOAT16_VECTOR, dim=1024)
index_params = client.prepare_index_params()
index_params.add_index(
field_name="primary_id",
index_type="STL_SORT")
index_params.add_index(
field_name="filepath",
index_type="Trie")
index_params.add_index(
field_name="so400m",
index_type="IVF_SQ8",
metric_type="COSINE",
params={'nlist':128})
index_params.add_index(
field_name='date',
index_type='Trie')
client.create_collection(
collection_name=f"nuggets_{cam}_so400m_siglip2",
schema=schema,
index_params=index_params
)
print(cam)
#res = client.get_load_state( collection_name="nuggets_ptz_so400m_siglip2")
#res = client.load_collection(collection_name="nuggets_so400m")

View File

@@ -0,0 +1,25 @@
from pymilvus import MilvusClient, DataType
import numpy as np
# 1. Set up a Milvus client
client = MilvusClient(uri="http://localhost:19530")
cname = "nuggets_so400m"
ou = client.get_collection_stats(cname)
import random
vec = [random.random() for x in range(1152)]
# %%
from prettyprinter import cpprint
vec = random.
out = client.search(
collection_name=cname,
limit = 100,
data=[vec],
output_fields=["filepath", "frame_number"],
filter='(filepath like "%2024/09/20%") or (filepath like "%2024/09/23%")'
)
cpprint([x['entity']['filepath'] for x in out[0]])

View File

@@ -0,0 +1,75 @@
from pymilvus import MilvusClient, DataType
import numpy as np
import time
from pymilvus.client.types import LoadState
client = MilvusClient(
uri="http://localhost:19530"
)
res = client.get_load_state(
collection_name="nuggets_so400m"
)
if res['state'] == LoadState.Loaded:
pass
else:
client.load_collection(collection_name = 'nuggets_so400m')
for i in range(10):
time.sleep(1)
if res['state'] == LoadState.Loaded:
break
def get_vec_path(vpath):
return os.path.splitext(vpath)[0]+'.oclip_embeds.npz'
def get_db_embed_done_path(vpath):
return os.path.splitext(vpath)[0]+'.db_has_oclip_embeds'
def upload_vector_file(vector_file_to_upload):
if os.path.exists(get_embed_done_path(vector_file_to_upload)):
print('Already exists in DB, skipping upload')
return
vector_file_to_upload = get_vec_path(vector_file_to_upload)
vf = np.load(vector_file_to_upload)
embeds = vf['embeds']
fr_nums = vf['frame_numbers']
fname_root = vector_file_to_upload.rsplit('/',1)[-1].split('.')[0]
fc = fname_root.split('_')[-1]
data = list()
filepath = vector_file_to_upload.replace('/srv/ftp/','').replace('/mergedfs/ftp','').split('.')[-0]
for embed, frame_num in zip(embeds, fr_nums):
fg = '{0:05g}'.format(frame_num)
id_num = int(fc+fg)
to_put = dict(primary_id= id_num, filepath=filepath, frame_number = int(frame_num), so400m=embed)
data.append(to_put)
client.insert(collection_name = 'nuggets_so400m', data = data)
print(f'Inserting into DB, {vector_file_to_upload}')
with open(get_embed_done_path(vector_file_to_upload),'w') as ff:
ff.write(str(time.time()))
root_path = '/srv/ftp/railing/2024'
to_put = list()
for root, dirs, files in os.walk(root_path):
for x in files:
if x.endswith('oclip_embeds.npz'):
to_put.append(os.path.join(root, x))
for x in to_put:
upload_vector_file(x)

View File

@@ -0,0 +1,38 @@
import ngtpy
dim = 1152
index_path = b"/mnt/ssd_nvm/ngt/openclip_so400m"
if not os.path.exists(index_path):
ngtpy.create(index_path, dim)
print(f'Created index at {index_path}')
index = ngtpy.Index(index_path)
# %%
import os
to_add_to_index = list()
for root, dirs, files in os.walk('/mergedfs/ftp'):
for x in files:
if x.endswith('.oclip_embeds.npz'):
to_add_to_index.append(os.path.join(root, x))
# %%
import numpy as np
import progressbar
# %%
total_vecs = 0
import progressbar
bar = progressbar.ProgressBar(max_value = len(to_add_to_index))
for idx, to_add in enumerate(to_add_to_index):
try:
emb_vec = np.load(to_add)['embeds']
total_vecs+= emb_vec.shape[0]
index.batch_insert(emb_vec)
index.save()
except:
pass
bar.update(idx)

View File

@@ -1,132 +1,104 @@
do_load = True do_load = True
from qdrant_client import QdrantClient import requests
from qdrant_client.models import Filter, FieldCondition, Range, MatchText from pymilvus import MilvusClient, DataType
from pymilvus.client.types import LoadState
from datetime import datetime, timedelta from datetime import datetime, timedelta
import numpy as np import numpy as np
import traceback import traceback
from bottle import route, run, template, request, debug from bottle import route, run, template, request, debug
import open_clip import time
import prettyprinter
# %% # %%
if do_load:
from lavis.models import load_model_and_preprocess, model_zoo
import torch
device = 'cpu'
model, vis_processors, txt_processors = load_model_and_preprocess("clip_feature_extractor", model_type="ViT-B-16", is_eval=True, device=device)
model.eval()
TIMEOUT=2 TIMEOUT=2
# %% # %%
collection_name = "nuggets_so400m" #collection_name = "nuggets_{camera}_so400m_siglip2"
client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True, timeout=TIMEOUT) collection_name = "nuggets_{camera}_so400m"
client = MilvusClient(
uri="http://localhost:19530"
)
# %% # %%
from bottle import route, run, template from bottle import route, run, template
@route('/get_text_match') @route('/get_text_match')
def get_matches(): def get_matches():
# %%
valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'} valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'}
query = request.query.get('query','A large bird eating corn') query = request.query.get('query','A large bird eating corn')
cameras = request.query.get('cameras','sidefeeder') cameras = request.query.get('cameras','sidefeeder')
num_videos = int(request.query.get('num_videos',5))
cams = set(cameras.split(',')).intersection(valid_cameras) cams = set(cameras.split(',')).intersection(valid_cameras)
max_age = int(request.query.get('age',5)); max_age = int(request.query.get('age',5));
print({'Cameras':cams,'Max Age':max_age,'Query':query}) print({'Cameras':cams,'Max Age':max_age,'Query':query})
# %%
max_date = datetime.now() max_date = datetime.now()
min_date = max_date - timedelta(days=(max_age)) min_date = max_date - timedelta(days=(max_age))
days_step = (max_date- min_date).days days_step = (max_date- min_date).days
date_arrays = list() day_strs = list()
for i in range(days_step):
date_arrays.append(max_date - timedelta(days=i))
# %% # %%
string_filter = list() for x in range(days_step):
for cand_date in date_arrays: day_strs.append( (min_date + timedelta(days=x)).strftime('%Y%m%d') )
string_filter.append(cand_date.strftime('%Y/%m/%d'))
str_insert = ','.join([f'"{x}"'for x in day_strs])
filter_string = 'date in [' + str_insert + ']'
should_list = list()
if max_age == 0: if max_age == 0:
string_filter=[''] filter_string = ''
# %%
for str_filt in string_filter: vec_form = requests.get('http://192.168.1.242:53004/encode',params={'query':query}).json()['vector'][0]
for cam in cams: vec_search = np.asarray(vec_form).astype(np.float16)
str_use = cam+'/'+str_filt
print(str_use)
ccond = FieldCondition(key='filepath',match=MatchText(text=str_use))
should_list.append(ccond)
condition_dict = Filter(should = should_list)
if len(should_list) == 0:
return_this = {'query':query,'num_videos':0,'results':[]}
return return_this
averaged = False
if isinstance(query, str):
averaged = bool(averaged)
num_videos = int(request.query.get('num_videos',5))
if do_load:
with torch.no_grad():
text_input = txt_processors['eval'](query)
sample = {'text_input':text_input}
vec = model.extract_features( sample)
vec_search = vec.cpu().numpy().squeeze().tolist()
else:
sz_vec= client.get_collection(collection_name).config.params.vectors.size
vec_search = np.random.random(sz_vec).tolist()
if averaged:
col_name=collection_name+'_averaged'
else:
col_name = collection_name
print(f'Doing query for {query}!')
all_results=list()
try: try:
error = '' error = ''
if True: if True:
for i in range(num_videos, 100, 10):
for cam in cams:
col_name = collection_name.format(camera=cam)
for i in [num_videos]:
results = client.search(collection_name = col_name, results = client.search(collection_name = col_name,
query_vector = vec_search, limit=i, query_filter=condition_dict, timeout=TIMEOUT) consistency_level="Eventually",
num_video_got = len(set([x.payload['filepath'] for x in results])) filter=filter_string,
if num_video_got >= num_videos: data = [vec_search],
break limit=i,
else: search_params={'metric_type':'COSINE', 'params':{}},
results = client.search(collection_name = col_name, output_fields=['filepath','frame_number']
query_vector = vec_search, limit=1, query_filter=condition_dict, timeout=TIMEOUT) )
all_results.extend(results[0])
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
error = traceback.format_exc() error = traceback.format_exc()
results = []; results = []
# %%
def linux_to_win_path(form): def linux_to_win_path(form):
form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/') form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/')
return form return form
def normalize_to_merged(path): def normalize_to_merged(path):
path = path.replace('/srv/ftp','/mergedfs/ftp') if not path.startswith('/'):
path = path.replace('/mnt/archive2/videos/ftp','/mergedfs/ftp') path= '/mergedfs/ftp/'+path
return path return path
resul = list() resul = list()
for x in results: all_results = sorted(all_results, key=lambda x: x['distance'])
pload =dict( x.payload) for x in all_results:
pload =dict( x['entity'])
pload['filepath'] = normalize_to_merged(pload['filepath']) pload['filepath'] = normalize_to_merged(pload['filepath'])
pload['score'] = x.score pload['score'] = x['distance']
pload['winpath'] = linux_to_win_path(pload['filepath']) pload['winpath'] = linux_to_win_path(pload['filepath'])
pload['frame'] = pload['frame_number']
resul.append(pload) resul.append(pload)
# %% # %%
return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error} return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error}
prettyprinter.pprint(return_this)
return return_this return return_this

139
search_me_qdrant.py Normal file
View File

@@ -0,0 +1,139 @@
do_load = True
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, Range, MatchText
from datetime import datetime, timedelta
import numpy as np
import traceback
from bottle import route, run, template, request, debug
import open_clip
import torch
# %%
if do_load:
model_name = 'ViT-SO400M-14-SigLIP-384'
pretrained_name = 'webli'
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained_name)
device = 'cpu'
model.eval()
tokenizer = open_clip.get_tokenizer(model_name)
TIMEOUT=2
# %%
collection_name = "nuggets_so400m"
client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True, timeout=TIMEOUT)
# %%
from bottle import route, run, template
@route('/get_text_match')
def get_matches():
valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'}
query = request.query.get('query','A large bird eating corn')
cameras = request.query.get('cameras','sidefeeder')
cams = set(cameras.split(',')).intersection(valid_cameras)
max_age = int(request.query.get('age',5));
print({'Cameras':cams,'Max Age':max_age,'Query':query})
# %%
max_date = datetime.now()
min_date = max_date - timedelta(days=(max_age))
days_step = (max_date- min_date).days
date_arrays = list()
for i in range(days_step):
date_arrays.append(max_date - timedelta(days=i))
# %%
string_filter = list()
for cand_date in date_arrays:
string_filter.append(cand_date.strftime('%Y/%m/%d'))
should_list = list()
if max_age == 0:
string_filter=['']
for str_filt in string_filter:
for cam in cams:
str_use = cam+'/'+str_filt
print(str_use)
ccond = FieldCondition(key='filepath',match=MatchText(text=str_use))
should_list.append(ccond)
condition_dict = Filter(should = should_list)
if len(should_list) == 0:
return_this = {'query':query,'num_videos':0,'results':[]}
return return_this
averaged = False
if isinstance(query, str):
averaged = bool(averaged)
num_videos = int(request.query.get('num_videos',5))
if do_load:
with torch.no_grad():
text_tokenized = tokenizer(query)
vec = model.encode_text(text_tokenized)
# text_input = txt_processors['eval'](query)
# sample = {'text_input':text_input}
# vec = model.extract_features( sample)
vec_search = vec.cpu().numpy().squeeze().tolist()
else:
sz_vec= client.get_collection(collection_name).config.params.vectors.size
vec_search = np.random.random(sz_vec).tolist()
if averaged:
col_name=collection_name+'_averaged'
else:
col_name = collection_name
# %%
try:
error = ''
if True:
for i in range(num_videos, 100, 10):
results = client.search(collection_name = col_name,
query_vector = vec_search, limit=i, query_filter=condition_dict, timeout=TIMEOUT)
num_video_got = len(set([x.payload['filepath'] for x in results]))
if num_video_got >= num_videos:
break
else:
results = client.search(collection_name = col_name,
query_vector = vec_search, limit=1, query_filter=condition_dict, timeout=TIMEOUT)
except Exception as e:
print(traceback.format_exc())
error = traceback.format_exc()
results = [];
def linux_to_win_path(form):
form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/')
return form
def normalize_to_merged(path):
path = path.replace('/srv/ftp','/mergedfs/ftp')
path = path.replace('/mnt/archive2/videos/ftp','/mergedfs/ftp')
return path
resul = list()
for x in results:
pload =dict( x.payload)
pload['filepath'] = normalize_to_merged(pload['filepath'])
pload['score'] = x.score
pload['winpath'] = linux_to_win_path(pload['filepath'])
resul.append(pload)
# %%
return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error}
return return_this
debug(True)
run(host='0.0.0.0', port=53003, server='bjoern')

View File

@@ -3,12 +3,17 @@ from qdrant_client.http import models
from qdrant_client.models import Distance, VectorParams from qdrant_client.models import Distance, VectorParams
client = QdrantClient(host="localhost", port=6333) client = QdrantClient(host="localhost", port=6333)
# %%
collection_head = "nuggets_clip" collection_head = "nuggets_clip"
collection_head = "nuggets_so400m" collection_head = "nuggets_so400m"
out = client.scroll('nuggets_so400m',limit=10,with_vectors=True)
print(len(out[0][0].vector))
# %%
ou = np.asarray([x.vector for x in out[0]], dtype=np.float16)
# %%
for collection_name in [collection_head, collection_head +'_averaged']: for collection_name in [collection_head, collection_head +'_averaged']:
try: try:
client.create_collection( client.create_collection(
@@ -33,3 +38,27 @@ for collection_name in [collection_head, collection_head +'_averaged']:
from prettyprinter import cpprint from prettyprinter import cpprint
cpprint(client.get_collection(collection_name).dict()['vectors_count']) cpprint(client.get_collection(collection_name).dict()['vectors_count'])
# %%
client.update_collection(
collection_name=f"{collection_name}",
hnsw_config=models.HnswConfigDiff(
on_disk=True
),
vectors_config={
"": models.VectorParamsDiff(
on_disk=True,
hnsw_config=models.HnswConfigDiff(
on_disk=True
),
)
}
)
# %%
cpprint(client.get_collection(collection_name).dict())