Skip to content

Commit

Permalink
Add atomic keyword (#383)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd authored Oct 31, 2024
1 parent 1e5265e commit 0e5c1b3
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 10 deletions.
32 changes: 24 additions & 8 deletions src/pantab/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def frame_to_hyper(
json_columns: Optional[set[str]] = None,
geo_columns: Optional[set[str]] = None,
process_params: Optional[dict[str, str]] = None,
atomic: bool = True,
) -> None:
"""
Convert a DataFrame to a .hyper extract.
Expand All @@ -68,6 +69,7 @@ def frame_to_hyper(
:param json_columns: Columns to be written as a JSON data type
:param geo_columns: Columns to be written as a GEOGRAPHY data type
:param process_params: Parameters to pass to the Hyper Process constructor.
:param atomic: Whether to treat write as atomic. Disabling gives better performance, but failures during write will likely corrupt the Hyper file.
"""
frames_to_hyper(
{table: df},
Expand All @@ -77,6 +79,7 @@ def frame_to_hyper(
json_columns=json_columns,
geo_columns=geo_columns,
process_params=process_params,
atomic=atomic,
)


Expand All @@ -89,6 +92,7 @@ def frames_to_hyper(
json_columns: Optional[set[str]] = None,
geo_columns: Optional[set[str]] = None,
process_params: Optional[dict[str, str]] = None,
atomic: bool = True,
) -> None:
"""
Writes multiple DataFrames to a .hyper extract.
Expand All @@ -100,6 +104,7 @@ def frames_to_hyper(
:param json_columns: Columns to be written as a JSON data type
:param geo_columns: Columns to be written as a GEOGRAPHY data type
:param process_params: Parameters to pass to the Hyper Process constructor.
:param atomic: Whether to treat write as atomic. Disabling gives better performance, but failures during write will likely corrupt the Hyper file.
"""
_validate_table_mode(table_mode)

Expand All @@ -112,10 +117,20 @@ def frames_to_hyper(
if process_params is None:
process_params = {}

tmp_db = pathlib.Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.hyper"

if table_mode == "a" and pathlib.Path(database).exists():
shutil.copy(database, tmp_db)
if not (atomic and pathlib.Path(database).exists()):
needs_copy = False
needs_move = False
path_to_write = database
else:
path_to_write = pathlib.Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.hyper"
needs_move = True
if table_mode == "a":
needs_copy = True
else:
needs_copy = False

if needs_copy:
shutil.copy(database, path_to_write)

def convert_to_table_name(table: pt_types.TableNameType):
if isinstance(table, pt_types.TableauTableName):
Expand All @@ -135,14 +150,15 @@ def convert_to_table_name(table: pt_types.TableNameType):

libpantab.write_to_hyper(
data,
path=str(tmp_db),
path=str(path_to_write),
table_mode=table_mode,
not_null_columns=not_null_columns,
json_columns=json_columns,
geo_columns=geo_columns,
process_params=process_params,
)

# In Python 3.9+ we can just pass the path object, but due to bpo 32689
# and subsequent typeshed changes it is easier to just pass as str for now
shutil.move(str(tmp_db), database)
if needs_move:
# In Python 3.9+ we can just pass the path object, but due to bpo 32689
# and subsequent typeshed changes it is easier to just pass as str for now
shutil.move(str(path_to_write), database)
76 changes: 74 additions & 2 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import re
import unittest.mock

import narwhals as nw
import pandas as pd
Expand Down Expand Up @@ -187,13 +188,84 @@ def test_failed_write_doesnt_overwrite_file(
frame = compat.add_non_writeable_column(frame)
msg = "Unsupported Arrow type"
with pytest.raises(ValueError, match=msg):
pt.frame_to_hyper(frame, tmp_hyper, table="test", table_mode=table_mode)
pt.frames_to_hyper({"test": frame}, tmp_hyper, table_mode=table_mode)
pt.frame_to_hyper(
frame,
tmp_hyper,
table="test",
table_mode=table_mode,
process_params={"default_database_version": "4"},
)
pt.frames_to_hyper(
{"test": frame},
tmp_hyper,
table_mode=table_mode,
process_params={"default_database_version": "4"},
)

# Neither should not update file stats
assert last_modified == tmp_hyper.stat().st_mtime


@unittest.mock.patch("shutil.copy")
@unittest.mock.patch("shutil.move")
def test_new_file_write_does_not_copy_or_move(
mocked_copy,
mocked_move,
frame,
tmp_hyper,
monkeypatch,
table_mode,
):
pt.frame_to_hyper(
frame,
tmp_hyper,
table="test",
table_mode=table_mode,
process_params={"default_database_version": "4"},
)

mocked_copy.assert_not_called()
mocked_move.assert_not_called()


@unittest.mock.patch("shutil.copy")
@unittest.mock.patch("shutil.move")
def test_atomic_keyword_does_not_copy_or_move(
mocked_copy,
mocked_move,
frame,
tmp_hyper,
monkeypatch,
table_mode,
):
pt.frame_to_hyper(
frame,
tmp_hyper,
table="test",
table_mode=table_mode,
process_params={"default_database_version": "4"},
)

pt.frame_to_hyper(
frame,
tmp_hyper,
table="test",
table_mode=table_mode,
atomic=False,
process_params={"default_database_version": "4"},
)
pt.frames_to_hyper(
{"test": frame},
tmp_hyper,
table_mode=table_mode,
atomic=False,
process_params={"default_database_version": "4"},
)

mocked_copy.assert_not_called()
mocked_move.assert_not_called()


def test_duplicate_columns_raises(tmp_hyper):
frame = pd.DataFrame([[1, 1]], columns=[1, 1])
msg = r"Duplicate column names found: \[1, 1\]"
Expand Down

0 comments on commit 0e5c1b3

Please sign in to comment.