65 lines
1.6 KiB
Python
65 lines
1.6 KiB
Python
from qdrant_client import QdrantClient
|
|
from qdrant_client.http import models
|
|
from qdrant_client.models import Distance, VectorParams
|
|
|
|
client = QdrantClient(host="localhost", port=6333)
|
|
# %%
|
|
collection_head = "nuggets_clip"
|
|
collection_head = "nuggets_so400m"
|
|
out = client.scroll('nuggets_so400m',limit=10,with_vectors=True)
|
|
print(len(out[0][0].vector))
|
|
# %%
|
|
ou = np.asarray([x.vector for x in out[0]], dtype=np.float16)
|
|
|
|
|
|
|
|
# %%
|
|
for collection_name in [collection_head, collection_head +'_averaged']:
|
|
try:
|
|
client.create_collection(
|
|
collection_name=collection_name,
|
|
vectors_config=VectorParams(size=1152, distance=Distance.COSINE),
|
|
on_disk_payload = True
|
|
)
|
|
|
|
client.create_payload_index(
|
|
collection_name=collection_name,
|
|
field_name="filepath",
|
|
field_schema=models.TextIndexParams(
|
|
type="text",
|
|
tokenizer=models.TokenizerType.WORD,
|
|
min_token_len=1,
|
|
max_token_len=15,
|
|
lowercase=True,
|
|
),
|
|
)
|
|
except Exception as e:
|
|
print(e)
|
|
from prettyprinter import cpprint
|
|
cpprint(client.get_collection(collection_name).dict()['vectors_count'])
|
|
|
|
# %%
|
|
|
|
|
|
|
|
client.update_collection(
|
|
collection_name=f"{collection_name}",
|
|
|
|
hnsw_config=models.HnswConfigDiff(
|
|
on_disk=True
|
|
),
|
|
|
|
vectors_config={
|
|
"": models.VectorParamsDiff(
|
|
on_disk=True,
|
|
hnsw_config=models.HnswConfigDiff(
|
|
on_disk=True
|
|
),
|
|
)
|
|
}
|
|
)
|
|
|
|
|
|
# %%
|
|
cpprint(client.get_collection(collection_name).dict())
|