diff --git a/py-polars/polars/testing/asserts/frame.py b/py-polars/polars/testing/asserts/frame.py index 1d2793bc9056..298001073db6 100644 --- a/py-polars/polars/testing/asserts/frame.py +++ b/py-polars/polars/testing/asserts/frame.py @@ -62,6 +62,13 @@ def assert_frame_equal( assert_series_equal assert_frame_not_equal + Notes + ----- + When using pytest, it may be worthwhile to shorten Python traceback printing + by passing ``--tb=short``. The default mode tends to be unhelpfully verbose. + More information in the + `pytest docs `_. + Examples -------- >>> from polars.testing import assert_frame_equal @@ -82,27 +89,27 @@ def assert_frame_equal( """ lazy = _assert_correct_input_type(left, right) - objs = "LazyFrames" if lazy else "DataFrames" + objects = "LazyFrames" if lazy else "DataFrames" _assert_frame_schema_equal( - left, right, check_column_order=check_column_order, check_dtype=check_dtype + left, + right, + check_column_order=check_column_order, + check_dtype=check_dtype, + objects=objects, ) if lazy: left, right = left.collect(), right.collect() # type: ignore[union-attr] - left, right = cast(DataFrame, left), cast(DataFrame, right) if left.height != right.height: - raise_assertion_error(objs, "length mismatch", left.height, right.height) + raise_assertion_error( + objects, "number of rows does not match", left.height, right.height + ) if not check_row_order: - try: - left = left.sort(by=left.columns) - right = right.sort(by=left.columns) - except ComputeError as exc: - msg = "cannot set `check_row_order=False` on frame with unsortable columns" - raise InvalidAssert(msg) from exc + left, right = _sort_dataframes(left, right) for c in left.columns: try: @@ -140,8 +147,9 @@ def _assert_frame_schema_equal( left: DataFrame | LazyFrame, right: DataFrame | LazyFrame, *, - check_dtype: bool = True, - check_column_order: bool = True, + check_dtype: bool, + check_column_order: bool, + objects: str, ) -> None: left_schema, right_schema = left.schema, right.schema @@ -152,24 +160,35 @@ def _assert_frame_schema_equal( # Special error message for when column names do not match if left_schema.keys() != right_schema.keys(): if left_not_right := [c for c in left_schema if c not in right_schema]: - msg = f"columns {left_not_right!r} in left frame, but not in right" + msg = f"columns {left_not_right!r} in left {objects[:-1]}, but not in right" raise AssertionError(msg) else: right_not_left = [c for c in right_schema if c not in left_schema] - msg = f"columns {right_not_left!r} in right frame, but not in left" + msg = f"columns {right_not_left!r} in right {objects[:-1]}, but not in left" raise AssertionError(msg) if check_column_order: left_columns, right_columns = list(left_schema), list(right_schema) if left_columns != right_columns: - msg = "columns are not in the same order" - raise_assertion_error("Frames", msg, left_columns, right_columns) + detail = "columns are not in the same order" + raise_assertion_error(objects, detail, left_columns, right_columns) if check_dtype: left_schema_dict, right_schema_dict = dict(left_schema), dict(right_schema) if check_column_order or left_schema_dict != right_schema_dict: - msg = "dtypes do not match" - raise_assertion_error("frames", msg, left_schema_dict, right_schema_dict) + detail = "dtypes do not match" + raise_assertion_error(objects, detail, left_schema_dict, right_schema_dict) + + +def _sort_dataframes(left: DataFrame, right: DataFrame) -> tuple[DataFrame, DataFrame]: + by = left.columns + try: + left = left.sort(by) + right = right.sort(by) + except ComputeError as exc: + msg = "cannot set `check_row_order=False` on frame with unsortable columns" + raise InvalidAssert(msg) from exc + return left, right def assert_frame_not_equal( diff --git a/py-polars/polars/testing/asserts/series.py b/py-polars/polars/testing/asserts/series.py index 99bfb4c71e5a..57db626494ac 100644 --- a/py-polars/polars/testing/asserts/series.py +++ b/py-polars/polars/testing/asserts/series.py @@ -62,6 +62,13 @@ def assert_series_equal( assert_frame_equal assert_series_not_equal + Notes + ----- + When using pytest, it may be worthwhile to shorten Python traceback printing + by passing ``--tb=short``. The default mode tends to be unhelpfully verbose. + More information in the + `pytest docs `_. + Examples -------- >>> from polars.testing import assert_series_equal diff --git a/py-polars/polars/testing/asserts/utils.py b/py-polars/polars/testing/asserts/utils.py index bc99259efd11..713e57170ac1 100644 --- a/py-polars/polars/testing/asserts/utils.py +++ b/py-polars/polars/testing/asserts/utils.py @@ -4,12 +4,12 @@ def raise_assertion_error( - obj: str, + objects: str, detail: str, left: Any, right: Any, ) -> NoReturn: """Raise a detailed assertion error.""" __tracebackhide__ = True - msg = f"{obj} are different ({detail})\n[left]: {left}\n[right]: {right}" + msg = f"{objects} are different ({detail})\n[left]: {left}\n[right]: {right}" raise AssertionError(msg) diff --git a/py-polars/tests/unit/testing/test_testing.py b/py-polars/tests/unit/testing/test_testing.py index ae7a6058b500..bafafc51c1d6 100644 --- a/py-polars/tests/unit/testing/test_testing.py +++ b/py-polars/tests/unit/testing/test_testing.py @@ -289,7 +289,8 @@ def test_assert_frame_equal_length_mismatch() -> None: df1 = pl.DataFrame({"a": [1, 2]}) df2 = pl.DataFrame({"a": [1, 2, 3]}) with pytest.raises( - AssertionError, match=r"DataFrames are different \(length mismatch\)" + AssertionError, + match=r"DataFrames are different \(number of rows does not match\)", ): assert_frame_equal(df1, df2) @@ -298,16 +299,17 @@ def test_assert_frame_equal_column_mismatch() -> None: df1 = pl.DataFrame({"a": [1, 2]}) df2 = pl.DataFrame({"b": [1, 2]}) with pytest.raises( - AssertionError, match="columns \\['a'\\] in left frame, but not in right" + AssertionError, match="columns \\['a'\\] in left DataFrame, but not in right" ): assert_frame_equal(df1, df2) def test_assert_frame_equal_column_mismatch2() -> None: - df1 = pl.DataFrame({"a": [1, 2]}) - df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + df1 = pl.LazyFrame({"a": [1, 2]}) + df2 = pl.LazyFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) with pytest.raises( - AssertionError, match="columns \\['b', 'c'\\] in right frame, but not in left" + AssertionError, + match="columns \\['b', 'c'\\] in right LazyFrame, but not in left", ): assert_frame_equal(df1, df2)