Skip to content

Commit

Permalink
feat(exporter): Add support for multiple data export using `LocalExpo…
Browse files Browse the repository at this point in the history
…rter`.
  • Loading branch information
Erik Båvenstrand committed May 7, 2024
1 parent b200651 commit ff988b6
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 65 deletions.
14 changes: 4 additions & 10 deletions mleko/dataset/export/base_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, TypeVar

from typing_extensions import TypedDict


TypedDictType = TypeVar("T", bound=TypedDict) # type: ignore
"""Type variable for TypedDict type annotations."""
from typing import Any


class BaseExporter(ABC):
Expand All @@ -19,15 +13,15 @@ class BaseExporter(ABC):
@abstractmethod
def export(
self,
data: Any,
exporter_config: dict[str, Any] | TypedDictType, # type: ignore
data: Any | list[Any],
config: dict[str, Any] | list[dict[str, Any]],
force_recompute: bool = False,
) -> str | Path:
"""Exports the data to a destination.
Args:
data: Data to be exported.
exporter_config: Configuration for the export destination.
config: Configuration for the export destination.
force_recompute: If set to True, forces the data to be exported even if it already exists
at the destination.
"""
Expand Down
102 changes: 74 additions & 28 deletions mleko/dataset/export/local_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import hashlib
import pickle
from pathlib import Path
from typing import Any, Literal, Union
from typing import Any, Callable, Literal, Union

import vaex
from typing_extensions import TypedDict

from mleko.cache.fingerprinters.json_fingerprinter import JsonFingerprinter
Expand All @@ -22,6 +21,9 @@
logger = CustomLogger()
"""A module-level logger instance."""

ExportType = Literal["vaex", "json", "pickle", "joblib", "string"]
"""Type alias for the supported export types."""


class LocalExporterConfig(TypedDict):
"""Configuration for the LocalExporter."""
Expand Down Expand Up @@ -54,15 +56,25 @@ class LocalExporter(BaseExporter):
def __init__(self) -> None:
"""Initializes the LocalExporter."""
super().__init__()
self._exporters: dict[ExportType, Callable[[Path, Any], None]] = {
"vaex": write_vaex_dataframe,
"json": write_json,
"pickle": write_pickle,
"joblib": write_joblib,
"string": write_string,
}

def export( # pyright: ignore [reportIncompatibleMethodOverride]
self, data: Any, exporter_config: LocalExporterConfig, force_recompute: bool = False
) -> Path:
self,
data: Any | list[Any],
config: LocalExporterConfig | list[LocalExporterConfig],
force_recompute: bool = False,
) -> list[Path]:
"""Exports the data to a local file.
Args:
data: Data to be exported.
exporter_config: Configuration for the export destination following the `LocalExporterConfig` schema.
config: Configuration for the export destination following the `LocalExporterConfig` schema.
force_recompute: If set to True, forces the data to be exported even if it already exists on disk.
Examples:
Expand All @@ -71,16 +83,55 @@ def export( # pyright: ignore [reportIncompatibleMethodOverride]
>>> exporter.export("test data", {"export_destination": "test.txt", "export_type": "string"})
Path('test.txt')
"""
export_type = exporter_config["export_type"]
destination = exporter_config["export_destination"]
if isinstance(config, list) and not isinstance(data, list):
msg = "Data is not a list, but the config is a list. Please provide a single config for a single data item."
logger.error(msg)
raise ValueError(msg)

if isinstance(config, list) and isinstance(data, list) and len(config) != len(data):
msg = "Number of data items and number of configs do not match."
logger.error(msg)
raise ValueError(msg)

if not isinstance(config, list) and isinstance(data, list) and config["export_type"] != "json":
msg = (
"Data is a list, but the export type is not 'json'. Please provide a list of configs for a "
"list of data items."
)
logger.error(msg)
raise ValueError(msg)

if not isinstance(config, list):
data = [data]
config = [config]

results: list[Path] = []
for d, c in zip(data, config):
results.append(self._export_single(d, c, force_recompute))

return results

def _export_single(self, data: Any, config: LocalExporterConfig, force_recompute: bool) -> Path:
"""Exports a single data item to a local file.
Args:
data: Data to be exported.
config: Configuration for the export destination following the `LocalExporterConfig` schema.
force_recompute: If set to True, forces the data to be exported even if it already exists on disk.
Returns:
The path to the exported file.
"""
export_type = config["export_type"]
destination = config["export_destination"]
if isinstance(destination, str):
destination = Path(destination)

self._ensure_path_exists(destination)
suffix = destination.suffix

hash_destination = destination.with_suffix(suffix + ".hash")
data_hash = self._hash_data(data)
data_hash = self._hash_data(data, export_type)
if (
not force_recompute
and hash_destination.exists()
Expand All @@ -95,21 +146,7 @@ def export( # pyright: ignore [reportIncompatibleMethodOverride]
else:
logger.info(f"\033[31mCache Miss\033[0m: Exporting data to {str(destination)!r}.")

if export_type == "vaex" and isinstance(data, vaex.DataFrame):
write_vaex_dataframe(destination, data)
elif export_type == "json" and (isinstance(data, dict) or isinstance(data, list)):
write_json(destination, data)
elif export_type == "pickle":
write_pickle(destination, data)
elif export_type == "joblib":
write_joblib(destination, data)
elif export_type == "string" and isinstance(data, str):
write_string(destination, data)
else:
msg = f"Unsupported data type: {type(data)} for the export type: {export_type}."
logger.error(msg)
raise ValueError(msg)

self._run_export_function(data, destination, export_type)
hash_destination.write_text(data_hash)

return destination
Expand All @@ -122,7 +159,16 @@ def _ensure_path_exists(self, path: Path) -> None:
"""
path.parent.mkdir(parents=True, exist_ok=True)

def _hash_data(self, data: Any) -> str:
def _run_export_function(self, data: Any, destination: Path, export_type: ExportType) -> None:
exporter = self._exporters.get(export_type)
if exporter is None:
msg = f"Unsupported data type: {type(data)} for the export type: {export_type}."
logger.error(msg)
raise ValueError(msg)

exporter(destination, data)

def _hash_data(self, data: Any, export_type: ExportType) -> str:
"""Generates a hash for the given data.
Args:
Expand All @@ -131,10 +177,10 @@ def _hash_data(self, data: Any) -> str:
Returns:
A hash of the data.
"""
if isinstance(data, vaex.DataFrame):
if export_type == "vaex":
return VaexFingerprinter().fingerprint(data)
if isinstance(data, (dict, list)):
if export_type == "json":
return JsonFingerprinter().fingerprint(data)
if isinstance(data, str):
return hashlib.md5((data).encode()).hexdigest()
if export_type == "string":
return hashlib.md5(str(data).encode()).hexdigest()
return hashlib.md5(pickle.dumps(data)).hexdigest()
114 changes: 87 additions & 27 deletions tests/dataset/export/test_local_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def test_cached_string_export(self, temporary_directory: Path):
file_path = test_exporter.export(
string_data, {"export_destination": file_save_location, "export_type": "string"}
)
assert file_path.exists()
assert file_path.read_text() == "test data"
assert file_path[0].exists()
assert file_path[0].read_text() == "test data"

with patch("mleko.cache.handlers.write_string") as mocked_write_string:
test_exporter.export(string_data, {"export_destination": file_save_location, "export_type": "string"})
Expand All @@ -63,14 +63,14 @@ def test_force_recompute_string_export(self, temporary_directory: Path):
file_path = test_exporter.export(
string_data, {"export_destination": file_save_location, "export_type": "string"}
)
assert file_path.exists()
assert file_path.read_text() == "test data"
assert file_path[0].exists()
assert file_path[0].read_text() == "test data"

with patch("mleko.dataset.export.local_exporter.write_string") as mocked_write_string:
with patch.object(LocalExporter, "_run_export_function") as mocked_export:
test_exporter.export(
string_data, {"export_destination": file_save_location, "export_type": "string"}, force_recompute=True
)
mocked_write_string.assert_called()
mocked_export.assert_called()

def test_cached_json_export(self, temporary_directory: Path):
"""Should export the data to a local file."""
Expand All @@ -79,16 +79,16 @@ def test_cached_json_export(self, temporary_directory: Path):
file_save_location = temporary_directory / "test.json"
test_exporter = LocalExporter()
file_path = test_exporter.export(json_data_1, {"export_destination": file_save_location, "export_type": "json"})
assert file_path.exists()
assert json.loads(file_path.read_text()) == json_data_1
assert file_path[0].exists()
assert json.loads(file_path[0].read_text()) == json_data_1

with patch("mleko.dataset.export.local_exporter.write_json") as mocked_write_json:
with patch.object(LocalExporter, "_run_export_function") as mocked_export:
file_path = test_exporter.export(
json_data_2, {"export_destination": file_save_location, "export_type": "json"}
)
mocked_write_json.assert_not_called()
assert file_path.exists()
assert json.loads(file_path.read_text()) == json_data_1
mocked_export.assert_not_called()
assert file_path[0].exists()
assert json.loads(file_path[0].read_text()) == json_data_1

def test_cached_pickle_export(self, temporary_directory: Path):
"""Should export the data to a local file."""
Expand All @@ -98,12 +98,12 @@ def test_cached_pickle_export(self, temporary_directory: Path):
file_path = test_exporter.export(
pickle_data, {"export_destination": file_save_location, "export_type": "pickle"}
)
assert file_path.exists()
assert pickle.loads(file_path.read_bytes()) == pickle_data
assert file_path[0].exists()
assert pickle.loads(file_path[0].read_bytes()) == pickle_data

with patch("mleko.dataset.export.local_exporter.write_pickle") as mocked_write_pickle:
with patch.object(LocalExporter, "_run_export_function") as mocked_export:
test_exporter.export(pickle_data, {"export_destination": file_save_location, "export_type": "pickle"})
mocked_write_pickle.assert_not_called()
mocked_export.assert_not_called()

def test_cached_joblib_export(self, temporary_directory: Path):
"""Should export the data to a local file."""
Expand All @@ -113,12 +113,12 @@ def test_cached_joblib_export(self, temporary_directory: Path):
file_path = test_exporter.export(
joblib_data, {"export_destination": file_save_location, "export_type": "joblib"}
)
assert file_path.exists()
assert pickle.loads(file_path.read_bytes()) == joblib_data
assert file_path[0].exists()
assert pickle.loads(file_path[0].read_bytes()) == joblib_data

with patch("mleko.dataset.export.local_exporter.write_joblib") as mocked_write_joblib:
with patch.object(LocalExporter, "_run_export_function") as mocked_export:
test_exporter.export(joblib_data, {"export_destination": file_save_location, "export_type": "joblib"})
mocked_write_joblib.assert_not_called()
mocked_export.assert_not_called()

def test_cached_vae_dataframe_export(self, temporary_directory: Path, example_vaex_dataframe: vaex.DataFrame):
"""Should export the data to a local file."""
Expand All @@ -127,18 +127,18 @@ def test_cached_vae_dataframe_export(self, temporary_directory: Path, example_va
file_path = test_exporter.export(
example_vaex_dataframe, {"export_destination": file_save_location, "export_type": "vaex"}
)
assert file_path.exists()
assert file_path[0].exists()

with patch("mleko.dataset.export.local_exporter.write_vaex_dataframe") as mocked_write_vaex_dataframe:
with patch.object(LocalExporter, "_run_export_function") as mocked_export:
test_exporter.export(
example_vaex_dataframe, {"export_destination": file_save_location, "export_type": "vaex"}
)
mocked_write_vaex_dataframe.assert_not_called()
mocked_export.assert_not_called()

def test_unsupported_data_type_export(self, temporary_directory: Path):
"""Should raise an error for unsupported data types."""
test_exporter = LocalExporter()
with pytest.raises(ValueError):
with pytest.raises(TypeError):
test_exporter.export(1, {"export_destination": temporary_directory / "test.txt", "export_type": "string"})

def test_pickle_diff_instance_object_export(self, temporary_directory: Path):
Expand All @@ -149,8 +149,68 @@ def test_pickle_diff_instance_object_export(self, temporary_directory: Path):
file_path = test_exporter.export(
data_schema, {"export_destination": file_save_location, "export_type": "pickle"}
)
assert file_path.exists()
assert file_path[0].exists()

with patch("mleko.dataset.export.local_exporter.write_pickle") as mocked_write_pickle:
with patch.object(LocalExporter, "_run_export_function") as mocked_export:
test_exporter.export(DataSchema(), {"export_destination": file_save_location, "export_type": "pickle"})
mocked_write_pickle.assert_not_called()
mocked_export.assert_not_called()

def test_config_list_data_single(self, temporary_directory: Path):
"""Should raise ValueError if config is a list but data is not."""
test_exporter = LocalExporter()
with pytest.raises(ValueError):
test_exporter.export(
"Test", [{"export_destination": temporary_directory / "test.txt", "export_type": "string"}]
)

def test_config_data_lengths_not_matching(self, temporary_directory: Path):
"""Should raise ValueError if config and data lengths do not match."""
test_exporter = LocalExporter()
with pytest.raises(ValueError):
test_exporter.export(
["Test", "Test"],
[{"export_destination": temporary_directory / "test.txt", "export_type": "string"}],
)

def test_config_list_data_single_json(self, temporary_directory: Path):
"""Should raise ValueError if data is a list but export type is not json."""
test_exporter = LocalExporter()
with pytest.raises(ValueError):
test_exporter.export(
["Test"],
{"export_destination": temporary_directory / "test.txt", "export_type": "string"},
)

def test_unsupported_export_type(self, temporary_directory: Path):
"""Should raise an error for unsupported export types."""
test_exporter = LocalExporter()
with pytest.raises(ValueError):
test_exporter.export(
"Test",
{
"export_destination": temporary_directory / "test.txt",
"export_type": "unsupported", # type: ignore
},
)

def test_multiple_exports_partially_cached(self, temporary_directory: Path):
"""Should individually cache and export data, only exporting the uncached data."""
test_exporter = LocalExporter()
test_exporter.export(
["Test", DataSchema()],
[
{"export_destination": temporary_directory / "test.txt", "export_type": "string"},
{"export_destination": temporary_directory / "data_schema.pkl", "export_type": "pickle"},
],
)

with patch.object(LocalExporter, "_run_export_function") as mocked_export:
test_exporter.export(
["Test", DataSchema(), "New"],
[
{"export_destination": temporary_directory / "test.txt", "export_type": "string"},
{"export_destination": temporary_directory / "data_schema.pkl", "export_type": "pickle"},
{"export_destination": temporary_directory / "new.txt", "export_type": "string"},
],
)
mocked_export.assert_called_once()

0 comments on commit ff988b6

Please sign in to comment.