Skip to content

Commit

Permalink
tests(python): tighten assert_frame_equal for LazyFrames (don't col…
Browse files Browse the repository at this point in the history
…lect until after the schema has been checked)
  • Loading branch information
alexander-beedie committed Sep 26, 2023
1 parent 1f0450a commit 65de008
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions py-polars/polars/testing/asserts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import textwrap
from typing import Any
from typing import Any, NoReturn

from polars import functions as F
from polars.dataframe import DataFrame
Expand Down Expand Up @@ -75,24 +75,21 @@ def assert_frame_equal(
>>> assert_frame_equal(df1, df2) # doctest: +SKIP
AssertionError: Values for column 'a' are different.
"""
if isinstance(left, LazyFrame) and isinstance(right, LazyFrame):
left, right = left.collect(), right.collect()
obj = "LazyFrames"
if collect_input_frames := (
isinstance(left, LazyFrame) and isinstance(right, LazyFrame)
):
objs = "LazyFrames"
elif isinstance(left, DataFrame) and isinstance(right, DataFrame):
obj = "DataFrames"
objs = "DataFrames"
else:
raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right))

if left.shape[0] != right.shape[0]: # type: ignore[union-attr]
raise_assert_detail(obj, "Length mismatch", left.shape, right.shape) # type: ignore[union-attr]

left_not_right = [c for c in left.columns if c not in right.columns]
if left_not_right:
if left_not_right := [c for c in left.columns if c not in right.columns]:
raise AssertionError(
f"columns {left_not_right!r} in left frame, but not in right"
)
right_not_left = [c for c in right.columns if c not in left.columns]
if right_not_left:

if right_not_left := [c for c in right.columns if c not in left.columns]:
raise AssertionError(
f"columns {right_not_left!r} in right frame, but not in left"
)
Expand All @@ -102,6 +99,14 @@ def assert_frame_equal(
f"columns are not in the same order:\n{left.columns!r}\n{right.columns!r}"
)

if collect_input_frames:
if check_dtype: # check this _before_ we collect
assert left.schema == right.schema, "schema dtypes are not equal"
left, right = left.collect(), right.collect() # type: ignore[union-attr]

if left.shape[0] != right.shape[0]: # type: ignore[union-attr]
raise_assert_detail(objs, "Length mismatch", left.shape, right.shape) # type: ignore[union-attr]

if not check_row_order:
try:
left = left.sort(by=left.columns)
Expand Down Expand Up @@ -525,7 +530,7 @@ def raise_assert_detail(
left: Any,
right: Any,
exc: AssertionError | None = None,
) -> None:
) -> NoReturn:
"""Raise a detailed assertion error."""
__tracebackhide__ = True

Expand Down

0 comments on commit 65de008

Please sign in to comment.