YACWC
This commit is contained in:
31
clip_endpoint.py
Normal file
31
clip_endpoint.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import open_clip
|
||||||
|
import torch
|
||||||
|
|
||||||
|
do_load = True
|
||||||
|
if do_load:
|
||||||
|
|
||||||
|
#model_name = 'hf-hub:timm/ViT-L-16-SigLIP2-512'
|
||||||
|
# model, preprocess = open_clip.create_model_from_pretrained('hf-hub:timm/ViT-L-16-SigLIP2-512')
|
||||||
|
|
||||||
|
|
||||||
|
model_name = 'ViT-SO400M-14-SigLIP-384'
|
||||||
|
pretrained_name = 'webli'
|
||||||
|
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained_name)
|
||||||
|
device = 'cpu'
|
||||||
|
model.eval()
|
||||||
|
tokenizer = open_clip.get_tokenizer(model_name)
|
||||||
|
|
||||||
|
from bottle import route, run, template, request, debug
|
||||||
|
@route('/encode')
|
||||||
|
def get_matches():
|
||||||
|
query = request.query.get('query','A large bird eating corn')
|
||||||
|
with torch.no_grad():
|
||||||
|
text_tokenized = tokenizer(query)
|
||||||
|
vec = model.encode_text(text_tokenized).detach().cpu().tolist()
|
||||||
|
|
||||||
|
return {'vector':vec}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
debug(True)
|
||||||
|
run(host='0.0.0.0', port=53004, server='bjoern')
|
||||||
21
dump_qdrant.py
Normal file
21
dump_qdrant.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.http import models
|
||||||
|
from qdrant_client.models import Distance, VectorParams
|
||||||
|
client = QdrantClient(host="localhost", port=6333)
|
||||||
|
|
||||||
|
|
||||||
|
collection_head = "nuggets_so400m"
|
||||||
|
out = client.scroll('nuggets_so400m',limit=10,with_vectors=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
offset_id = None
|
||||||
|
#while True:
|
||||||
|
if True:
|
||||||
|
res = client.scroll(collection_name=collection_head,
|
||||||
|
limit=1000000,
|
||||||
|
offset=offset_id,
|
||||||
|
with_payload=False,
|
||||||
|
with_vectors=True);
|
||||||
|
|
||||||
|
# ou = np.asarray([x.vector for x in out[0]], dtype=np.float16)
|
||||||
15
milvus_migrate/*dashboard*
Normal file
15
milvus_migrate/*dashboard*
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from pymilvus import MilvusClient, DataType
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
client = MilvusClient(
|
||||||
|
uri="http://localhost:19530"
|
||||||
|
)
|
||||||
|
cname = 'nuggets_so400m'
|
||||||
|
client.get_collection_stats('nuggets_so400m')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vec = np.random.random(1152).astype(np.float16)
|
||||||
|
|
||||||
|
|
||||||
|
client.search
|
||||||
56
milvus_migrate/create_collection.py
Normal file
56
milvus_migrate/create_collection.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from pymilvus import MilvusClient, DataType
|
||||||
|
|
||||||
|
# 1. Set up a Milvus client
|
||||||
|
client = MilvusClient(
|
||||||
|
uri="http://localhost:19530"
|
||||||
|
)
|
||||||
|
client.get_collection_stats('nuggets_so400m')
|
||||||
|
# %%
|
||||||
|
|
||||||
|
schema = MilvusClient.create_schema(
|
||||||
|
auto_id=False,
|
||||||
|
enable_dynamic_field=False,
|
||||||
|
)
|
||||||
|
schema.add_field(field_name="primary_id",datatype=DataType.INT64, is_primary=True)
|
||||||
|
schema.add_field(field_name="filepath", datatype=DataType.VARCHAR, max_length=128)
|
||||||
|
schema.add_field(field_name="frame_number", datatype=DataType.INT32)
|
||||||
|
schema.add_field(field_name="so400m", datatype=DataType.FLOAT16_VECTOR, dim=1152)
|
||||||
|
|
||||||
|
|
||||||
|
index_params = client.prepare_index_params()
|
||||||
|
|
||||||
|
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="primary_id",
|
||||||
|
index_type="STL_SORT")
|
||||||
|
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="filepath",
|
||||||
|
index_type="Trie")
|
||||||
|
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="so400m",
|
||||||
|
index_type="IVF_FLAT",
|
||||||
|
metric_type="COSINE",
|
||||||
|
params={ "nlist": 128 })
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
client.create_collection(
|
||||||
|
collection_name="nuggets_so400m",
|
||||||
|
schema=schema,
|
||||||
|
index_params=index_params
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
res = client.get_load_state(
|
||||||
|
collection_name="nuggets_so400m"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
res = client.load_collection(collection_name="nuggets_so400m")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
66
milvus_migrate/create_collection_v2.py
Normal file
66
milvus_migrate/create_collection_v2.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
from pymilvus import MilvusClient, DataType
|
||||||
|
|
||||||
|
|
||||||
|
client = MilvusClient(
|
||||||
|
uri="http://localhost:19530"
|
||||||
|
)
|
||||||
|
for x in client.list_collections():
|
||||||
|
client.drop_collection(x)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
import os
|
||||||
|
out = os.listdir('/mergedfs/ftp/')
|
||||||
|
|
||||||
|
# %%
|
||||||
|
for cam in out:
|
||||||
|
schema = MilvusClient.create_schema(
|
||||||
|
auto_id=False,
|
||||||
|
nenable_dynamic_field=False,
|
||||||
|
)
|
||||||
|
schema.add_field(field_name="primary_id",datatype=DataType.INT64, is_primary=True)
|
||||||
|
schema.add_field(field_name="filepath", datatype=DataType.VARCHAR, max_length=128)
|
||||||
|
schema.add_field(field_name="frame_number", datatype=DataType.INT32)
|
||||||
|
schema.add_field(field_name="date", datatype=DataType.VARCHAR, max_length=len('20241220'), is_partition_key=True)
|
||||||
|
schema.add_field(field_name="so400m", datatype=DataType.FLOAT16_VECTOR, dim=1152)
|
||||||
|
|
||||||
|
|
||||||
|
index_params = client.prepare_index_params()
|
||||||
|
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="primary_id",
|
||||||
|
index_type="STL_SORT")
|
||||||
|
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="filepatph",
|
||||||
|
index_type="Trie")
|
||||||
|
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="so400m",
|
||||||
|
index_type="IVF_SQ8",
|
||||||
|
metric_type="COSINE",
|
||||||
|
params={'nlist':128})
|
||||||
|
|
||||||
|
index_params.add_index(
|
||||||
|
field_name='date',
|
||||||
|
index_type='Trie')
|
||||||
|
|
||||||
|
client.create_collection(
|
||||||
|
collection_name=f"nuggets_{cam}_so400m",
|
||||||
|
schema=schema,
|
||||||
|
index_params=index_params
|
||||||
|
)
|
||||||
|
print(cam)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
res = client.get_load_state(
|
||||||
|
collection_name="nuggets_so400m"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
res = client.load_collection(collection_name="nuggets_so400m")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
64
milvus_migrate/create_collection_v3.py
Normal file
64
milvus_migrate/create_collection_v3.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
from pymilvus import MilvusClient, DataType
|
||||||
|
|
||||||
|
|
||||||
|
client = MilvusClient(
|
||||||
|
uri="http://localhost:19530"
|
||||||
|
)
|
||||||
|
#for x in client.list_collections():
|
||||||
|
# client.drop_collection(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
import os
|
||||||
|
out = os.listdir('/mergedfs/ftp/')
|
||||||
|
|
||||||
|
# %%
|
||||||
|
for cam in out:
|
||||||
|
schema = MilvusClient.create_schema(
|
||||||
|
auto_id=False,
|
||||||
|
nenable_dynamic_field=False,
|
||||||
|
)
|
||||||
|
schema.add_field(field_name="primary_id",datatype=DataType.INT64, is_primary=True)
|
||||||
|
schema.add_field(field_name="filepath", datatype=DataType.VARCHAR, max_length=128)
|
||||||
|
schema.add_field(field_name="frame_number", datatype=DataType.INT32)
|
||||||
|
schema.add_field(field_name="date", datatype=DataType.VARCHAR, max_length=len('20241220'), is_partition_key=True)
|
||||||
|
schema.add_field(field_name="so400m", datatype=DataType.FLOAT16_VECTOR, dim=1024)
|
||||||
|
|
||||||
|
|
||||||
|
index_params = client.prepare_index_params()
|
||||||
|
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="primary_id",
|
||||||
|
index_type="STL_SORT")
|
||||||
|
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="filepath",
|
||||||
|
index_type="Trie")
|
||||||
|
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="so400m",
|
||||||
|
index_type="IVF_SQ8",
|
||||||
|
metric_type="COSINE",
|
||||||
|
params={'nlist':128})
|
||||||
|
|
||||||
|
index_params.add_index(
|
||||||
|
field_name='date',
|
||||||
|
index_type='Trie')
|
||||||
|
|
||||||
|
client.create_collection(
|
||||||
|
collection_name=f"nuggets_{cam}_so400m_siglip2",
|
||||||
|
schema=schema,
|
||||||
|
index_params=index_params
|
||||||
|
)
|
||||||
|
print(cam)
|
||||||
|
|
||||||
|
#res = client.get_load_state( collection_name="nuggets_ptz_so400m_siglip2")
|
||||||
|
#res = client.load_collection(collection_name="nuggets_so400m")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
25
milvus_migrate/search_try.py
Normal file
25
milvus_migrate/search_try.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from pymilvus import MilvusClient, DataType
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# 1. Set up a Milvus client
|
||||||
|
client = MilvusClient(uri="http://localhost:19530")
|
||||||
|
cname = "nuggets_so400m"
|
||||||
|
ou = client.get_collection_stats(cname)
|
||||||
|
|
||||||
|
import random
|
||||||
|
vec = [random.random() for x in range(1152)]
|
||||||
|
# %%
|
||||||
|
|
||||||
|
from prettyprinter import cpprint
|
||||||
|
vec = random.
|
||||||
|
out = client.search(
|
||||||
|
collection_name=cname,
|
||||||
|
limit = 100,
|
||||||
|
data=[vec],
|
||||||
|
output_fields=["filepath", "frame_number"],
|
||||||
|
filter='(filepath like "%2024/09/20%") or (filepath like "%2024/09/23%")'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
cpprint([x['entity']['filepath'] for x in out[0]])
|
||||||
75
milvus_migrate/upload_from_folder.py
Normal file
75
milvus_migrate/upload_from_folder.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from pymilvus import MilvusClient, DataType
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from pymilvus.client.types import LoadState
|
||||||
|
client = MilvusClient(
|
||||||
|
uri="http://localhost:19530"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
res = client.get_load_state(
|
||||||
|
collection_name="nuggets_so400m"
|
||||||
|
)
|
||||||
|
if res['state'] == LoadState.Loaded:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
client.load_collection(collection_name = 'nuggets_so400m')
|
||||||
|
for i in range(10):
|
||||||
|
time.sleep(1)
|
||||||
|
if res['state'] == LoadState.Loaded:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def get_vec_path(vpath):
|
||||||
|
return os.path.splitext(vpath)[0]+'.oclip_embeds.npz'
|
||||||
|
|
||||||
|
def get_db_embed_done_path(vpath):
|
||||||
|
return os.path.splitext(vpath)[0]+'.db_has_oclip_embeds'
|
||||||
|
|
||||||
|
|
||||||
|
def upload_vector_file(vector_file_to_upload):
|
||||||
|
if os.path.exists(get_embed_done_path(vector_file_to_upload)):
|
||||||
|
print('Already exists in DB, skipping upload')
|
||||||
|
return
|
||||||
|
|
||||||
|
vector_file_to_upload = get_vec_path(vector_file_to_upload)
|
||||||
|
vf = np.load(vector_file_to_upload)
|
||||||
|
|
||||||
|
embeds = vf['embeds']
|
||||||
|
fr_nums = vf['frame_numbers']
|
||||||
|
|
||||||
|
fname_root = vector_file_to_upload.rsplit('/',1)[-1].split('.')[0]
|
||||||
|
fc = fname_root.split('_')[-1]
|
||||||
|
|
||||||
|
data = list()
|
||||||
|
filepath = vector_file_to_upload.replace('/srv/ftp/','').replace('/mergedfs/ftp','').split('.')[-0]
|
||||||
|
|
||||||
|
for embed, frame_num in zip(embeds, fr_nums):
|
||||||
|
fg = '{0:05g}'.format(frame_num)
|
||||||
|
id_num = int(fc+fg)
|
||||||
|
to_put = dict(primary_id= id_num, filepath=filepath, frame_number = int(frame_num), so400m=embed)
|
||||||
|
data.append(to_put)
|
||||||
|
|
||||||
|
client.insert(collection_name = 'nuggets_so400m', data = data)
|
||||||
|
print(f'Inserting into DB, {vector_file_to_upload}')
|
||||||
|
|
||||||
|
with open(get_embed_done_path(vector_file_to_upload),'w') as ff:
|
||||||
|
ff.write(str(time.time()))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
root_path = '/srv/ftp/railing/2024'
|
||||||
|
to_put = list()
|
||||||
|
for root, dirs, files in os.walk(root_path):
|
||||||
|
for x in files:
|
||||||
|
if x.endswith('oclip_embeds.npz'):
|
||||||
|
to_put.append(os.path.join(root, x))
|
||||||
|
|
||||||
|
|
||||||
|
for x in to_put:
|
||||||
|
upload_vector_file(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
38
ngt_migrate/create_index.py
Normal file
38
ngt_migrate/create_index.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import ngtpy
|
||||||
|
dim = 1152
|
||||||
|
index_path = b"/mnt/ssd_nvm/ngt/openclip_so400m"
|
||||||
|
if not os.path.exists(index_path):
|
||||||
|
ngtpy.create(index_path, dim)
|
||||||
|
print(f'Created index at {index_path}')
|
||||||
|
|
||||||
|
index = ngtpy.Index(index_path)
|
||||||
|
# %%
|
||||||
|
|
||||||
|
import os
|
||||||
|
to_add_to_index = list()
|
||||||
|
for root, dirs, files in os.walk('/mergedfs/ftp'):
|
||||||
|
for x in files:
|
||||||
|
if x.endswith('.oclip_embeds.npz'):
|
||||||
|
to_add_to_index.append(os.path.join(root, x))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import numpy as np
|
||||||
|
import progressbar
|
||||||
|
# %%
|
||||||
|
total_vecs = 0
|
||||||
|
import progressbar
|
||||||
|
bar = progressbar.ProgressBar(max_value = len(to_add_to_index))
|
||||||
|
for idx, to_add in enumerate(to_add_to_index):
|
||||||
|
try:
|
||||||
|
emb_vec = np.load(to_add)['embeds']
|
||||||
|
total_vecs+= emb_vec.shape[0]
|
||||||
|
index.batch_insert(emb_vec)
|
||||||
|
index.save()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
bar.update(idx)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
128
search_me.py
128
search_me.py
@@ -1,132 +1,104 @@
|
|||||||
do_load = True
|
do_load = True
|
||||||
from qdrant_client import QdrantClient
|
import requests
|
||||||
from qdrant_client.models import Filter, FieldCondition, Range, MatchText
|
from pymilvus import MilvusClient, DataType
|
||||||
|
from pymilvus.client.types import LoadState
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import traceback
|
import traceback
|
||||||
from bottle import route, run, template, request, debug
|
from bottle import route, run, template, request, debug
|
||||||
import open_clip
|
import time
|
||||||
|
import prettyprinter
|
||||||
# %%
|
# %%
|
||||||
if do_load:
|
|
||||||
|
|
||||||
from lavis.models import load_model_and_preprocess, model_zoo
|
|
||||||
import torch
|
|
||||||
device = 'cpu'
|
|
||||||
model, vis_processors, txt_processors = load_model_and_preprocess("clip_feature_extractor", model_type="ViT-B-16", is_eval=True, device=device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
TIMEOUT=2
|
TIMEOUT=2
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
collection_name = "nuggets_so400m"
|
#collection_name = "nuggets_{camera}_so400m_siglip2"
|
||||||
client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True, timeout=TIMEOUT)
|
collection_name = "nuggets_{camera}_so400m"
|
||||||
|
client = MilvusClient(
|
||||||
|
uri="http://localhost:19530"
|
||||||
|
)
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
from bottle import route, run, template
|
from bottle import route, run, template
|
||||||
|
|
||||||
@route('/get_text_match')
|
@route('/get_text_match')
|
||||||
def get_matches():
|
def get_matches():
|
||||||
|
# %%
|
||||||
valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'}
|
valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'}
|
||||||
query = request.query.get('query','A large bird eating corn')
|
query = request.query.get('query','A large bird eating corn')
|
||||||
cameras = request.query.get('cameras','sidefeeder')
|
cameras = request.query.get('cameras','sidefeeder')
|
||||||
|
num_videos = int(request.query.get('num_videos',5))
|
||||||
|
|
||||||
cams = set(cameras.split(',')).intersection(valid_cameras)
|
cams = set(cameras.split(',')).intersection(valid_cameras)
|
||||||
|
|
||||||
max_age = int(request.query.get('age',5));
|
max_age = int(request.query.get('age',5));
|
||||||
print({'Cameras':cams,'Max Age':max_age,'Query':query})
|
print({'Cameras':cams,'Max Age':max_age,'Query':query})
|
||||||
# %%
|
|
||||||
max_date = datetime.now()
|
max_date = datetime.now()
|
||||||
min_date = max_date - timedelta(days=(max_age))
|
min_date = max_date - timedelta(days=(max_age))
|
||||||
|
|
||||||
days_step = (max_date- min_date).days
|
days_step = (max_date- min_date).days
|
||||||
date_arrays = list()
|
day_strs = list()
|
||||||
for i in range(days_step):
|
# %%
|
||||||
date_arrays.append(max_date - timedelta(days=i))
|
for x in range(days_step):
|
||||||
# %%
|
day_strs.append( (min_date + timedelta(days=x)).strftime('%Y%m%d') )
|
||||||
string_filter = list()
|
|
||||||
for cand_date in date_arrays:
|
|
||||||
string_filter.append(cand_date.strftime('%Y/%m/%d'))
|
|
||||||
|
|
||||||
|
str_insert = ','.join([f'"{x}"'for x in day_strs])
|
||||||
|
filter_string = 'date in [' + str_insert + ']'
|
||||||
|
|
||||||
should_list = list()
|
|
||||||
if max_age == 0:
|
if max_age == 0:
|
||||||
string_filter=['']
|
filter_string = ''
|
||||||
|
# %%
|
||||||
for str_filt in string_filter:
|
vec_form = requests.get('http://192.168.1.242:53004/encode',params={'query':query}).json()['vector'][0]
|
||||||
for cam in cams:
|
vec_search = np.asarray(vec_form).astype(np.float16)
|
||||||
str_use = cam+'/'+str_filt
|
|
||||||
print(str_use)
|
|
||||||
ccond = FieldCondition(key='filepath',match=MatchText(text=str_use))
|
|
||||||
should_list.append(ccond)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
condition_dict = Filter(should = should_list)
|
|
||||||
if len(should_list) == 0:
|
|
||||||
return_this = {'query':query,'num_videos':0,'results':[]}
|
|
||||||
return return_this
|
|
||||||
|
|
||||||
|
|
||||||
averaged = False
|
|
||||||
if isinstance(query, str):
|
|
||||||
averaged = bool(averaged)
|
|
||||||
|
|
||||||
num_videos = int(request.query.get('num_videos',5))
|
|
||||||
|
|
||||||
if do_load:
|
|
||||||
with torch.no_grad():
|
|
||||||
text_input = txt_processors['eval'](query)
|
|
||||||
sample = {'text_input':text_input}
|
|
||||||
vec = model.extract_features( sample)
|
|
||||||
vec_search = vec.cpu().numpy().squeeze().tolist()
|
|
||||||
else:
|
|
||||||
sz_vec= client.get_collection(collection_name).config.params.vectors.size
|
|
||||||
vec_search = np.random.random(sz_vec).tolist()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if averaged:
|
|
||||||
col_name=collection_name+'_averaged'
|
|
||||||
else:
|
|
||||||
col_name = collection_name
|
|
||||||
|
|
||||||
|
|
||||||
|
print(f'Doing query for {query}!')
|
||||||
|
all_results=list()
|
||||||
try:
|
try:
|
||||||
error = ''
|
error = ''
|
||||||
if True:
|
if True:
|
||||||
for i in range(num_videos, 100, 10):
|
|
||||||
|
for cam in cams:
|
||||||
|
col_name = collection_name.format(camera=cam)
|
||||||
|
for i in [num_videos]:
|
||||||
|
|
||||||
results = client.search(collection_name = col_name,
|
results = client.search(collection_name = col_name,
|
||||||
query_vector = vec_search, limit=i, query_filter=condition_dict, timeout=TIMEOUT)
|
consistency_level="Eventually",
|
||||||
num_video_got = len(set([x.payload['filepath'] for x in results]))
|
filter=filter_string,
|
||||||
if num_video_got >= num_videos:
|
data = [vec_search],
|
||||||
break
|
limit=i,
|
||||||
else:
|
search_params={'metric_type':'COSINE', 'params':{}},
|
||||||
results = client.search(collection_name = col_name,
|
output_fields=['filepath','frame_number']
|
||||||
query_vector = vec_search, limit=1, query_filter=condition_dict, timeout=TIMEOUT)
|
)
|
||||||
|
all_results.extend(results[0])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
error = traceback.format_exc()
|
error = traceback.format_exc()
|
||||||
results = [];
|
results = []
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
def linux_to_win_path(form):
|
def linux_to_win_path(form):
|
||||||
form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/')
|
form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/')
|
||||||
return form
|
return form
|
||||||
|
|
||||||
def normalize_to_merged(path):
|
def normalize_to_merged(path):
|
||||||
path = path.replace('/srv/ftp','/mergedfs/ftp')
|
if not path.startswith('/'):
|
||||||
path = path.replace('/mnt/archive2/videos/ftp','/mergedfs/ftp')
|
path= '/mergedfs/ftp/'+path
|
||||||
return path
|
return path
|
||||||
|
|
||||||
resul = list()
|
resul = list()
|
||||||
for x in results:
|
all_results = sorted(all_results, key=lambda x: x['distance'])
|
||||||
pload =dict( x.payload)
|
for x in all_results:
|
||||||
|
pload =dict( x['entity'])
|
||||||
pload['filepath'] = normalize_to_merged(pload['filepath'])
|
pload['filepath'] = normalize_to_merged(pload['filepath'])
|
||||||
pload['score'] = x.score
|
pload['score'] = x['distance']
|
||||||
pload['winpath'] = linux_to_win_path(pload['filepath'])
|
pload['winpath'] = linux_to_win_path(pload['filepath'])
|
||||||
|
pload['frame'] = pload['frame_number']
|
||||||
resul.append(pload)
|
resul.append(pload)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error}
|
return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error}
|
||||||
|
prettyprinter.pprint(return_this)
|
||||||
return return_this
|
return return_this
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
139
search_me_qdrant.py
Normal file
139
search_me_qdrant.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
do_load = True
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.models import Filter, FieldCondition, Range, MatchText
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import numpy as np
|
||||||
|
import traceback
|
||||||
|
from bottle import route, run, template, request, debug
|
||||||
|
import open_clip
|
||||||
|
import torch
|
||||||
|
# %%
|
||||||
|
if do_load:
|
||||||
|
model_name = 'ViT-SO400M-14-SigLIP-384'
|
||||||
|
pretrained_name = 'webli'
|
||||||
|
|
||||||
|
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained_name)
|
||||||
|
device = 'cpu'
|
||||||
|
model.eval()
|
||||||
|
tokenizer = open_clip.get_tokenizer(model_name)
|
||||||
|
|
||||||
|
TIMEOUT=2
|
||||||
|
# %%
|
||||||
|
|
||||||
|
collection_name = "nuggets_so400m"
|
||||||
|
client = QdrantClient(host="localhost", grpc_port=6334, prefer_grpc=True, timeout=TIMEOUT)
|
||||||
|
# %%
|
||||||
|
from bottle import route, run, template
|
||||||
|
|
||||||
|
@route('/get_text_match')
|
||||||
|
def get_matches():
|
||||||
|
valid_cameras = {'sidefeeder','ptz','railing','hummingbird','pond'}
|
||||||
|
query = request.query.get('query','A large bird eating corn')
|
||||||
|
cameras = request.query.get('cameras','sidefeeder')
|
||||||
|
cams = set(cameras.split(',')).intersection(valid_cameras)
|
||||||
|
|
||||||
|
max_age = int(request.query.get('age',5));
|
||||||
|
print({'Cameras':cams,'Max Age':max_age,'Query':query})
|
||||||
|
# %%
|
||||||
|
max_date = datetime.now()
|
||||||
|
min_date = max_date - timedelta(days=(max_age))
|
||||||
|
|
||||||
|
days_step = (max_date- min_date).days
|
||||||
|
date_arrays = list()
|
||||||
|
for i in range(days_step):
|
||||||
|
date_arrays.append(max_date - timedelta(days=i))
|
||||||
|
# %%
|
||||||
|
string_filter = list()
|
||||||
|
for cand_date in date_arrays:
|
||||||
|
string_filter.append(cand_date.strftime('%Y/%m/%d'))
|
||||||
|
|
||||||
|
|
||||||
|
should_list = list()
|
||||||
|
if max_age == 0:
|
||||||
|
string_filter=['']
|
||||||
|
|
||||||
|
for str_filt in string_filter:
|
||||||
|
for cam in cams:
|
||||||
|
str_use = cam+'/'+str_filt
|
||||||
|
print(str_use)
|
||||||
|
ccond = FieldCondition(key='filepath',match=MatchText(text=str_use))
|
||||||
|
should_list.append(ccond)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
condition_dict = Filter(should = should_list)
|
||||||
|
if len(should_list) == 0:
|
||||||
|
return_this = {'query':query,'num_videos':0,'results':[]}
|
||||||
|
return return_this
|
||||||
|
|
||||||
|
|
||||||
|
averaged = False
|
||||||
|
if isinstance(query, str):
|
||||||
|
averaged = bool(averaged)
|
||||||
|
|
||||||
|
num_videos = int(request.query.get('num_videos',5))
|
||||||
|
|
||||||
|
if do_load:
|
||||||
|
with torch.no_grad():
|
||||||
|
text_tokenized = tokenizer(query)
|
||||||
|
vec = model.encode_text(text_tokenized)
|
||||||
|
# text_input = txt_processors['eval'](query)
|
||||||
|
# sample = {'text_input':text_input}
|
||||||
|
# vec = model.extract_features( sample)
|
||||||
|
vec_search = vec.cpu().numpy().squeeze().tolist()
|
||||||
|
else:
|
||||||
|
sz_vec= client.get_collection(collection_name).config.params.vectors.size
|
||||||
|
vec_search = np.random.random(sz_vec).tolist()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if averaged:
|
||||||
|
col_name=collection_name+'_averaged'
|
||||||
|
else:
|
||||||
|
col_name = collection_name
|
||||||
|
|
||||||
|
# %%
|
||||||
|
try:
|
||||||
|
error = ''
|
||||||
|
if True:
|
||||||
|
for i in range(num_videos, 100, 10):
|
||||||
|
results = client.search(collection_name = col_name,
|
||||||
|
query_vector = vec_search, limit=i, query_filter=condition_dict, timeout=TIMEOUT)
|
||||||
|
num_video_got = len(set([x.payload['filepath'] for x in results]))
|
||||||
|
if num_video_got >= num_videos:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
results = client.search(collection_name = col_name,
|
||||||
|
query_vector = vec_search, limit=1, query_filter=condition_dict, timeout=TIMEOUT)
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
error = traceback.format_exc()
|
||||||
|
results = [];
|
||||||
|
|
||||||
|
|
||||||
|
def linux_to_win_path(form):
|
||||||
|
form = form.replace('/srv','file://192.168.1.242/thebears/Videos/merged/')
|
||||||
|
return form
|
||||||
|
|
||||||
|
def normalize_to_merged(path):
|
||||||
|
path = path.replace('/srv/ftp','/mergedfs/ftp')
|
||||||
|
path = path.replace('/mnt/archive2/videos/ftp','/mergedfs/ftp')
|
||||||
|
return path
|
||||||
|
|
||||||
|
resul = list()
|
||||||
|
for x in results:
|
||||||
|
pload =dict( x.payload)
|
||||||
|
pload['filepath'] = normalize_to_merged(pload['filepath'])
|
||||||
|
pload['score'] = x.score
|
||||||
|
pload['winpath'] = linux_to_win_path(pload['filepath'])
|
||||||
|
resul.append(pload)
|
||||||
|
# %%
|
||||||
|
return_this = {'query':query,'num_videos':num_videos,'results':resul,'error':error}
|
||||||
|
return return_this
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
debug(True)
|
||||||
|
run(host='0.0.0.0', port=53003, server='bjoern')
|
||||||
@@ -3,12 +3,17 @@ from qdrant_client.http import models
|
|||||||
from qdrant_client.models import Distance, VectorParams
|
from qdrant_client.models import Distance, VectorParams
|
||||||
|
|
||||||
client = QdrantClient(host="localhost", port=6333)
|
client = QdrantClient(host="localhost", port=6333)
|
||||||
|
# %%
|
||||||
collection_head = "nuggets_clip"
|
collection_head = "nuggets_clip"
|
||||||
collection_head = "nuggets_so400m"
|
collection_head = "nuggets_so400m"
|
||||||
|
out = client.scroll('nuggets_so400m',limit=10,with_vectors=True)
|
||||||
|
print(len(out[0][0].vector))
|
||||||
|
# %%
|
||||||
|
ou = np.asarray([x.vector for x in out[0]], dtype=np.float16)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
for collection_name in [collection_head, collection_head +'_averaged']:
|
for collection_name in [collection_head, collection_head +'_averaged']:
|
||||||
try:
|
try:
|
||||||
client.create_collection(
|
client.create_collection(
|
||||||
@@ -33,3 +38,27 @@ for collection_name in [collection_head, collection_head +'_averaged']:
|
|||||||
from prettyprinter import cpprint
|
from prettyprinter import cpprint
|
||||||
cpprint(client.get_collection(collection_name).dict()['vectors_count'])
|
cpprint(client.get_collection(collection_name).dict()['vectors_count'])
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
client.update_collection(
|
||||||
|
collection_name=f"{collection_name}",
|
||||||
|
|
||||||
|
hnsw_config=models.HnswConfigDiff(
|
||||||
|
on_disk=True
|
||||||
|
),
|
||||||
|
|
||||||
|
vectors_config={
|
||||||
|
"": models.VectorParamsDiff(
|
||||||
|
on_disk=True,
|
||||||
|
hnsw_config=models.HnswConfigDiff(
|
||||||
|
on_disk=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
cpprint(client.get_collection(collection_name).dict())
|
||||||
|
|||||||
Reference in New Issue
Block a user