Skip to content

Commit

Permalink
use pinecone grpc client
Browse files Browse the repository at this point in the history
list collections in qdrant and pinecone for input
  • Loading branch information
dhruv-anand-aintech committed May 17, 2024
1 parent 888d769 commit c418950
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 22 deletions.
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy
pandas
pinecone-client~=4.0.0
pinecone-client[grpc]~=4.0
pyarrow
qdrant_client
tqdm
Expand Down Expand Up @@ -33,4 +33,5 @@ datasets~=2.16,>=2.19.0
mlx_embedding_models
azure-search-documents
azure-identity
turbopuffer[fast]
turbopuffer[fast]
psycopg2
19 changes: 12 additions & 7 deletions src/vdf_io/export_vdf/pinecone_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from tqdm import tqdm
from halo import Halo

from pinecone import Pinecone, Vector
from pinecone.grpc import PineconeGRPC as Pinecone
from pinecone import Vector

from vdf_io.constants import ID_COLUMN
from vdf_io.names import DBNames
Expand Down Expand Up @@ -114,11 +115,6 @@ def export_vdb(cls, args):
str,
"us-west-2",
)
set_arg_from_input(
args,
"index",
"Enter the name of index to export (hit return to export all): ",
)
set_arg_from_password(
args,
"pinecone_api_key",
Expand Down Expand Up @@ -158,6 +154,13 @@ def export_vdb(cls, args):
"Enter the path to id list file (hit return to skip): ",
)
pinecone_export = ExportPinecone(args)
pinecone_export.all_indexes = pinecone_export.get_all_index_names()
set_arg_from_input(
args,
"index",
"Enter the name of indexes to export (comma-separated) (hit return to export all):",
choices=pinecone_export.all_indexes,
)
pinecone_export.get_data()
return pinecone_export

Expand All @@ -171,6 +174,8 @@ def __init__(self, args):
self.collected_ids_by_modifying = False

def get_index_names(self):
if self.args.get("index"):
return self.args["index"].split(",")
return self.get_all_index_names()

def get_all_index_names(self):
Expand Down Expand Up @@ -239,7 +244,7 @@ def get_ids_from_vector_query(self, input_vector, namespace, all_ids, hash_value
if mark_batch_size < 1:
raise Exception("Could not upsert vectors")
continue
i += resp["upserted_count"]
i += resp.upserted_count
mark_pbar.update(len(batch_ids))
self.collected_ids_by_modifying = True
tqdm.write(f"Marked {len(ids_to_mark)} vectors as exported.")
Expand Down
24 changes: 15 additions & 9 deletions src/vdf_io/export_vdf/qdrant_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,18 @@ def export_vdb(cls, args):
bool,
True,
)
set_arg_from_password(
args, "qdrant_api_key", "Enter your Qdrant API key: ", "QDRANT_API_KEY"
)
qdrant_export = ExportQdrant(args)
qdrant_export.all_collections = qdrant_export.get_all_index_names()
set_arg_from_input(
args,
"collections",
"Enter the name of collection(s) to export (comma-separated) (hit return to export all):",
str,
choices=qdrant_export.all_collections,
)
set_arg_from_password(
args, "qdrant_api_key", "Enter your Qdrant API key: ", "QDRANT_API_KEY"
)
qdrant_export = ExportQdrant(args)
qdrant_export.get_data()
return qdrant_export

Expand All @@ -81,20 +83,24 @@ def __init__(self, args):
prefer_grpc=self.args.get("prefer_grpc", True),
)

def get_index_names(self) -> List[str]:
def get_all_index_names(self) -> List[str]:
"""
Get all collection names from Qdrant
"""
collections = self.client.get_collections().collections
collection_names = [collection.name for collection in collections]
return collection_names

def get_data(self):
def get_index_names(self) -> List[str]:
"""
Get collection names from args or all collection names
"""
if "collections" not in self.args or self.args["collections"] is None:
collection_names = self.get_index_names()
else:
collection_names = self.args["collections"].split(",")
return self.get_all_index_names()
return self.args["collections"].split(",")

def get_data(self):
collection_names = self.get_index_names()
index_metas: Dict[str, List[NamespaceMeta]] = {}
for collection_name in tqdm(collection_names, desc="Fetching indexes"):
index_meta = self.get_data_for_collection(collection_name)
Expand Down
8 changes: 8 additions & 0 deletions src/vdf_io/export_vdf/vdb_export_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ def get_index_names(self) -> List[str]:
# raise NotImplementedError()
pass

@abc.abstractmethod
def get_all_index_names(self) -> List[str]:
"""
Get all index names from vector database
"""
# raise NotImplementedError()
pass

@abc.abstractmethod
def get_data(self) -> ExportVDB:
"""
Expand Down
9 changes: 5 additions & 4 deletions src/vdf_io/import_vdf/pinecone_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import os
from dotenv import load_dotenv

from pinecone import Pinecone, ServerlessSpec, PodSpec, Vector
from pinecone.grpc import PineconeGRPC as Pinecone
from pinecone import ServerlessSpec, PodSpec, Vector

from vdf_io.constants import INT_MAX
from vdf_io.names import DBNames
Expand Down Expand Up @@ -260,9 +261,9 @@ def upsert_data(self):
]
try:
resp = index.upsert(vectors=batch_vectors, namespace=namespace)
self.total_imported_count += resp["upserted_count"]
pbar.update(resp["upserted_count"])
start_idx += resp["upserted_count"]
self.total_imported_count += resp.upserted_count
pbar.update(resp.upserted_count)
start_idx += resp.upserted_count
except Exception as e:
tqdm.write(
f"Error upserting vectors for index '{compliant_index_name}', {e}"
Expand Down

0 comments on commit c418950

Please sign in to comment.