diff --git a/deltacat/benchmarking/benchmark_engine.py b/deltacat/benchmarking/benchmark_engine.py index 1f7fe084..e38ca53d 100644 --- a/deltacat/benchmarking/benchmark_engine.py +++ b/deltacat/benchmarking/benchmark_engine.py @@ -1,14 +1,13 @@ import sys import time from contextlib import contextmanager -from typing import Generator, Tuple, Iterator +from typing import Generator, Tuple from deltacat.benchmarking.benchmark_report import BenchmarkMetric, BenchmarkStep from deltacat.storage.rivulet.dataset import Dataset from deltacat.storage.rivulet.reader.query_expression import QueryExpression - @contextmanager def timed_step(description: str) -> Generator[BenchmarkStep, None, None]: """Convenience for computing elapsed time of a block of code as a metric. @@ -24,11 +23,12 @@ def timed_step(description: str) -> Generator[BenchmarkStep, None, None]: class BenchmarkEngine: - def __init__(self, dataset: Dataset): self.dataset = dataset - def load_and_commit(self, field_group, generator, count) -> Tuple[str, BenchmarkStep]: + def load_and_commit( + self, schema_name, generator, count + ) -> Tuple[str, BenchmarkStep]: """Load count number of rows from the generator and commit. :param generator: row generator @@ -36,7 +36,7 @@ def load_and_commit(self, field_group, generator, count) -> Tuple[str, Benchmark :return: tuple of the manifest URI and a operation measurement """ desc = f"load {count} from {generator}" - writer = self.dataset.writer(field_group) + writer = self.dataset.writer(schema_name) with timed_step(desc) as step: rows = [generator.generate() for _ in range(count)] writer.write(rows) @@ -45,21 +45,30 @@ def load_and_commit(self, field_group, generator, count) -> Tuple[str, Benchmark return result, step def scan(self) -> Tuple[set[any], BenchmarkStep]: - """Scans the rows and prints some basic statistics about the manifest""" + """ + Scans the rows in dataset and prints some basic statistics about the manifest + + :return: Tuple[set[any], BenchmarkStep] - a tuple containing a set of merge keys and a benchmark step with metrics + """ keys = set() object_count = 0 size_b = 0 + # Note that we expect single col merge keys so we can return key set + # this will fail with validation error if dataset has multiple merge keys + merge_key_name = self.dataset.schemas["all"].get_merge_key() with timed_step("full scan") as step: for row in self.dataset.scan(QueryExpression()).to_pydict(): object_count += 1 size_b += sum([sys.getsizeof(x) for x in row.values()]) - keys.add(row.get(self.dataset.schema.primary_key.name)) + keys.add(row.get(merge_key_name)) # TODO replace with the actual metrics we want to measure step.add(BenchmarkMetric("rows read", object_count)) step.add(BenchmarkMetric("size", size_b / (1024 * 1024), "MB")) return keys, step - def run_queries(self, description, manifest_uri, queries: list[QueryExpression]) -> BenchmarkStep: + def run_queries( + self, description, manifest_uri, queries: list[QueryExpression] + ) -> BenchmarkStep: object_count = 0 size_b = 0 with timed_step(description) as step: diff --git a/deltacat/benchmarking/benchmark_report.py b/deltacat/benchmarking/benchmark_report.py index a2740874..1f897125 100644 --- a/deltacat/benchmarking/benchmark_report.py +++ b/deltacat/benchmarking/benchmark_report.py @@ -65,12 +65,22 @@ def __str__(self): return suite = self.runs[0].suite - headers = [f"{suite} Operation", "Metric", "Unit", *[r.description for r in self.runs]] + headers = [ + f"{suite} Operation", + "Metric", + "Unit", + *[r.description for r in self.runs], + ] rows = [] for step_tranche in zip(*[r.steps for r in self.runs]): # TODO zip by metric name instead of assuming all metrics are being measured step_name = step_tranche[0].description for metric_tuple in zip(*[x.list_metrics() for x in step_tranche]): - row = [step_name, metric_tuple[0].name, metric_tuple[0].unit, *[p.value for p in metric_tuple]] + row = [ + step_name, + metric_tuple[0].name, + metric_tuple[0].unit, + *[p.value for p in metric_tuple], + ] rows.append(row) return tabulate(rows, headers=headers, tablefmt="fancy_outline") diff --git a/deltacat/benchmarking/conftest.py b/deltacat/benchmarking/conftest.py index fa2495ca..6cf3a43a 100644 --- a/deltacat/benchmarking/conftest.py +++ b/deltacat/benchmarking/conftest.py @@ -14,17 +14,21 @@ ) -@pytest.fixture(autouse=True, scope='function') +@pytest.fixture(autouse=True, scope="function") def report(request): report = BenchmarkReport(request.node.name) + def final_callback(): - terminal_reporter: TerminalReporter = request.config.pluginmanager.get_plugin("terminalreporter") - capture_manager = request.config.pluginmanager.get_plugin('capturemanager') + terminal_reporter: TerminalReporter = request.config.pluginmanager.get_plugin( + "terminalreporter" + ) + capture_manager = request.config.pluginmanager.get_plugin("capturemanager") with capture_manager.global_and_fixture_disabled(): terminal_reporter.ensure_newline() - terminal_reporter.section(request.node.name, sep='-', blue=True, bold=True) + terminal_reporter.section(request.node.name, sep="-", blue=True, bold=True) terminal_reporter.write(str(report)) terminal_reporter.ensure_newline() + request.addfinalizer(final_callback) return report diff --git a/deltacat/benchmarking/data/random_row_generator.py b/deltacat/benchmarking/data/random_row_generator.py index 75497700..8d75435a 100644 --- a/deltacat/benchmarking/data/random_row_generator.py +++ b/deltacat/benchmarking/data/random_row_generator.py @@ -27,7 +27,9 @@ class ImageStyle(Enum): class RandomRowGenerator(RowGenerator): """Generate rows with 'images' that are just randomly-generated bytes""" - def __init__(self, seed=0, tmp_dir=None, style: ImageStyle = ImageStyle.RANDOM_BYTES): + def __init__( + self, seed=0, tmp_dir=None, style: ImageStyle = ImageStyle.RANDOM_BYTES + ): self.seed = seed self.fake = faker.Faker() self.fake.seed_instance(seed) @@ -51,7 +53,9 @@ def _generate_image(self, width, height) -> bytes: @staticmethod def _generate_with_random_bytes(width, height) -> bytes: """Generate random bytes to simulate an image.""" - target_size = math.floor(width * height / 50) # this isn't actually how file size relates to image size + target_size = math.floor( + width * height / 50 + ) # this isn't actually how file size relates to image size # Assumption: we don't actually need images. It suffices to generate arbitrary-length bytes of random characters. return os.urandom(target_size) @@ -59,9 +63,9 @@ def _generate_with_random_bytes(width, height) -> bytes: def _generate_with_pillow(width, height) -> bytes: """Generate actual PNG files in-memory directly using Pillow""" file = BytesIO() - image = Image.new('RGBA', size=(width, height), color=(155, 0, 0)) - image.save(file, 'png') - file.name = 'test.png' + image = Image.new("RGBA", size=(width, height), color=(155, 0, 0)) + image.save(file, "png") + file.name = "test.png" file.seek(0) return file.read() @@ -72,17 +76,19 @@ def _generate_with_faker(self, width, height) -> bytes: root_path=self.temp_dir, rel_path="tmp", ), - size=(width, height)) + size=(width, height), + ) file_name = f"{self.temp_dir}/{rel_name}" - with open(file_name, 'rb') as f: + with open(file_name, "rb") as f: return f.read() - - def generate(self) -> Dict[str,Any]: + def generate(self) -> Dict[str, Any]: return { "id": self.fake.random_int(0, 10_000_000), "source": self.fake.image_url(), - "media": (self._generate_image( - self.fake.random_int(512, 2048), - self.fake.random_int(512, 4096))) + "media": ( + self._generate_image( + self.fake.random_int(512, 2048), self.fake.random_int(512, 4096) + ) + ), } diff --git a/deltacat/benchmarking/data/row_generator.py b/deltacat/benchmarking/data/row_generator.py index 5f35567a..411727d4 100644 --- a/deltacat/benchmarking/data/row_generator.py +++ b/deltacat/benchmarking/data/row_generator.py @@ -1,5 +1,6 @@ from typing import Protocol, Iterator, Dict, Any + class RowGenerator(Protocol): def generate(self) -> Dict[str, Any]: ... diff --git a/deltacat/benchmarking/test_benchmark_pipeline.py b/deltacat/benchmarking/test_benchmark_pipeline.py index 1a1d7957..1264f6c4 100644 --- a/deltacat/benchmarking/test_benchmark_pipeline.py +++ b/deltacat/benchmarking/test_benchmark_pipeline.py @@ -3,32 +3,30 @@ import tempfile from contextlib import contextmanager from random import shuffle - import pytest - -from deltacat.storage.rivulet.field_group import FileSystemFieldGroup, FieldGroup - -pytestmark = pytest.mark.benchmark - -#from rivulet import Datatype, Dataset from deltacat.storage.rivulet.dataset import Dataset from deltacat.storage.rivulet.schema.datatype import Datatype -from deltacat.storage.rivulet.fs.file_store import FileStore from deltacat.storage.rivulet.reader.query_expression import QueryExpression -from deltacat.storage.rivulet.schema.schema import Schema, Field +from deltacat.storage.rivulet.schema.schema import Schema from deltacat.benchmarking.benchmark_engine import BenchmarkEngine from deltacat.benchmarking.benchmark_report import BenchmarkRun, BenchmarkReport from deltacat.benchmarking.benchmark_suite import BenchmarkSuite from deltacat.benchmarking.data.random_row_generator import RandomRowGenerator from deltacat.benchmarking.data.row_generator import RowGenerator +pytestmark = pytest.mark.benchmark + @pytest.fixture def schema(): - return Schema({ - "id": Field("id", Datatype.int32()), - "source": Field("source", Datatype.string()), - "media": Field("media", Datatype.image("png"))}, "id") + return Schema( + [ + ("id", Datatype.int32()), + ("source", Datatype.string()), + ("media", Datatype.image("png")), + ], + "id", + ) @contextmanager @@ -44,10 +42,13 @@ def make_tmpdir(): class LoadAndScanSuite(BenchmarkSuite): """Load some number of rows and scan""" + schema_name = "LoadAndScanSuite" + def __init__(self, dataset: Dataset, schema: Schema, generator, description=None): self.suite = "ReadSuite" self.dataset: Dataset = dataset - self.field_group = self.dataset.new_field_group(schema) + self.schema = schema + self.dataset.add_schema(schema, LoadAndScanSuite.schema_name) self.generator: RowGenerator = generator self.description: str = description or f"{self.dataset} x {self.generator}" @@ -55,7 +56,9 @@ def run(self) -> BenchmarkRun: container = BenchmarkEngine(self.dataset) run = BenchmarkRun(self.suite, self.description) # load a large number of rows - manifest_uri, step = container.load_and_commit(self.field_group, self.generator, 1000) + manifest_uri, step = container.load_and_commit( + LoadAndScanSuite.schema_name, self.generator, 1000 + ) run.add(step) # do a full scan of all rows (and eagerly load them) keys, step = container.scan() @@ -63,13 +66,20 @@ def run(self) -> BenchmarkRun: # randomly retrieve all keys one-by-one from the dataset random_keys = list(keys) shuffle(random_keys) - step = container.run_queries("load all keys individually", manifest_uri, - [QueryExpression().with_primary_key(k) for k in random_keys]) + step = container.run_queries( + "load all keys individually", + manifest_uri, + [QueryExpression().with_key(k) for k in random_keys], + ) run.add(step) # split into 4 key ranges and get them individually quartiles = self._generate_quartiles(keys) - expressions = [QueryExpression().with_primary_range(start, end) for (start, end) in quartiles] - step = container.run_queries("load key ranges by quartile", manifest_uri, expressions) + expressions = [ + QueryExpression().with_range(start, end) for (start, end) in quartiles + ] + step = container.run_queries( + "load key ranges by quartile", manifest_uri, expressions + ) run.add(step) return run @@ -77,7 +87,7 @@ def run(self) -> BenchmarkRun: def _generate_quartiles(keys): sorted_keys = sorted(keys) size = len(keys) - starts = list(range(0, size, math.ceil(size/4))) + starts = list(range(0, size, math.ceil(size / 4))) ends = list([x - 1 for x in starts[1:]]) ends.append(size - 1) quartiles = list(zip(starts, ends)) @@ -87,7 +97,22 @@ def _generate_quartiles(keys): def test_suite1(schema: Schema, report: BenchmarkReport): with make_tmpdir() as temp_dir: generator = RandomRowGenerator(123, temp_dir) - report.add(LoadAndScanSuite(Dataset(temp_dir), schema, generator, "SST (rand)").run()) + report.add( + LoadAndScanSuite( + Dataset(dataset_name="test_suite1_ds1", metadata_uri=temp_dir), + schema, + generator, + "SST (rand)", + ).run() + ) + with make_tmpdir() as temp_dir: generator = RandomRowGenerator(123, temp_dir) - report.add(LoadAndScanSuite(Dataset(temp_dir), schema, generator, "dupe").run()) + report.add( + LoadAndScanSuite( + Dataset(dataset_name="test_suite1_ds2", metadata_uri=temp_dir), + schema, + generator, + "dupe", + ).run() + ) diff --git a/deltacat/storage/README.md b/deltacat/storage/README.md index dd5b1d8d..8f1777cf 100644 --- a/deltacat/storage/README.md +++ b/deltacat/storage/README.md @@ -71,9 +71,9 @@ a **Name Resolution Directory** to map the object's mutable name or alias back t **Name Mapping File** -The format of the **Name Mapping File** file is: +The format of the **Name Mapping File** file is: `__.` -Where `object_id` is the name of the associated object's **Immutable ID** directory. +Where `object_id` is the name of the associated object's **Immutable ID** directory. Note that (except **Immutable ID**) this is the same format used by **Metadata Revision Files**, and the same process is employed to `create`, `update`, and `delete` name mappings. ### Transaction Log Directory diff --git a/deltacat/storage/rivulet/dataset.py b/deltacat/storage/rivulet/dataset.py index b169603c..ae86918d 100644 --- a/deltacat/storage/rivulet/dataset.py +++ b/deltacat/storage/rivulet/dataset.py @@ -28,6 +28,7 @@ # These are the hardcoded default schema names ALL = "all" +ALL = "all" DEFAULT = "default" @@ -233,7 +234,9 @@ def from_parquet( dataset_schema = Schema.from_pyarrow(pyarrow_schema, merge_keys) # TODO the file URI never gets stored/saved, do we need to do so? - dataset = cls(dataset_name=name, metadata_uri=metadata_uri, schema=dataset_schema) + dataset = cls( + dataset_name=name, metadata_uri=metadata_uri, schema=dataset_schema + ) # TODO: avoid write! associate fields with their source data. writer = dataset.writer() @@ -246,12 +249,12 @@ def from_parquet( @classmethod def from_json( - cls, - name: str, - file_uri: str, - merge_keys: str | Iterable[str], - metadata_uri: Optional[str] = None, - schema_mode: str = "union", + cls, + name: str, + file_uri: str, + merge_keys: str | Iterable[str], + metadata_uri: Optional[str] = None, + schema_mode: str = "union", ) -> "Dataset": """ Create a Dataset from a single JSON file. @@ -279,7 +282,9 @@ def from_json( dataset_schema = Schema.from_pyarrow(pyarrow_schema, merge_keys) # Create the Dataset instance - dataset = cls(dataset_name=name, metadata_uri=metadata_uri, schema=dataset_schema) + dataset = cls( + dataset_name=name, metadata_uri=metadata_uri, schema=dataset_schema + ) writer = dataset.writer() writer.write(pyarrow_table.to_batches()) @@ -289,12 +294,12 @@ def from_json( @classmethod def from_csv( - cls, - name: str, - file_uri: str, - merge_keys: str | Iterable[str], - metadata_uri: Optional[str] = None, - schema_mode: str = "union", + cls, + name: str, + file_uri: str, + merge_keys: str | Iterable[str], + metadata_uri: Optional[str] = None, + schema_mode: str = "union", ) -> "Dataset": """ Create a Dataset from a single JSON file. @@ -322,7 +327,9 @@ def from_csv( dataset_schema = Schema.from_pyarrow(pyarrow_schema, merge_keys) # Create the Dataset instance - dataset = cls(dataset_name=name, metadata_uri=metadata_uri, schema=dataset_schema) + dataset = cls( + dataset_name=name, metadata_uri=metadata_uri, schema=dataset_schema + ) writer = dataset.writer() writer.write(table.to_batches()) @@ -334,9 +341,14 @@ def print(self, num_records: int = 10) -> None: """Prints the first `num_records` records in the dataset.""" records = self.scan().to_pydict() for record in itertools.islice(records, num_records): - print(record) + print(record) - def export(self, file_uri: str, format: str = "parquet", query: QueryExpression=QueryExpression()) -> None: + def export( + self, + file_uri: str, + format: str = "parquet", + query: QueryExpression = QueryExpression(), + ) -> None: """Export the dataset to a file. Args: diff --git a/deltacat/storage/rivulet/feather/file_reader.py b/deltacat/storage/rivulet/feather/file_reader.py index f74a43b5..755f4c48 100644 --- a/deltacat/storage/rivulet/feather/file_reader.py +++ b/deltacat/storage/rivulet/feather/file_reader.py @@ -24,7 +24,13 @@ class FeatherFileReader(FileReader[RecordBatchRowIndex]): TODO can consider abstracting code between this and ParquetFileReader """ - def __init__(self, sst_row: SSTableRow, file_store: FileStore, primary_key: str, schema: Schema): + def __init__( + self, + sst_row: SSTableRow, + file_store: FileStore, + primary_key: str, + schema: Schema, + ): self.sst_row = sst_row self.input = file_store.new_input_file(self.sst_row.uri) @@ -120,7 +126,11 @@ def __advance_record_batch(self): self._pk_col = self._curr_batch[self.key] # Filter the batch to only include fields in the schema # Pyarrow select will throw a ValueError if the field is not in the schema - fields = [field for field in self.schema.keys() if field in self._curr_batch.schema.names] + fields = [ + field + for field in self.schema.keys() + if field in self._curr_batch.schema.names + ] self._curr_batch = self._curr_batch.select(fields) except ValueError: raise StopIteration(f"Ended iteration at batch {self._curr_batch_index}") diff --git a/deltacat/storage/rivulet/metastore/sst_interval_tree.py b/deltacat/storage/rivulet/metastore/sst_interval_tree.py index d3cb2ac3..e8fee696 100644 --- a/deltacat/storage/rivulet/metastore/sst_interval_tree.py +++ b/deltacat/storage/rivulet/metastore/sst_interval_tree.py @@ -207,8 +207,14 @@ def get_sorted_block_groups( f"min_key {min_key} cannot be greater than max_key {max_key}" ) - min_key_idx = max(0, bisect_left(key_boundaries, min_key) - 1) if min_key is not None else None - max_key_idx = bisect_right(key_boundaries, max_key) + 1 if max_key is not None else None + min_key_idx = ( + max(0, bisect_left(key_boundaries, min_key) - 1) + if min_key is not None + else None + ) + max_key_idx = ( + bisect_right(key_boundaries, max_key) + 1 if max_key is not None else None + ) boundary_table = key_boundaries[min_key_idx:max_key_idx] for lower_bound, upper_bound in pairwise(boundary_table): diff --git a/deltacat/storage/rivulet/reader/data_reader.py b/deltacat/storage/rivulet/reader/data_reader.py index 0041c803..ff4f7891 100644 --- a/deltacat/storage/rivulet/reader/data_reader.py +++ b/deltacat/storage/rivulet/reader/data_reader.py @@ -49,7 +49,11 @@ class FileReader( @abstractmethod def __init__( - self, sst_row: SSTableRow, file_store: FileStore, primary_key: str, schema: Schema + self, + sst_row: SSTableRow, + file_store: FileStore, + primary_key: str, + schema: Schema, ) -> None: """ Required constructor (see: FileReaderRegistrar) diff --git a/deltacat/tests/storage/rivulet/reader/test_data_scan.py b/deltacat/tests/storage/rivulet/reader/test_data_scan.py index 32e9cbf1..130fa175 100644 --- a/deltacat/tests/storage/rivulet/reader/test_data_scan.py +++ b/deltacat/tests/storage/rivulet/reader/test_data_scan.py @@ -1,11 +1,10 @@ import pytest -from deltacat.tests.storage.rivulet.test_utils import ( - verify_pyarrow_scan -) +from deltacat.tests.storage.rivulet.test_utils import verify_pyarrow_scan import pyarrow as pa from deltacat.storage.rivulet import Schema, Field, Datatype from deltacat.storage.rivulet.dataset import Dataset + @pytest.fixture def combined_schema(): return Schema( @@ -18,6 +17,7 @@ def combined_schema(): ] ) + @pytest.fixture def initial_schema(): return Schema( @@ -28,6 +28,7 @@ def initial_schema(): ] ) + @pytest.fixture def extended_schema(): return Schema( @@ -38,6 +39,7 @@ def extended_schema(): ] ) + @pytest.fixture def sample_data(): return { @@ -46,6 +48,7 @@ def sample_data(): "age": [25, 30, 35], } + @pytest.fixture def extended_data(): return { @@ -54,12 +57,14 @@ def extended_data(): "gender": ["male", "female", "male"], } + @pytest.fixture def combined_data(sample_data, extended_data): data = sample_data.copy() data.update(extended_data) return data + @pytest.fixture def parquet_data(tmp_path, sample_data): parquet_path = tmp_path / "test.parquet" @@ -67,6 +72,7 @@ def parquet_data(tmp_path, sample_data): pa.parquet.write_table(table, parquet_path) return parquet_path + @pytest.fixture def sample_dataset(parquet_data, tmp_path): return Dataset.from_parquet( @@ -78,7 +84,13 @@ def sample_dataset(parquet_data, tmp_path): def test_end_to_end_scan_with_multiple_schemas( - sample_dataset, initial_schema, extended_schema, combined_schema, sample_data, extended_data, combined_data + sample_dataset, + initial_schema, + extended_schema, + combined_schema, + sample_data, + extended_data, + combined_data, ): # Verify initial scan. verify_pyarrow_scan(sample_dataset.scan().to_arrow(), initial_schema, sample_data) @@ -95,7 +107,13 @@ def test_end_to_end_scan_with_multiple_schemas( writer.flush() # Verify scan with the extended schema retrieves only extended datfa - verify_pyarrow_scan(sample_dataset.scan(schema_name="schema2").to_arrow(), extended_schema, extended_data) + verify_pyarrow_scan( + sample_dataset.scan(schema_name="schema2").to_arrow(), + extended_schema, + extended_data, + ) # Verify a combined scan retrieves data matching the combined schema - verify_pyarrow_scan(sample_dataset.scan().to_arrow(), combined_schema, combined_data) + verify_pyarrow_scan( + sample_dataset.scan().to_arrow(), combined_schema, combined_data + ) diff --git a/deltacat/tests/storage/rivulet/test_sst_interval_tree.py b/deltacat/tests/storage/rivulet/test_sst_interval_tree.py index be066830..0734722f 100644 --- a/deltacat/tests/storage/rivulet/test_sst_interval_tree.py +++ b/deltacat/tests/storage/rivulet/test_sst_interval_tree.py @@ -173,9 +173,8 @@ def test_build_sst_with_bounds( expected = _build_ordered_block_groups(expected_block_groups[0:1]) assert expected == block_groups_filtered -def test_build_sst_with_non_zero_min_key_matching_global_min_key( - manifest_context1 -): + +def test_build_sst_with_non_zero_min_key_matching_global_min_key(manifest_context1): # Using a non-0 value since 0 evaluates to False min_key = 1 max_key = 95 @@ -185,9 +184,19 @@ def test_build_sst_with_non_zero_min_key_matching_global_min_key( t.add_sst_table(SSTable([sst_row], min_key, max_key), manifest_context1) block_groups_filtered = t.get_sorted_block_groups(min_key, min_key + 1) - expected = _build_ordered_block_groups([ - BlockGroup(min_key, max_key, {manifest_context1.schema: frozenset([Block(sst_row, manifest_context1)])}) - ]) + expected = _build_ordered_block_groups( + [ + BlockGroup( + min_key, + max_key, + { + manifest_context1.schema: frozenset( + [Block(sst_row, manifest_context1)] + ) + }, + ) + ] + ) assert expected == block_groups_filtered diff --git a/deltacat/tests/storage/rivulet/test_utils.py b/deltacat/tests/storage/rivulet/test_utils.py index 6b3ae6de..6a9c2f84 100644 --- a/deltacat/tests/storage/rivulet/test_utils.py +++ b/deltacat/tests/storage/rivulet/test_utils.py @@ -8,7 +8,6 @@ from deltacat.storage.rivulet.writer.dataset_writer import DatasetWriter from deltacat.storage.rivulet.mvp.Table import MvpTable, MvpRow -from deltacat.storage.rivulet.reader.data_scan import DataScan from deltacat.storage.rivulet import Schema from typing import Dict, List, Generator, Set @@ -96,7 +95,12 @@ def create_dataset_for_method(temp_dir: str): dataset_name=f"dataset-${caller_frame.function}", metadata_uri=dataset_dir ) -def verify_pyarrow_scan(scan_result: Generator[RecordBatch, None, None], expected_schema: Schema, expected_data: dict): + +def verify_pyarrow_scan( + scan_result: Generator[RecordBatch, None, None], + expected_schema: Schema, + expected_data: dict, +): record_batches = list(scan_result) assert record_batches, "Scan should return at least one record batch." @@ -104,12 +108,14 @@ def verify_pyarrow_scan(scan_result: Generator[RecordBatch, None, None], expecte expected_fields = {field.name for field in expected_schema.values()} scanned_fields = set(combined_table.schema.names) - assert scanned_fields == expected_fields, ( - f"Scanned fields {scanned_fields} do not match expected fields {expected_fields}." - ) + assert ( + scanned_fields == expected_fields + ), f"Scanned fields {scanned_fields} do not match expected fields {expected_fields}." for field in expected_fields: - assert field in combined_table.column_names, f"Field '{field}' is missing in the scan result." - assert combined_table[field].to_pylist() == expected_data[field], ( - f"Field '{field}' data does not match expected values." - ) + assert ( + field in combined_table.column_names + ), f"Field '{field}' is missing in the scan result." + assert ( + combined_table[field].to_pylist() == expected_data[field] + ), f"Field '{field}' data does not match expected values." diff --git a/deltacat/tests/test_utils/message_pack_utils.py b/deltacat/tests/test_utils/message_pack_utils.py index 3cbe7fa4..298770f5 100644 --- a/deltacat/tests/test_utils/message_pack_utils.py +++ b/deltacat/tests/test_utils/message_pack_utils.py @@ -4,31 +4,34 @@ import os import shutil + def _convert_bytes_to_base64_str(obj): - if isinstance(obj, dict): - for key, value in obj.items(): - if isinstance(value, bytes): - obj[key] = base64.b64encode(value).decode('utf-8') - elif isinstance(value, list): - _convert_bytes_to_base64_str(value) - elif isinstance(value,dict): - _convert_bytes_to_base64_str(value) - elif isinstance(obj, list): - for i, item in enumerate(obj): - if isinstance(item, bytes): - obj[i] = base64.b64encode(item).decode('utf-8') - elif isinstance(item, (dict, list)): - _convert_bytes_to_base64_str(item) + if isinstance(obj, dict): + for key, value in obj.items(): + if isinstance(value, bytes): + obj[key] = base64.b64encode(value).decode("utf-8") + elif isinstance(value, list): + _convert_bytes_to_base64_str(value) + elif isinstance(value, dict): + _convert_bytes_to_base64_str(value) + elif isinstance(obj, list): + for i, item in enumerate(obj): + if isinstance(item, bytes): + obj[i] = base64.b64encode(item).decode("utf-8") + elif isinstance(item, (dict, list)): + _convert_bytes_to_base64_str(item) + def copy_and_convert(src_dir, dst_dir=None): - ''' + """ Helper function for copying a metastore recursively and converting all messagepack files to json This can be used manually to more easily introspect metastore metadata - ''' + """ if dst_dir is None: from tempfile import mkdtemp + dst_dir = mkdtemp() - print(f'destination is: {dst_dir}') + print(f"destination is: {dst_dir}") if not os.path.exists(dst_dir): os.makedirs(dst_dir) @@ -39,12 +42,12 @@ def copy_and_convert(src_dir, dst_dir=None): if os.path.isdir(src_path): copy_and_convert(src_path, dst_path) else: - if item.endswith('.mpk'): - with open(src_path, 'rb') as f: + if item.endswith(".mpk"): + with open(src_path, "rb") as f: data = msgpack.unpackb(f.read(), raw=False) _convert_bytes_to_base64_str(data) - dst_path = dst_path[:-4] + '.json' - with open(dst_path, 'w') as f: + dst_path = dst_path[:-4] + ".json" + with open(dst_path, "w") as f: json.dump(data, f) else: shutil.copy2(src_path, dst_path) diff --git a/deltacat/utils/export.py b/deltacat/utils/export.py index 9a36a546..72a27e87 100644 --- a/deltacat/utils/export.py +++ b/deltacat/utils/export.py @@ -6,22 +6,26 @@ from deltacat.storage.rivulet.reader.query_expression import QueryExpression -def export_parquet(dataset, file_uri: str, query: QueryExpression=QueryExpression()): + +def export_parquet(dataset, file_uri: str, query: QueryExpression = QueryExpression()): records = dataset.scan(query).to_arrow() table = pa.Table.from_batches(records) pyarrow.parquet.write_table(table, file_uri) -def export_feather(dataset, file_uri: str, query: QueryExpression=QueryExpression()): + +def export_feather(dataset, file_uri: str, query: QueryExpression = QueryExpression()): records = dataset.scan(query).to_arrow() table = pa.Table.from_batches(records) pyarrow.feather.write_feather(table, file_uri) -def export_json(dataset, file_uri: str, query: QueryExpression=QueryExpression()): + +def export_json(dataset, file_uri: str, query: QueryExpression = QueryExpression()): with open(file_uri, "w") as f: for batch in dataset.scan(query).to_pydict(): json.dump(batch, f, indent=2) f.write("\n") + def export_dataset(dataset, file_uri: str, format: str = "parquet", query=None): """ Export the dataset to a file. @@ -42,7 +46,9 @@ def export_dataset(dataset, file_uri: str, format: str = "parquet", query=None): } if format not in export_handlers: - raise ValueError(f"Unsupported format: {format}. Supported formats are {list(export_handlers.keys())}") + raise ValueError( + f"Unsupported format: {format}. Supported formats are {list(export_handlers.keys())}" + ) export_handlers[format](dataset, file_uri, query or QueryExpression())