From 99ae1d24b9b9ef0563701de302643f38d79a2999 Mon Sep 17 00:00:00 2001 From: Isaak Willett Date: Thu, 29 Sep 2022 20:02:20 +0000 Subject: [PATCH 1/8] adds non spark persistence --- tests/core/__init__.py | 0 tests/core/test_persistence.py | 78 ++++++++++++ wicker/core/persistance.py | 213 ++++++++++++++++++++++++++++++++- wicker/plugins/spark.py | 77 +----------- 4 files changed, 291 insertions(+), 77 deletions(-) create mode 100644 tests/core/__init__.py create mode 100644 tests/core/test_persistence.py diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/core/test_persistence.py b/tests/core/test_persistence.py new file mode 100644 index 0000000..f73b962 --- /dev/null +++ b/tests/core/test_persistence.py @@ -0,0 +1,78 @@ +import copy +import os +import random +import uuid + +import pyarrow.parquet as papq +import pytest + +from wicker import schema +from wicker.core.config import get_config +from wicker.core.persistance import BasicPersistor +from wicker.core.storage import S3PathFactory +from wicker.testing.storage import FakeS3DataStorage + +DATASET_NAME = "dataset" +DATASET_VERSION = "0.0.1" +SCHEMA = schema.DatasetSchema( + primary_keys=["bar", "foo"], + fields=[ + schema.IntField("foo"), + schema.StringField("bar"), + schema.BytesField("bytescol"), + ], +) +EXAMPLES = [ + ( + "train" if i % 2 == 0 else "test", + { + "foo": random.randint(0, 10000), + "bar": str(uuid.uuid4()), + "bytescol": b"0", + }, + ) + for i in range(10000) +] +# Examples with a duplicated key +EXAMPLES_DUPES = copy.deepcopy(EXAMPLES) + + +@pytest.fixture +def mock_basic_persistor(request, tmpdir): + storage = request.param.get("storage", FakeS3DataStorage(tmpdir=tmpdir)) + path_factory = request.param.get("path_factory", S3PathFactory()) + return BasicPersistor(storage, path_factory), tmpdir + + +def assert_written_correctness(tmpdir: str) -> None: + """Asserts that all files are written as expected by the L5MLDatastore""" + # Check that files are correctly written locally by Spark/Parquet with a _SUCCESS marker file + prefix = get_config().aws_s3_config.s3_datasets_path.replace("s3://", "") + assert DATASET_NAME in os.listdir(os.path.join(tmpdir, prefix)) + assert DATASET_VERSION in os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME)) + for partition in ["train", "test"]: + print(os.listdir(os.path.join(tmpdir, prefix))) + columns_path = os.path.join(tmpdir, prefix, "__COLUMN_CONCATENATED_FILES__") + all_read_bytes = b"" + for filename in os.listdir(columns_path): + concatenated_bytes_filepath = os.path.join(columns_path, filename) + with open(concatenated_bytes_filepath, "rb") as bytescol_file: + all_read_bytes += bytescol_file.read() + assert all_read_bytes == b"0" * 10000 + + # Load parquet file and assert ordering of primary_key + assert f"{partition}.parquet" in os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION)) + tbl = papq.read_table(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION, f"{partition}.parquet")) + foobar = [(barval.as_py(), fooval.as_py()) for fooval, barval in zip(tbl["foo"], tbl["bar"])] + assert foobar == sorted(foobar) + + +@pytest.mark.parametrize( + "mock_basic_persistor, dataset_name, dataset_version, dataset_schema, dataset", + [({}, DATASET_NAME, DATASET_VERSION, SCHEMA, copy.deepcopy(EXAMPLES_DUPES))], + indirect=["mock_basic_persistor"], +) +def test_basic_persistor(mock_basic_persistor: BasicPersistor, dataset_name, dataset_version, dataset_schema, dataset): + mock_basic_persistor, tempdir = mock_basic_persistor + mock_basic_persistor.persist_wicker_dataset(dataset_name, dataset_version, dataset_schema, dataset) + assert_written_correctness(tempdir) diff --git a/wicker/core/persistance.py b/wicker/core/persistance.py index fe6525e..166413d 100644 --- a/wicker/core/persistance.py +++ b/wicker/core/persistance.py @@ -1,12 +1,22 @@ import abc -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterable, Optional, Tuple -from wicker import schema +import pyarrow as pa +import pyarrow.compute as pc + +from wicker import schema as schema_module +from wicker.core.column_files import ColumnBytesFileWriter +from wicker.core.definitions import DatasetID +from wicker.core.shuffle import save_index from wicker.core.storage import S3DataStorage, S3PathFactory -from wicker.schema import dataparsing +from wicker.schema import dataparsing, serialization + +PARTITION_SIZE = 256 +MAX_COL_FILE_NUMROW = 50 # TODO(isaak-willett): Magic number, we should derive this based on row size UnparsedExample = Dict[str, Any] ParsedExample = Dict[str, Any] +PointerParsedExample = Dict[str, Any] class AbstractDataPersistor(abc.ABC): @@ -40,7 +50,7 @@ def persist_wicker_dataset( self, dataset_name: str, dataset_version: str, - dataset_schema: schema.DatasetSchema, + dataset_schema: schema_module.DatasetSchema, dataset: Any, ) -> Optional[Dict[str, int]]: """ @@ -58,7 +68,7 @@ def persist_wicker_dataset( raise NotImplementedError("Method, persist_wicker_dataset, needs to be implemented in inhertiance class.") @staticmethod - def parse_row(data_row: UnparsedExample, schema: schema.DatasetSchema) -> ParsedExample: + def parse_row(data_row: UnparsedExample, schema: schema_module.DatasetSchema) -> ParsedExample: """ Parse a row to test for validation errors. @@ -68,3 +78,196 @@ def parse_row(data_row: UnparsedExample, schema: schema.DatasetSchema) -> Parsed :rtype: ParsedExample """ return dataparsing.parse_example(data_row, schema) + + # Write data to Column Byte Files + + @staticmethod + def persist_wicker_partition( + spark_partition_iter: Iterable[Tuple[str, ParsedExample]], + schema: schema_module.DatasetSchema, + s3_storage: S3DataStorage, + s3_path_factory: S3PathFactory, + target_max_column_file_numrows: int = 50, + ) -> Iterable[Tuple[str, PointerParsedExample]]: + """Persists a Spark partition of examples with parsed bytes into S3Storage as ColumnBytesFiles, + returning a new Spark partition of examples with heavy-pointers and metadata only. + :param spark_partition_iter: Spark partition of `(partition_str, example)`, where `example` + is a dictionary of parsed bytes that needs to be uploaded to S3 + :param target_max_column_file_numrows: Maximum number of rows in column files. Defaults to 50. + :return: a Generator of `(partition_str, example)`, where `example` is a dictionary with heavy-pointers + that point to ColumnBytesFiles in S3 in place of the parsed bytes + """ + column_bytes_file_writers: Dict[str, ColumnBytesFileWriter] = {} + heavy_pointer_columns = schema.get_pointer_columns() + metadata_columns = schema.get_non_pointer_columns() + + for partition, example in spark_partition_iter: + # Create ColumnBytesFileWriter lazily as required, for each partition + if partition not in column_bytes_file_writers: + column_bytes_file_writers[partition] = ColumnBytesFileWriter( + s3_storage, + s3_path_factory, + target_file_rowgroup_size=target_max_column_file_numrows, + ) + + # Write to ColumnBytesFileWriter and return only metadata + heavy-pointers + parquet_metadata: Dict[str, Any] = {col: example[col] for col in metadata_columns} + for col in heavy_pointer_columns: + loc = column_bytes_file_writers[partition].add(col, example[col]) + parquet_metadata[col] = loc.to_bytes() + yield partition, parquet_metadata + + # Flush all writers when finished + for partition in column_bytes_file_writers: + column_bytes_file_writers[partition].close() + + @staticmethod + def save_partition_tbl( + partition_table_tuple: Tuple[str, pa.Table], + dataset_name: str, + dataset_version: str, + s3_storage: S3DataStorage, + s3_path_factory: S3PathFactory, + ) -> Tuple[str, int]: + """ + Save a partition table to s3 under the dataset name and version. + + :param partition_table_tuple: Tuple of partition id and pyarrow table to save + :type partition_table_tuple: Tuple[str, pyarrow.Table] + :return: A tuple containing the paritiion id and the num of saved rows + :rtype: Tuple[str, int] + """ + partition, pa_tbl = partition_table_tuple + save_index( + dataset_name, + dataset_version, + {partition: pa_tbl}, + s3_storage=s3_storage, + s3_path_factory=s3_path_factory, + ) + return (partition, pa_tbl.num_rows) + + +def persist_wicker_dataset( + dataset_name: str, + dataset_version: str, + dataset_schema: schema_module.DatasetSchema, + dataset: Any, + s3_storage: S3DataStorage = S3DataStorage(), + s3_path_factory: S3PathFactory = S3PathFactory(), +) -> Optional[Dict[str, int]]: + """ + Persist wicker dataset public facing api function, for api consistency. + :param dataset_name: name of dataset persisted + :type dataset_name: str + :param dataset_version: version of dataset persisted + :type dataset_version: str + :param dataset_schema: schema of dataset to be persisted + :type dataset_schema: DatasetSchema + :param rdd: rdd of data to persist + :type rdd: RDD + :param s3_storage: s3 storage abstraction + :type s3_storage: S3DataStorage + :param s3_path_factory: s3 path abstraction + :type s3_path_factory: S3PathFactory + """ + return BasicPersistor(s3_storage, s3_path_factory).persist_wicker_dataset( + dataset_name, dataset_version, dataset_schema, dataset + ) + + +class BasicPersistor(AbstractDataPersistor): + """ + Basic persistor class that persists wicker data on s3 in a non sorted manner. + + We will move to supporting other features like shuffling, other data engines, etc... + """ + + def __init__( + self, s3_storage: S3DataStorage = S3DataStorage(), s3_path_factory: S3PathFactory = S3PathFactory() + ) -> None: + super().__init__(s3_storage, s3_path_factory) + + def persist_wicker_dataset( + self, dataset_name: str, dataset_version: str, dataset_schema: schema_module.DatasetSchema, dataset: Any + ) -> Optional[Dict[str, int]]: + """ + Persist a user defined dataset, pushing data to s3 in a basic manner + + :param dataset_name: Name of the dataset + :type dataset_name: str + :param dataset_version: Version of the dataset + :type: dataset_version: str + :param dataset_schema: Schema of the dataset + :type dataset_schema: wicker.schema.schema.DatasetSchema + :param dataset: Data of the dataset + :type dataset: User defined + """ + # what needs to be done within this function + # 1. Check if the variables are set + # check if variables have been set ie: not None + if ( + not isinstance(dataset_name, str) + or not isinstance(dataset_version, str) + or not isinstance(dataset_schema, schema_module.DatasetSchema) + ): + raise ValueError("Current dataset variables not all set, set all to proper not None values") + + # 6. Put the schema up on + schema_path = self.s3_path_factory.get_dataset_schema_path( + DatasetID(name=dataset_name, version=dataset_version) + ) + self.s3_storage.put_object_s3(serialization.dumps(dataset_schema).encode("utf-8"), schema_path) + + # 2. Validate the rows and ensure data is well formed, sort while doing + dataset_0 = [(row[0], self.parse_row(row[1], dataset_schema)) for row in dataset] + + # 3. Sort the dataset if not sorted + sorted_dataset_0 = sorted(dataset_0, key=lambda tup: tup[0]) + + # 4. Partition the dataset into K partitions + num_paritions = len(sorted_dataset_0) // PARTITION_SIZE + partitions = [] + + def divide_chunks(list_to_divide, num_chunks): + # looping till length l + for i in range(0, len(list_to_divide), num_chunks): + partitions.append(list_to_divide[i : i + num_chunks]) + + divide_chunks(sorted_dataset_0, num_paritions) + + # 5. Persist the partitions to S3 + for partition in partitions: + iterator = self.persist_wicker_partition( + partition, dataset_schema, self.s3_storage, self.s3_path_factory, MAX_COL_FILE_NUMROW + ) + # make sure all yields get called + list(iterator) + + # 6. Create the parition table, need to combine keys in a way we can form table + merged_dicts = {} + for partition_key, row in sorted_dataset_0: + current_dict = merged_dicts.get(partition_key, {}) + for col in row.keys(): + if col in current_dict: + current_dict[col].append(row[col]) + else: + current_dict[col] = [row[col]] + merged_dicts[partition_key] = current_dict + arrow_dict = {} + for partition_key, data_dict in merged_dicts.items(): + data_table = pa.Table.from_pydict(data_dict) + arrow_dict[partition_key] = pc.take( + pa.Table.from_pydict(data_dict), + pc.sort_indices(data_table, sort_keys=[(pk, "ascending") for pk in dataset_schema.primary_keys]), + ) + + # 7. Persist the partition table to s3 + written_dict = {} + for partition_key, pa_table in arrow_dict.items(): + self.save_partition_tbl( + (partition_key, pa_table), dataset_name, dataset_version, self.s3_storage, self.s3_path_factory + ) + written_dict[partition_key] = pa_table.num_rows + + return written_dict diff --git a/wicker/plugins/spark.py b/wicker/plugins/spark.py index b4d0d04..f2a249d 100644 --- a/wicker/plugins/spark.py +++ b/wicker/plugins/spark.py @@ -21,11 +21,9 @@ from operator import add from wicker import schema as schema_module -from wicker.core.column_files import ColumnBytesFileWriter from wicker.core.definitions import DatasetID from wicker.core.errors import WickerDatastoreException from wicker.core.persistance import AbstractDataPersistor -from wicker.core.shuffle import save_index from wicker.core.storage import S3DataStorage, S3PathFactory from wicker.schema import serialization @@ -108,7 +106,7 @@ def persist_wicker_dataset( s3_path_factory = self.s3_path_factory parse_row = self.parse_row get_row_keys = self.get_row_keys - persist_spark_partition_wicker = self.persist_spark_partition_wicker + persist_wicker_partition = self.persist_wicker_partition save_partition_tbl = self.save_partition_tbl # put the schema up on to s3 @@ -148,8 +146,9 @@ def set_partition(iterator: Iterable[PrimaryKeyTuple]) -> Iterable[int]: # persist the spark partition to S3Storage rdd3 = rdd2.values() + rdd4 = rdd3.mapPartitions( - lambda spark_iterator: persist_spark_partition_wicker( + lambda spark_iterator: persist_wicker_partition( spark_iterator, schema, s3_storage, @@ -176,6 +175,7 @@ def set_partition(iterator: Iterable[PrimaryKeyTuple]) -> Iterable[int]: ] ), ) + # create the partition tables rdd6 = rdd5.mapValues( lambda pa_tbl: pc.take( @@ -186,6 +186,7 @@ def set_partition(iterator: Iterable[PrimaryKeyTuple]) -> Iterable[int]: ), ) ) + # save the parition table to s3 rdd7 = rdd6.map( lambda partition_table: save_partition_tbl( @@ -210,71 +211,3 @@ def get_row_keys( """ partition, data = partition_data_tup return (partition,) + tuple(data[pk] for pk in schema.primary_keys) - - # Write data to Column Byte Files - @staticmethod - def persist_spark_partition_wicker( - spark_partition_iter: Iterable[Tuple[str, ParsedExample]], - schema: schema_module.DatasetSchema, - s3_storage: S3DataStorage, - s3_path_factory: S3PathFactory, - target_max_column_file_numrows: int = 50, - ) -> Iterable[Tuple[str, PointerParsedExample]]: - """Persists a Spark partition of examples with parsed bytes into S3Storage as ColumnBytesFiles, - returning a new Spark partition of examples with heavy-pointers and metadata only. - :param spark_partition_iter: Spark partition of `(partition_str, example)`, where `example` - is a dictionary of parsed bytes that needs to be uploaded to S3 - :param target_max_column_file_numrows: Maximum number of rows in column files. Defaults to 50. - :return: a Generator of `(partition_str, example)`, where `example` is a dictionary with heavy-pointers - that point to ColumnBytesFiles in S3 in place of the parsed bytes - """ - column_bytes_file_writers: Dict[str, ColumnBytesFileWriter] = {} - heavy_pointer_columns = schema.get_pointer_columns() - metadata_columns = schema.get_non_pointer_columns() - - for partition, example in spark_partition_iter: - # Create ColumnBytesFileWriter lazily as required, for each partition - if partition not in column_bytes_file_writers: - column_bytes_file_writers[partition] = ColumnBytesFileWriter( - s3_storage, - s3_path_factory, - target_file_rowgroup_size=target_max_column_file_numrows, - ) - - # Write to ColumnBytesFileWriter and return only metadata + heavy-pointers - parquet_metadata: Dict[str, Any] = {col: example[col] for col in metadata_columns} - for col in heavy_pointer_columns: - loc = column_bytes_file_writers[partition].add(col, example[col]) - parquet_metadata[col] = loc.to_bytes() - yield partition, parquet_metadata - - # Flush all writers when finished - for partition in column_bytes_file_writers: - column_bytes_file_writers[partition].close() - - # sort the indices of the primary keys in ascending order - @staticmethod - def save_partition_tbl( - partition_table_tuple: Tuple[str, pa.Table], - dataset_name: str, - dataset_version: str, - s3_storage: S3DataStorage, - s3_path_factory: S3PathFactory, - ) -> Tuple[str, int]: - """ - Save a partition table to s3 under the dataset name and version. - - :param partition_table_tuple: Tuple of partition id and pyarrow table to save - :type partition_table_tuple: Tuple[str, pyarrow.Table] - :return: A tuple containing the paritiion id and the num of saved rows - :rtype: Tuple[str, int] - """ - partition, pa_tbl = partition_table_tuple - save_index( - dataset_name, - dataset_version, - {partition: pa_tbl}, - s3_storage=s3_storage, - s3_path_factory=s3_path_factory, - ) - return (partition, pa_tbl.num_rows) From c803741166e40fd4ebc86cbe3713f8aa84fde134 Mon Sep 17 00:00:00 2001 From: Isaak Willett Date: Thu, 29 Sep 2022 20:15:07 +0000 Subject: [PATCH 2/8] linting and typing --- tests/core/test_persistence.py | 16 ++++++++++++---- wicker/core/persistance.py | 6 +++--- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/core/test_persistence.py b/tests/core/test_persistence.py index f73b962..e522932 100644 --- a/tests/core/test_persistence.py +++ b/tests/core/test_persistence.py @@ -2,6 +2,7 @@ import os import random import uuid +from typing import Any, Dict, List, Tuple import pyarrow.parquet as papq import pytest @@ -10,6 +11,7 @@ from wicker.core.config import get_config from wicker.core.persistance import BasicPersistor from wicker.core.storage import S3PathFactory +from wicker.schema.schema import DatasetSchema from wicker.testing.storage import FakeS3DataStorage DATASET_NAME = "dataset" @@ -38,7 +40,7 @@ @pytest.fixture -def mock_basic_persistor(request, tmpdir): +def mock_basic_persistor(request, tmpdir) -> Tuple[BasicPersistor, str]: storage = request.param.get("storage", FakeS3DataStorage(tmpdir=tmpdir)) path_factory = request.param.get("path_factory", S3PathFactory()) return BasicPersistor(storage, path_factory), tmpdir @@ -72,7 +74,13 @@ def assert_written_correctness(tmpdir: str) -> None: [({}, DATASET_NAME, DATASET_VERSION, SCHEMA, copy.deepcopy(EXAMPLES_DUPES))], indirect=["mock_basic_persistor"], ) -def test_basic_persistor(mock_basic_persistor: BasicPersistor, dataset_name, dataset_version, dataset_schema, dataset): - mock_basic_persistor, tempdir = mock_basic_persistor - mock_basic_persistor.persist_wicker_dataset(dataset_name, dataset_version, dataset_schema, dataset) +def test_basic_persistor( + mock_basic_persistor: Tuple[BasicPersistor, str], + dataset_name: str, + dataset_version: str, + dataset_schema: DatasetSchema, + dataset: List[Tuple[str, Dict[str, Any]]], +): + mock_basic_persistor_obj, tempdir = mock_basic_persistor + mock_basic_persistor_obj.persist_wicker_dataset(dataset_name, dataset_version, dataset_schema, dataset) assert_written_correctness(tempdir) diff --git a/wicker/core/persistance.py b/wicker/core/persistance.py index 166413d..76551a8 100644 --- a/wicker/core/persistance.py +++ b/wicker/core/persistance.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Dict, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import pyarrow as pa import pyarrow.compute as pc @@ -245,9 +245,9 @@ def divide_chunks(list_to_divide, num_chunks): list(iterator) # 6. Create the parition table, need to combine keys in a way we can form table - merged_dicts = {} + merged_dicts: Dict[str, Dict[str, List[Any]]] = {} for partition_key, row in sorted_dataset_0: - current_dict = merged_dicts.get(partition_key, {}) + current_dict: Dict[str, List[Any]] = merged_dicts.get(partition_key, {}) for col in row.keys(): if col in current_dict: current_dict[col].append(row[col]) From 3f8eb4aed6a9e33445b52866a5c915245e31c97b Mon Sep 17 00:00:00 2001 From: Isaak Willett Date: Thu, 29 Sep 2022 21:14:31 +0000 Subject: [PATCH 3/8] update logs, add parity test --- tests/core/test_persistence.py | 9 +++++++++ wicker/core/persistance.py | 19 ++++++++++++------- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/core/test_persistence.py b/tests/core/test_persistence.py index e522932..990717d 100644 --- a/tests/core/test_persistence.py +++ b/tests/core/test_persistence.py @@ -81,6 +81,15 @@ def test_basic_persistor( dataset_schema: DatasetSchema, dataset: List[Tuple[str, Dict[str, Any]]], ): + """ + Test if the basic persistor can persist data in the format we have established. + + Ensure we read the right file locations, the right amount of bytes, + and the ordering is correct. + """ + # create the mock basic persistor mock_basic_persistor_obj, tempdir = mock_basic_persistor + # persist the dataset mock_basic_persistor_obj.persist_wicker_dataset(dataset_name, dataset_version, dataset_schema, dataset) + # assert the dataset is correctly written assert_written_correctness(tempdir) diff --git a/wicker/core/persistance.py b/wicker/core/persistance.py index 76551a8..6061638 100644 --- a/wicker/core/persistance.py +++ b/wicker/core/persistance.py @@ -213,19 +213,19 @@ def persist_wicker_dataset( ): raise ValueError("Current dataset variables not all set, set all to proper not None values") - # 6. Put the schema up on + # 2. Put the schema up on s3 schema_path = self.s3_path_factory.get_dataset_schema_path( DatasetID(name=dataset_name, version=dataset_version) ) self.s3_storage.put_object_s3(serialization.dumps(dataset_schema).encode("utf-8"), schema_path) - # 2. Validate the rows and ensure data is well formed, sort while doing + # 3. Validate the rows and ensure data is well formed, sort while doing dataset_0 = [(row[0], self.parse_row(row[1], dataset_schema)) for row in dataset] - # 3. Sort the dataset if not sorted + # 4. Sort the dataset if not sorted sorted_dataset_0 = sorted(dataset_0, key=lambda tup: tup[0]) - # 4. Partition the dataset into K partitions + # 5. Partition the dataset into K partitions num_paritions = len(sorted_dataset_0) // PARTITION_SIZE partitions = [] @@ -236,15 +236,18 @@ def divide_chunks(list_to_divide, num_chunks): divide_chunks(sorted_dataset_0, num_paritions) - # 5. Persist the partitions to S3 + # 6. Persist the partitions to S3 for partition in partitions: + # build a persistence iterator for each parition iterator = self.persist_wicker_partition( partition, dataset_schema, self.s3_storage, self.s3_path_factory, MAX_COL_FILE_NUMROW ) # make sure all yields get called list(iterator) - # 6. Create the parition table, need to combine keys in a way we can form table + # 7. Create the parition table, need to combine keys in a way we can form table + # split into k dicts where k is partition number and the data is a list of values + # for each key for all the dicts in the partition merged_dicts: Dict[str, Dict[str, List[Any]]] = {} for partition_key, row in sorted_dataset_0: current_dict: Dict[str, List[Any]] = merged_dicts.get(partition_key, {}) @@ -254,6 +257,8 @@ def divide_chunks(list_to_divide, num_chunks): else: current_dict[col] = [row[col]] merged_dicts[partition_key] = current_dict + # convert each of the dicts to a pyarrow table in the same way SparkPersistor + # converts, needed to ensure parity between the two arrow_dict = {} for partition_key, data_dict in merged_dicts.items(): data_table = pa.Table.from_pydict(data_dict) @@ -262,7 +267,7 @@ def divide_chunks(list_to_divide, num_chunks): pc.sort_indices(data_table, sort_keys=[(pk, "ascending") for pk in dataset_schema.primary_keys]), ) - # 7. Persist the partition table to s3 + # 8. Persist the partition table to s3 written_dict = {} for partition_key, pa_table in arrow_dict.items(): self.save_partition_tbl( From d59de32aa212064e1b558eec23f212b3e01adc03 Mon Sep 17 00:00:00 2001 From: Isaak Willett Date: Wed, 5 Oct 2022 17:42:43 +0000 Subject: [PATCH 4/8] adds synchornous shuffling to basic persistor --- tests/core/test_persistence.py | 169 ++++++++++++++++++++++++++++++--- wicker/core/column_files.py | 1 + wicker/core/datasets.py | 21 ++-- wicker/core/persistance.py | 44 ++++++--- 4 files changed, 204 insertions(+), 31 deletions(-) diff --git a/tests/core/test_persistence.py b/tests/core/test_persistence.py index 990717d..40a16c6 100644 --- a/tests/core/test_persistence.py +++ b/tests/core/test_persistence.py @@ -2,14 +2,20 @@ import os import random import uuid -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, Iterable, List, Tuple +from unittest.mock import patch import pyarrow.parquet as papq import pytest from wicker import schema from wicker.core.config import get_config -from wicker.core.persistance import BasicPersistor +from wicker.core.persistance import ( + BasicPersistor, + ColumnBytesFileWriter, + ParsedExample, + PointerParsedExample, +) from wicker.core.storage import S3PathFactory from wicker.schema.schema import DatasetSchema from wicker.testing.storage import FakeS3DataStorage @@ -17,8 +23,9 @@ DATASET_NAME = "dataset" DATASET_VERSION = "0.0.1" SCHEMA = schema.DatasetSchema( - primary_keys=["bar", "foo"], + primary_keys=["global_index", "bar", "foo"], fields=[ + schema.IntField("global_index"), schema.IntField("foo"), schema.StringField("bar"), schema.BytesField("bytescol"), @@ -28,6 +35,7 @@ ( "train" if i % 2 == 0 else "test", { + "global_index": i, "foo": random.randint(0, 10000), "bar": str(uuid.uuid4()), "bytescol": b"0", @@ -53,7 +61,6 @@ def assert_written_correctness(tmpdir: str) -> None: assert DATASET_NAME in os.listdir(os.path.join(tmpdir, prefix)) assert DATASET_VERSION in os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME)) for partition in ["train", "test"]: - print(os.listdir(os.path.join(tmpdir, prefix))) columns_path = os.path.join(tmpdir, prefix, "__COLUMN_CONCATENATED_FILES__") all_read_bytes = b"" for filename in os.listdir(columns_path): @@ -65,7 +72,10 @@ def assert_written_correctness(tmpdir: str) -> None: # Load parquet file and assert ordering of primary_key assert f"{partition}.parquet" in os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION)) tbl = papq.read_table(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION, f"{partition}.parquet")) - foobar = [(barval.as_py(), fooval.as_py()) for fooval, barval in zip(tbl["foo"], tbl["bar"])] + foobar = [ + (glo_idx.as_py(), barval.as_py(), fooval.as_py()) + for glo_idx, fooval, barval in zip(tbl["global_index"], tbl["foo"], tbl["bar"]) + ] assert foobar == sorted(foobar) @@ -74,7 +84,7 @@ def assert_written_correctness(tmpdir: str) -> None: [({}, DATASET_NAME, DATASET_VERSION, SCHEMA, copy.deepcopy(EXAMPLES_DUPES))], indirect=["mock_basic_persistor"], ) -def test_basic_persistor( +def test_basic_persistor_no_shuffle( mock_basic_persistor: Tuple[BasicPersistor, str], dataset_name: str, dataset_version: str, @@ -87,9 +97,146 @@ def test_basic_persistor( Ensure we read the right file locations, the right amount of bytes, and the ordering is correct. """ + # in order to assert that we are not shuffling we are going to sub out the + # persist partition function and get average distance on global index + # if it is == 2 (ie: samples are adjacent in partitions) then shuffling has occured + def mock_persist_wicker_partition( + self, + spark_partition_iter: Iterable[Tuple[str, ParsedExample]], + schema: schema.DatasetSchema, + s3_storage: FakeS3DataStorage, + s3_path_factory: S3PathFactory, + target_max_column_file_numrows: int = 50, + ) -> Iterable[Tuple[str, PointerParsedExample]]: + # set up the global sum and counter for calcing mean + global_sum = 0 + global_counter = 0 + # we still have to do all of the regular logic to test writing + column_bytes_file_writers: Dict[str, ColumnBytesFileWriter] = {} + heavy_pointer_columns = schema.get_pointer_columns() + metadata_columns = schema.get_non_pointer_columns() + previous_value, previous_parition = None, None + + for partition, example in spark_partition_iter: + # if the previous value is unset or the parition has changed + if not previous_value or previous_parition != partition: + previous_value = example["global_index"] + previous_parition = partition + # if we can calculate the distance because we are on same parition + # and the previous value is not None + else: + current_diff = abs(example["global_index"] - previous_value) + previous_value = example["global_index"] + previous_parition = partition + global_sum += current_diff + global_counter += 1 + # Create ColumnBytesFileWriter lazily as required, for each partition + if partition not in column_bytes_file_writers: + column_bytes_file_writers[partition] = ColumnBytesFileWriter( + s3_storage, + s3_path_factory, + target_file_rowgroup_size=target_max_column_file_numrows, + ) + + # Write to ColumnBytesFileWriter and return only metadata + heavy-pointers + parquet_metadata: Dict[str, Any] = {col: example[col] for col in metadata_columns} + for col in heavy_pointer_columns: + loc = column_bytes_file_writers[partition].add(col, example[col]) + parquet_metadata[col] = loc.to_bytes() + yield partition, parquet_metadata + + # Flush all writers when finished + for partition in column_bytes_file_writers: + column_bytes_file_writers[partition].close() + # assert that we are at mean 2 and that we have not shuffled + mean = global_sum / global_counter + assert mean == 2.0 + + with patch("wicker.core.persistance.AbstractDataPersistor.persist_wicker_partition", mock_persist_wicker_partition): + # create the mock basic persistor + mock_basic_persistor_obj, tempdir = mock_basic_persistor + # persist the dataset + mock_basic_persistor_obj.persist_wicker_dataset(dataset_name, dataset_version, dataset_schema, dataset) + # assert the dataset is correctly written + assert_written_correctness(tempdir) + + +@pytest.mark.parametrize( + "mock_basic_persistor, dataset_name, dataset_version, dataset_schema, dataset", + [({}, DATASET_NAME, DATASET_VERSION, SCHEMA, copy.deepcopy(EXAMPLES_DUPES))], + indirect=["mock_basic_persistor"], +) +def test_basic_persistor_shuffle( + mock_basic_persistor: Tuple[BasicPersistor, str], + dataset_name: str, + dataset_version: str, + dataset_schema: DatasetSchema, + dataset: List[Tuple[str, Dict[str, Any]]], +): + """Test if the basic persistor saves the correct data and shuffles it into different partitions + + Ensure we read the right file locations, the right amount of bytes, + and the ordering is correct. + """ + # in order to assert that we are shuffling we are going to sub out the + # persist partition function and get average distance on global index + # if it is != 2 (ie: samples are adjacent in partitions) then shuffling has occured + def mock_persist_wicker_partition( + self, + spark_partition_iter: Iterable[Tuple[str, ParsedExample]], + schema: schema.DatasetSchema, + s3_storage: FakeS3DataStorage, + s3_path_factory: S3PathFactory, + target_max_column_file_numrows: int = 50, + ) -> Iterable[Tuple[str, PointerParsedExample]]: + # set up the global sum and counter for calcing mean + global_sum = 0 + global_counter = 0 + # we still have to do all of the regular logic to test writing + column_bytes_file_writers: Dict[str, ColumnBytesFileWriter] = {} + heavy_pointer_columns = schema.get_pointer_columns() + metadata_columns = schema.get_non_pointer_columns() + previous_value, previous_parition = None, None + + for partition, example in spark_partition_iter: + # if the previous value is unset or the parition has changed + if not previous_value or previous_parition != partition: + previous_value = example["global_index"] + previous_parition = partition + # if we can calculate the distance because we are on same parition + # and the previous value is not None + else: + current_diff = abs(example["global_index"] - previous_value) + previous_value = example["global_index"] + previous_parition = partition + global_sum += current_diff + global_counter += 1 + # Create ColumnBytesFileWriter lazily as required, for each partition + if partition not in column_bytes_file_writers: + column_bytes_file_writers[partition] = ColumnBytesFileWriter( + s3_storage, + s3_path_factory, + target_file_rowgroup_size=target_max_column_file_numrows, + ) + + # Write to ColumnBytesFileWriter and return only metadata + heavy-pointers + parquet_metadata: Dict[str, Any] = {col: example[col] for col in metadata_columns} + for col in heavy_pointer_columns: + loc = column_bytes_file_writers[partition].add(col, example[col]) + parquet_metadata[col] = loc.to_bytes() + yield partition, parquet_metadata + + # Flush all writers when finished + for partition in column_bytes_file_writers: + column_bytes_file_writers[partition].close() + # assert that we are not at mean 2 and that we have shuffled successfully + mean = global_sum / global_counter + assert mean != 2.0 + # create the mock basic persistor - mock_basic_persistor_obj, tempdir = mock_basic_persistor - # persist the dataset - mock_basic_persistor_obj.persist_wicker_dataset(dataset_name, dataset_version, dataset_schema, dataset) - # assert the dataset is correctly written - assert_written_correctness(tempdir) + with patch("wicker.core.persistance.AbstractDataPersistor.persist_wicker_partition", mock_persist_wicker_partition): + mock_basic_persistor_obj, tempdir = mock_basic_persistor + # persist and shuffle the dataset + mock_basic_persistor_obj.persist_wicker_dataset(dataset_name, dataset_version, dataset_schema, dataset, False) + # assert the dataset is correctly written + assert_written_correctness(tempdir) diff --git a/wicker/core/column_files.py b/wicker/core/column_files.py index c70606d..ae3dd50 100644 --- a/wicker/core/column_files.py +++ b/wicker/core/column_files.py @@ -51,6 +51,7 @@ def to_bytes(self) -> bytes: @classmethod def from_bytes(cls, b: bytes) -> ColumnBytesFileLocationV1: protocol_version = int.from_bytes(b[0:1], "little") + if protocol_version != 1: raise ValueError(f"Unable to parse ColumnBytesFileLocation with protocol_version={protocol_version}") _, file_id, byte_offset, data_size = struct.unpack(ColumnBytesFileLocationV1.STRUCT_PACK_FMT, b) diff --git a/wicker/core/datasets.py b/wicker/core/datasets.py index 94d04a6..7bc1a98 100644 --- a/wicker/core/datasets.py +++ b/wicker/core/datasets.py @@ -92,9 +92,10 @@ def __init__( self._partition = DatasetPartition(dataset_id=self._dataset_id, partition=dataset_partition_name) self._dataset_definition = DatasetDefinition( self._dataset_id, - schema=self.schema(), + schema=self.schema, ) + @property def schema(self) -> DatasetSchema: if self._schema is None: schema_path = self._s3_path_factory.get_dataset_schema_path(self._dataset_id) @@ -107,20 +108,24 @@ def schema(self) -> DatasetSchema: ) return self._schema + @property def arrow_table(self) -> pyarrow.Table: - path = self._s3_path_factory.get_dataset_partition_path(self._partition, s3_prefix=False) if not self._arrow_table: + path = self._s3_path_factory.get_dataset_partition_path(self._partition, s3_prefix=False) self._arrow_table = papq.read_table(path, filesystem=self._pa_filesystem) return self._arrow_table def __len__(self) -> int: - return len(self.arrow_table()) + return len(self.arrow_table) def __getitem__(self, idx: int) -> Dict[str, Any]: - tbl = self.arrow_table() - columns = self._columns_to_load if self._columns_to_load is not None else tbl.column_names - row = {col: tbl[col][idx].as_py() for col in columns} return dataloading.load_example( - self._column_bytes_file_cache.resolve_pointers(row, self.schema()), - self.schema(), + self._column_bytes_file_cache.resolve_pointers(self._get_row_pq_table(idx), self.schema), + self.schema, ) + + def _get_row_pq_table(self, idx: int): + tbl = self.arrow_table + columns = self._columns_to_load if self._columns_to_load is not None else tbl.column_names + row = {col: tbl[col][idx].as_py() for col in columns} + return row diff --git a/wicker/core/persistance.py b/wicker/core/persistance.py index 6061638..952f814 100644 --- a/wicker/core/persistance.py +++ b/wicker/core/persistance.py @@ -1,4 +1,5 @@ import abc +import random from typing import Any, Dict, Iterable, List, Optional, Tuple import pyarrow as pa @@ -155,6 +156,7 @@ def persist_wicker_dataset( dataset: Any, s3_storage: S3DataStorage = S3DataStorage(), s3_path_factory: S3PathFactory = S3PathFactory(), + shuffle: bool = False, ) -> Optional[Dict[str, int]]: """ Persist wicker dataset public facing api function, for api consistency. @@ -170,9 +172,15 @@ def persist_wicker_dataset( :type s3_storage: S3DataStorage :param s3_path_factory: s3 path abstraction :type s3_path_factory: S3PathFactory + :param shuffle: to shuffle or not, is this a question? + :type shuffle: str """ return BasicPersistor(s3_storage, s3_path_factory).persist_wicker_dataset( - dataset_name, dataset_version, dataset_schema, dataset + dataset_name, + dataset_version, + dataset_schema, + dataset, + shuffle, ) @@ -189,7 +197,12 @@ def __init__( super().__init__(s3_storage, s3_path_factory) def persist_wicker_dataset( - self, dataset_name: str, dataset_version: str, dataset_schema: schema_module.DatasetSchema, dataset: Any + self, + dataset_name: str, + dataset_version: str, + dataset_schema: schema_module.DatasetSchema, + dataset: Any, + shuffle: bool = False, ) -> Optional[Dict[str, int]]: """ Persist a user defined dataset, pushing data to s3 in a basic manner @@ -202,6 +215,8 @@ def persist_wicker_dataset( :type dataset_schema: wicker.schema.schema.DatasetSchema :param dataset: Data of the dataset :type dataset: User defined + :param shuffle: to shuffle or not, is this a question? + :type shuffle: str """ # what needs to be done within this function # 1. Check if the variables are set @@ -223,20 +238,25 @@ def persist_wicker_dataset( dataset_0 = [(row[0], self.parse_row(row[1], dataset_schema)) for row in dataset] # 4. Sort the dataset if not sorted - sorted_dataset_0 = sorted(dataset_0, key=lambda tup: tup[0]) + dataset_1 = sorted(dataset_0, key=lambda tup: tup[0]) + + # 5. if we have shuffling, shuffle the dataset before partition + # ensures proper and random shuffling + if shuffle: + random.shuffle(dataset_1) - # 5. Partition the dataset into K partitions - num_paritions = len(sorted_dataset_0) // PARTITION_SIZE + # 6. Partition the dataset into K partitions partitions = [] + num_partitions = (len(dataset_1) // PARTITION_SIZE) + 1 def divide_chunks(list_to_divide, num_chunks): # looping till length l - for i in range(0, len(list_to_divide), num_chunks): - partitions.append(list_to_divide[i : i + num_chunks]) + for i in range(0, len(list_to_divide), PARTITION_SIZE): + partitions.append(list_to_divide[i : i + PARTITION_SIZE]) - divide_chunks(sorted_dataset_0, num_paritions) + divide_chunks(dataset_1, num_partitions) - # 6. Persist the partitions to S3 + # 7. Persist the partitions to S3 for partition in partitions: # build a persistence iterator for each parition iterator = self.persist_wicker_partition( @@ -245,11 +265,11 @@ def divide_chunks(list_to_divide, num_chunks): # make sure all yields get called list(iterator) - # 7. Create the parition table, need to combine keys in a way we can form table + # 8. Create the parition table, need to combine keys in a way we can form table # split into k dicts where k is partition number and the data is a list of values # for each key for all the dicts in the partition merged_dicts: Dict[str, Dict[str, List[Any]]] = {} - for partition_key, row in sorted_dataset_0: + for partition_key, row in dataset_1: current_dict: Dict[str, List[Any]] = merged_dicts.get(partition_key, {}) for col in row.keys(): if col in current_dict: @@ -267,7 +287,7 @@ def divide_chunks(list_to_divide, num_chunks): pc.sort_indices(data_table, sort_keys=[(pk, "ascending") for pk in dataset_schema.primary_keys]), ) - # 8. Persist the partition table to s3 + # 9. Persist the partition table to s3 written_dict = {} for partition_key, pa_table in arrow_dict.items(): self.save_partition_tbl( From 15147f4d3b69f4dc4cfb771ea545df49fd2f3c1a Mon Sep 17 00:00:00 2001 From: Isaak Willett Date: Wed, 5 Oct 2022 17:47:02 +0000 Subject: [PATCH 5/8] bugfix partition divisions --- wicker/core/persistance.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/wicker/core/persistance.py b/wicker/core/persistance.py index 6061638..fac1bc3 100644 --- a/wicker/core/persistance.py +++ b/wicker/core/persistance.py @@ -226,15 +226,14 @@ def persist_wicker_dataset( sorted_dataset_0 = sorted(dataset_0, key=lambda tup: tup[0]) # 5. Partition the dataset into K partitions - num_paritions = len(sorted_dataset_0) // PARTITION_SIZE partitions = [] - def divide_chunks(list_to_divide, num_chunks): + def divide_chunks(list_to_divide): # looping till length l - for i in range(0, len(list_to_divide), num_chunks): - partitions.append(list_to_divide[i : i + num_chunks]) + for i in range(0, len(list_to_divide), PARTITION_SIZE): + partitions.append(list_to_divide[i : i + PARTITION_SIZE]) - divide_chunks(sorted_dataset_0, num_paritions) + divide_chunks(sorted_dataset_0) # 6. Persist the partitions to S3 for partition in partitions: From f2b736f6431d712887248d942a4055dedf2f5185 Mon Sep 17 00:00:00 2001 From: Isaak Willett Date: Wed, 5 Oct 2022 17:47:56 +0000 Subject: [PATCH 6/8] remove unused variables --- wicker/core/persistance.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/wicker/core/persistance.py b/wicker/core/persistance.py index 952f814..0e70422 100644 --- a/wicker/core/persistance.py +++ b/wicker/core/persistance.py @@ -247,14 +247,13 @@ def persist_wicker_dataset( # 6. Partition the dataset into K partitions partitions = [] - num_partitions = (len(dataset_1) // PARTITION_SIZE) + 1 - def divide_chunks(list_to_divide, num_chunks): + def divide_chunks(list_to_divide): # looping till length l for i in range(0, len(list_to_divide), PARTITION_SIZE): partitions.append(list_to_divide[i : i + PARTITION_SIZE]) - divide_chunks(dataset_1, num_partitions) + divide_chunks(dataset_1) # 7. Persist the partitions to S3 for partition in partitions: From f242fc6d2489e1dd06b3f2ef9c2931e8a30072a9 Mon Sep 17 00:00:00 2001 From: Isaak Willett Date: Wed, 5 Oct 2022 18:01:34 +0000 Subject: [PATCH 7/8] remove unneeded space --- wicker/core/column_files.py | 1 - 1 file changed, 1 deletion(-) diff --git a/wicker/core/column_files.py b/wicker/core/column_files.py index ae3dd50..c70606d 100644 --- a/wicker/core/column_files.py +++ b/wicker/core/column_files.py @@ -51,7 +51,6 @@ def to_bytes(self) -> bytes: @classmethod def from_bytes(cls, b: bytes) -> ColumnBytesFileLocationV1: protocol_version = int.from_bytes(b[0:1], "little") - if protocol_version != 1: raise ValueError(f"Unable to parse ColumnBytesFileLocation with protocol_version={protocol_version}") _, file_id, byte_offset, data_size = struct.unpack(ColumnBytesFileLocationV1.STRUCT_PACK_FMT, b) From 77e05e9b347ed2f8bbe9e9974a0499a9452db618 Mon Sep 17 00:00:00 2001 From: isaak-willett Date: Wed, 26 Oct 2022 19:20:42 +0000 Subject: [PATCH 8/8] fix merge --- wicker/core/persistance.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/wicker/core/persistance.py b/wicker/core/persistance.py index ac2ba54..0e70422 100644 --- a/wicker/core/persistance.py +++ b/wicker/core/persistance.py @@ -1,8 +1,5 @@ import abc -<<<<<<< HEAD import random -======= ->>>>>>> main from typing import Any, Dict, Iterable, List, Optional, Tuple import pyarrow as pa