Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

testing: Weaviate impl #88

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 109 additions & 35 deletions src/vdf_io/export_vdf/weaviate_export.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import os

from tqdm import tqdm
import weaviate
import json
from typing import Dict, List
from tqdm import tqdm

from weaviate.classes.query import MetadataQuery

from vdf_io.export_vdf.vdb_export_cls import ExportVDB
from vdf_io.meta_types import NamespaceMeta
from vdf_io.names import DBNames
from vdf_io.util import set_arg_from_input, set_arg_from_password

# Set these environment variables
URL = os.getenv("YOUR_WCS_URL")
APIKEY = os.getenv("YOUR_WCS_API_KEY")
from vdf_io.util import set_arg_from_input
from vdf_io.constants import DEFAULT_BATCH_SIZE
from vdf_io.weaviate_util import prompt_for_creds


class ExportWeaviate(ExportVDB):
Expand All @@ -23,24 +25,32 @@ def make_parser(cls, subparsers):

parser_weaviate.add_argument("--url", type=str, help="URL of Weaviate instance")
parser_weaviate.add_argument("--api_key", type=str, help="Weaviate API key")
parser_weaviate.add_argument(
"--openai_api_key", type=str, help="Openai API key"
)
parser_weaviate.add_argument(
"--batch_size",
type=int,
help="batch size for fetching",
default=DEFAULT_BATCH_SIZE,
)
parser_weaviate.add_argument(
"--offset", type=int, help="offset for fetching", default=None
)
parser_weaviate.add_argument(
"--connection-type",
type=str,
choices=["local", "cloud"],
default="cloud",
help="Type of connection to Weaviate (local or cloud)",
)
parser_weaviate.add_argument(
"--classes", type=str, help="Classes to export (comma-separated)"
)

@classmethod
def export_vdb(cls, args):
set_arg_from_input(
args,
"url",
"Enter the URL of Weaviate instance: ",
str,
)
set_arg_from_password(
args,
"api_key",
"Enter the Weaviate API key: ",
"WEAVIATE_API_KEY",
)
prompt_for_creds(args)
weaviate_export = ExportWeaviate(args)
weaviate_export.all_classes = list(
weaviate_export.client.collections.list_all().keys()
Expand All @@ -55,14 +65,20 @@ def export_vdb(cls, args):
weaviate_export.get_data()
return weaviate_export

# Connect to a WCS instance
# Connect to a WCS or local instance
def __init__(self, args):
super().__init__(args)
self.client = weaviate.connect_to_wcs(
cluster_url=self.args["url"],
auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]),
skip_init_checks=True,
)
if self.args["connection_type"] == "local":
self.client = weaviate.connect_to_local()
else:
self.client = weaviate.connect_to_wcs(
cluster_url=self.args["url"],
auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]),
headers={"X-OpenAI-Api-key": self.args["openai_api_key"]}
if self.args["openai_api_key"]
else None,
skip_init_checks=True,
)

def get_index_names(self):
if self.args.get("classes") is None:
Expand All @@ -75,15 +91,73 @@ def get_index_names(self):
)
return [c for c in self.all_classes if c in input_classes]

def metadata_to_dict(self, metadata):
meta_data = {}
meta_data["creation_time"] = metadata.creation_time
meta_data["distance"] = metadata.distance
meta_data["certainty"] = metadata.certainty
meta_data["explain_score"] = metadata.explain_score
meta_data["is_consistent"] = metadata.is_consistent
meta_data["last_update_time"] = metadata.last_update_time
meta_data["rerank_score"] = metadata.rerank_score
meta_data["score"] = metadata.score

return meta_data

def get_data(self):
# Get all objects of a class
# Get the index names to export
index_names = self.get_index_names()
for class_name in index_names:
collection = self.client.collections.get(class_name)
response = collection.aggregate.over_all(total_count=True)
print(f"{response.total_count=}")

# objects = self.client.query.get(
# wvq.Objects(wvq.Class(class_name)).with_limit(1000)
# )
# print(objects)
index_metas: Dict[str, List[NamespaceMeta]] = {}

# Export data in batches
batch_size = self.args["batch_size"]
offset = self.args["offset"]

# Iterate over index names and fetch data
for index_name in index_names:
collection = self.client.collections.get(index_name)
response = collection.query.fetch_objects(
limit=batch_size,
offset=offset,
include_vector=True,
return_metadata=MetadataQuery.full(),
)
res = collection.aggregate.over_all(total_count=True)
total_vector_count = res.total_count

# Create vectors directory for this index
vectors_directory = self.create_vec_dir(index_name)

for obj in response.objects:
vectors = obj.vector
metadata = obj.metadata
metadata = self.metadata_to_dict(metadata=metadata)

# Save vectors and metadata to Parquet file
num_vectors_exported = self.save_vectors_to_parquet(
vectors, metadata, vectors_directory
)

# Create NamespaceMeta for this index
namespace_metas = [
self.get_namespace_meta(
index_name,
vectors_directory,
total=total_vector_count,
num_vectors_exported=num_vectors_exported,
dim=-1,
distance="Cosine",
)
]
index_metas[index_name] = namespace_metas

# Write VDFMeta to JSON file
self.file_structure.append(os.path.join(self.vdf_directory, "VDF_META.json"))
internal_metadata = self.get_basic_vdf_meta(index_metas)
meta_text = json.dumps(internal_metadata.model_dump(), indent=4)
tqdm.write(meta_text)
with open(os.path.join(self.vdf_directory, "VDF_META.json"), "w") as json_file:
json_file.write(meta_text)
print("Data export complete.")

return True
122 changes: 122 additions & 0 deletions src/vdf_io/import_vdf/weaviate_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
import weaviate
from tqdm import tqdm
from vdf_io.import_vdf.vdf_import_cls import ImportVDB
from vdf_io.names import DBNames
from vdf_io.constants import INT_MAX, DEFAULT_BATCH_SIZE
from vdf_io.weaviate_util import prompt_for_creds

# Set these environment variables
URL = os.getenv("YOUR_WCS_URL")
APIKEY = os.getenv("YOUR_WCS_API_KEY")


class ImportWeaviate(ImportVDB):
DB_NAME_SLUG = DBNames.WEAVIATE

@classmethod
def make_parser(cls, subparsers):
parser_weaviate = subparsers.add_parser(
cls.DB_NAME_SLUG, help="Import data into Weaviate"
)

parser_weaviate.add_argument("--url", type=str, help="URL of Weaviate instance")
parser_weaviate.add_argument("--api_key", type=str, help="Weaviate API key")
parser_weaviate.add_argument(
"--connection-type",
type=str,
choices=["local", "cloud"],
default="cloud",
help="Type of connection to Weaviate (local or cloud)",
)
parser_weaviate.add_argument(
"--batch_size",
type=int,
help="batch size for fetching",
default=DEFAULT_BATCH_SIZE,
)

@classmethod
def import_vdb(cls, args):
prompt_for_creds(args)
weaviate_import = ImportWeaviate(args)
weaviate_import.upsert_data()
return weaviate_import

def __init__(self, args):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The connection_type argument is used here but it is not defined in the argument parser for the import script. This will cause an error when trying to access self.args["connection_type"].

To fix this, add the connection_type argument to the parser in the make_parser method:

Suggested change
def __init__(self, args):
parser_weaviate.add_argument(
"--connection-type", type=str, choices=["local", "cloud"], default="cloud",
help="Type of connection to Weaviate (local or cloud)"
)

super().__init__(args)
if self.args["connection_type"] == "local":
self.client = weaviate.connect_to_local()
else:
self.client = weaviate.connect_to_wcs(
cluster_url=self.args["url"],
auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]),
headers={"X-OpenAI-Api-key": self.args.get("openai_api_key", "")},
skip_init_checks=True,
)

def upsert_data(self):
max_hit = False
total_imported_count = 0

# Iterate over the indexes and import the data
for index_name, index_meta in tqdm(
self.vdf_meta["indexes"].items(), desc="Importing indexes"
):
tqdm.write(f"Importing data for index '{index_name}'")
for namespace_meta in index_meta:
self.set_dims(namespace_meta, index_name)

# Create or get the index
index_name = self.create_new_name(
index_name, self.client.collections.list_all().keys()
)

# Load data from the Parquet files
data_path = namespace_meta["data_path"]
final_data_path = self.get_final_data_path(data_path)
parquet_files = self.get_parquet_files(final_data_path)

vectors = {}
metadata = {}
vector_column_names, vector_column_name = self.get_vector_column_name(
index_name, namespace_meta
)

for file in tqdm(parquet_files, desc="Loading data from parquet files"):
file_path = os.path.join(final_data_path, file)
df = self.read_parquet_progress(file_path)

if len(vectors) > (self.args.get("max_num_rows") or INT_MAX):
max_hit = True
break
if len(vectors) + len(df) > (self.args.get("max_num_rows") or INT_MAX):
df = df.head(
(self.args.get("max_num_rows") or INT_MAX) - len(vectors)
)
max_hit = True
self.update_vectors(vectors, vector_column_name, df)
self.update_metadata(metadata, vector_column_names, df)
if max_hit:
break

tqdm.write(
f"Loaded {len(vectors)} vectors from {len(parquet_files)} parquet files"
)

# Upsert the vectors and metadata to the Weaviate index in batches
BATCH_SIZE = self.args.get("batch_size")

with self.client.batch.fixed_size(batch_size=BATCH_SIZE) as batch:
for _, vector in vectors.items():
batch.add_object(
vector=vector,
collection=index_name,
# TODO: Find way to add Metadata
)
total_imported_count += 1

tqdm.write(
f"Data import completed successfully. Imported {total_imported_count} vectors"
)
self.args["imported_count"] = total_imported_count
31 changes: 31 additions & 0 deletions src/vdf_io/weaviate_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from vdf_io.util import set_arg_from_input, set_arg_from_password


def prompt_for_creds(args):
set_arg_from_input(
args,
"connection_type",
"Enter 'local' or 'cloud' for connection types: ",
choices=["local", "cloud"],
)
if args["connection_type"] == "cloud":
set_arg_from_input(
args,
"url",
"Enter the URL of Weaviate instance: ",
str,
env_var="WEAVIATE_URL",
)
set_arg_from_password(
args,
"api_key",
"Enter the Weaviate API key: ",
"WEAVIATE_API_KEY",
)

set_arg_from_password(
args,
"api_key",
"Enter the Weaviate API key: ",
"WEAVIATE_API_KEY",
)