Skip to content

Commit

Permalink
More refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Oct 18, 2023
1 parent 27939b5 commit c68ab3b
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 25 deletions.
55 changes: 37 additions & 18 deletions py-polars/polars/testing/asserts/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def assert_frame_equal(
assert_series_equal
assert_frame_not_equal
Notes
-----
When using pytest, it may be worthwhile to shorten Python traceback printing
by passing ``--tb=short``. The default mode tends to be unhelpfully verbose.
More information in the
`pytest docs <https://docs.pytest.org/en/latest/how-to/output.html#modifying-python-traceback-printing>`_.
Examples
--------
>>> from polars.testing import assert_frame_equal
Expand All @@ -82,27 +89,27 @@ def assert_frame_equal(
"""
lazy = _assert_correct_input_type(left, right)
objs = "LazyFrames" if lazy else "DataFrames"
objects = "LazyFrames" if lazy else "DataFrames"

_assert_frame_schema_equal(
left, right, check_column_order=check_column_order, check_dtype=check_dtype
left,
right,
check_column_order=check_column_order,
check_dtype=check_dtype,
objects=objects,
)

if lazy:
left, right = left.collect(), right.collect() # type: ignore[union-attr]

left, right = cast(DataFrame, left), cast(DataFrame, right)

if left.height != right.height:
raise_assertion_error(objs, "length mismatch", left.height, right.height)
raise_assertion_error(
objects, "number of rows does not match", left.height, right.height
)

if not check_row_order:
try:
left = left.sort(by=left.columns)
right = right.sort(by=left.columns)
except ComputeError as exc:
msg = "cannot set `check_row_order=False` on frame with unsortable columns"
raise InvalidAssert(msg) from exc
left, right = _sort_dataframes(left, right)

for c in left.columns:
try:
Expand Down Expand Up @@ -140,8 +147,9 @@ def _assert_frame_schema_equal(
left: DataFrame | LazyFrame,
right: DataFrame | LazyFrame,
*,
check_dtype: bool = True,
check_column_order: bool = True,
check_dtype: bool,
check_column_order: bool,
objects: str,
) -> None:
left_schema, right_schema = left.schema, right.schema

Expand All @@ -152,24 +160,35 @@ def _assert_frame_schema_equal(
# Special error message for when column names do not match
if left_schema.keys() != right_schema.keys():
if left_not_right := [c for c in left_schema if c not in right_schema]:
msg = f"columns {left_not_right!r} in left frame, but not in right"
msg = f"columns {left_not_right!r} in left {objects[:-1]}, but not in right"
raise AssertionError(msg)
else:
right_not_left = [c for c in right_schema if c not in left_schema]
msg = f"columns {right_not_left!r} in right frame, but not in left"
msg = f"columns {right_not_left!r} in right {objects[:-1]}, but not in left"
raise AssertionError(msg)

if check_column_order:
left_columns, right_columns = list(left_schema), list(right_schema)
if left_columns != right_columns:
msg = "columns are not in the same order"
raise_assertion_error("Frames", msg, left_columns, right_columns)
detail = "columns are not in the same order"
raise_assertion_error(objects, detail, left_columns, right_columns)

if check_dtype:
left_schema_dict, right_schema_dict = dict(left_schema), dict(right_schema)
if check_column_order or left_schema_dict != right_schema_dict:
msg = "dtypes do not match"
raise_assertion_error("frames", msg, left_schema_dict, right_schema_dict)
detail = "dtypes do not match"
raise_assertion_error(objects, detail, left_schema_dict, right_schema_dict)


def _sort_dataframes(left: DataFrame, right: DataFrame) -> tuple[DataFrame, DataFrame]:
by = left.columns
try:
left = left.sort(by)
right = right.sort(by)
except ComputeError as exc:
msg = "cannot set `check_row_order=False` on frame with unsortable columns"
raise InvalidAssert(msg) from exc
return left, right


def assert_frame_not_equal(
Expand Down
7 changes: 7 additions & 0 deletions py-polars/polars/testing/asserts/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def assert_series_equal(
assert_frame_equal
assert_series_not_equal
Notes
-----
When using pytest, it may be worthwhile to shorten Python traceback printing
by passing ``--tb=short``. The default mode tends to be unhelpfully verbose.
More information in the
`pytest docs <https://docs.pytest.org/en/latest/how-to/output.html#modifying-python-traceback-printing>`_.
Examples
--------
>>> from polars.testing import assert_series_equal
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/testing/asserts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@


def raise_assertion_error(
obj: str,
objects: str,
detail: str,
left: Any,
right: Any,
) -> NoReturn:
"""Raise a detailed assertion error."""
__tracebackhide__ = True
msg = f"{obj} are different ({detail})\n[left]: {left}\n[right]: {right}"
msg = f"{objects} are different ({detail})\n[left]: {left}\n[right]: {right}"
raise AssertionError(msg)
12 changes: 7 additions & 5 deletions py-polars/tests/unit/testing/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ def test_assert_frame_equal_length_mismatch() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2, 3]})
with pytest.raises(
AssertionError, match=r"DataFrames are different \(length mismatch\)"
AssertionError,
match=r"DataFrames are different \(number of rows does not match\)",
):
assert_frame_equal(df1, df2)

Expand All @@ -298,16 +299,17 @@ def test_assert_frame_equal_column_mismatch() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"b": [1, 2]})
with pytest.raises(
AssertionError, match="columns \\['a'\\] in left frame, but not in right"
AssertionError, match="columns \\['a'\\] in left DataFrame, but not in right"
):
assert_frame_equal(df1, df2)


def test_assert_frame_equal_column_mismatch2() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})
df1 = pl.LazyFrame({"a": [1, 2]})
df2 = pl.LazyFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})
with pytest.raises(
AssertionError, match="columns \\['b', 'c'\\] in right frame, but not in left"
AssertionError,
match="columns \\['b', 'c'\\] in right LazyFrame, but not in left",
):
assert_frame_equal(df1, df2)

Expand Down

0 comments on commit c68ab3b

Please sign in to comment.