diff --git a/dataset.py b/dataset.py index c7b417c..46477c5 100644 --- a/dataset.py +++ b/dataset.py @@ -46,6 +46,7 @@ def list(): return datasets def load(self, skip_download: bool = False, load_queries: bool = True, + limit: int = 0, doc_sample_fraction: float = 1.0): """ Load the dataset, populating the 'documents' and 'queries' DataFrames. @@ -53,9 +54,9 @@ def load(self, skip_download: bool = False, load_queries: bool = True, if not skip_download: self._download_dataset_files() - # Load all the parquet dataset (made up of one or more parquet files), + # Load the parquet dataset (made up of one or more parquet files), # to use for documents into a pandas dataframe. - self.documents = self._load_parquet_dataset("documents") + self.documents = self._load_parquet_dataset("documents", limit=limit) # If there is an explicit 'queries' dataset, then load that and use # for querying, otherwise use documents directly. @@ -146,7 +147,7 @@ def should_download(blob): blob.download_to_filename(self.cache / blob.name) pbar.update(blob.size) - def _load_parquet_dataset(self, kind): + def _load_parquet_dataset(self, kind, limit=0): parquet_files = [f for f in (self.cache / self.name).glob(kind + '/*.parquet')] if not len(parquet_files): return pandas.DataFrame @@ -167,6 +168,8 @@ def _load_parquet_dataset(self, kind): # and hence significantly reduces memory usage when we later prune away the underlying # parrow data (see prune_documents). df = dataset.read(columns=columns).to_pandas(types_mapper=pandas.ArrowDtype) + if limit: + df = df.iloc[:limit] # And drop any columns which all values are missing - e.g. not all # datasets have sparse_values, but the parquet file may still have diff --git a/locustfile.py b/locustfile.py index f57ba49..7e55c3e 100644 --- a/locustfile.py +++ b/locustfile.py @@ -71,6 +71,8 @@ def _(parser): " list full details of available datasets.") pc_options.add_argument("--pinecone-dataset-ignore-queries", action=argparse.BooleanOptionalAction, help="Ignore and do not load the 'queries' table from the specified dataset.") + pc_options.add_argument("--pinecone-dataset-limit", type=int, default=0, + help="If non-zero, limit the dataset to the first N vectors.") pc_options.add_argument("--pinecone-dataset-docs-sample-for-query", type=float, default=0.01, metavar=" (0.0 - 1.0)", help="Specify the fraction of docs which should be sampled when the documents vectorset " @@ -141,8 +143,10 @@ def setup_dataset(environment: Environment, skip_download_and_populate: bool = F environment.dataset = Dataset(dataset_name, environment.parsed_options.pinecone_dataset_cache) ignore_queries = environment.parsed_options.pinecone_dataset_ignore_queries sample_ratio = environment.parsed_options.pinecone_dataset_docs_sample_for_query + limit = environment.parsed_options.pinecone_dataset_limit environment.dataset.load(skip_download=skip_download_and_populate, load_queries=not ignore_queries, + limit=limit, doc_sample_fraction=sample_ratio) populate = environment.parsed_options.pinecone_populate_index if not skip_download_and_populate and populate != "never": diff --git a/tests/integration/test_requests.py b/tests/integration/test_requests.py index bcbe54b..f3fac94 100644 --- a/tests/integration/test_requests.py +++ b/tests/integration/test_requests.py @@ -104,13 +104,16 @@ def test_datasets_list_details(self): def test_dataset_load(self, index_host): # Choosing a small dataset ("only" 60,000 documents) which also # has a non-zero queries set. + # We also test the --pinecone-dataset-limit option here (which has the + # bonus effect of speeding up the test - note that complete + # dataset loading is tested in test_dataset_load_multiprocess). test_dataset = "ANN_MNIST_d784_euclidean" self.do_request(index_host, "sdk", 'query', 'Vector (Query only)', timeout=60, extra_args=["--pinecone-dataset", test_dataset, + "--pinecone-dataset-limit", "123", "--pinecone-populate-index", "always"]) - def test_dataset_load_multiprocess(self, index_host): # Choosing a small dataset ("only" 60,000 documents) which also # has a non-zero queries set. diff --git a/tests/unit/test_dataset.py b/tests/unit/test_dataset.py new file mode 100644 index 0000000..58ee6e2 --- /dev/null +++ b/tests/unit/test_dataset.py @@ -0,0 +1,22 @@ +import os +import sys +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) + +from dataset import Dataset +import pytest + + +class TestDataset: + + def test_limit(self): + limit = 123 + name = "langchain-python-docs-text-embedding-ada-002" + dataset = Dataset(name) + # Sanity check that the complete dataset size is greater than what + # we are going to limit to. + dataset_info = ([d for d in dataset.list() if d["name"] == name][0]) + assert dataset_info["documents"] > limit, \ + "Too few documents in dataset to be able to limit" + + dataset.load(limit=limit, load_queries=False) + assert len(dataset.documents) == limit