Skip to content

Commit

Permalink
fixup! Add support for populating index from sample dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
daverigby committed Feb 13, 2024
1 parent a041e4c commit 657e5ed
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions locustfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pinecone import Pinecone
from pinecone.grpc import PineconeGRPC
import tempfile
from tqdm import tqdm
from tqdm import tqdm, trange
import sys

# patch grpc so that it uses gevent instead of asyncio. This is required to
Expand Down Expand Up @@ -92,7 +92,6 @@ def check_for_dataset(environment: Environment, **kwargs):
print()
sys.exit(1)

logging.info(f"Downloading dataset '{dataset_name}'")
environment.dataset = Dataset(dataset_name, environment.parsed_options.pinecone_dataset_cache)
environment.dataset.load()
populate = environment.parsed_options.pinecone_populate_index
Expand Down Expand Up @@ -166,8 +165,7 @@ def list():
metadata_blobs = bucket.list_blobs(match_glob="*/metadata.json")
datasets = []
for m in metadata_blobs:
with m.open() as f:
datasets.append(json.load(f))
datasets.append(json.loads(m.download_as_string()))
return datasets

def load(self):
Expand Down Expand Up @@ -227,19 +225,19 @@ def should_download(blob):
remote_size = blob.size
return local_size != remote_size

to_download = [b.name for b in filter(lambda b: should_download(b), blobs)]
logging.debug(f"Dataset files not found in cache - will be downloaded: '{to_download}'")
results = transfer_manager.download_many_to_path(bucket,
to_download,
destination_directory=str(self.cache))
for name, result in zip(to_download, results):
# The results list is either `None` or an exception for each blob in
# the input list, in order.
if isinstance(result, Exception):
logging.error("Failed to download {} due to exception: {}".format(name, result))
raise result
else:
logging.debug("Downloaded {} to {}.".format(name, self.cache / name))
to_download = [b for b in filter(lambda b: should_download(b), blobs)]
if not to_download:
return
pbar = tqdm(desc="Downloading datset",
total=sum([b.size for b in to_download]),
unit="Bytes",
unit_scale=True)
for blob in to_download:
logging.debug(f"Dataset file '{blob.name}' not found in cache - will be downloaded")
dest_path = self.cache / blob.name
dest_path.parent.mkdir(parents=True, exist_ok=True)
blob.download_to_filename(self.cache / blob.name)
pbar.update(blob.size)

def _load_parquet_dataset(self, kind):
parquet_files = [f for f in (self.cache / self.name).glob(kind + '/*.parquet')]
Expand Down

0 comments on commit 657e5ed

Please sign in to comment.