Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Oct 25, 2023
1 parent 73bb90f commit 46c405a
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 202 deletions.
4 changes: 2 additions & 2 deletions py-polars/tests/unit/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def test_init_ndarray(monkeypatch: Any) -> None:
data={"x": np.array([1.0, 2.5, np.nan]), "y": np.array([4.0, np.nan, 6.5])},
nan_to_null=True,
)
assert_frame_equal(df0, df1, nans_compare_equal=True)
assert_frame_equal(df0, df1)
assert df2.rows() == [(1.0, 4.0), (2.5, None), (None, 6.5)]


Expand Down Expand Up @@ -714,7 +714,7 @@ def test_init_series() -> None:
s1 = pl.Series("n", np.array([1.0, 2.5, float("nan")]))
s2 = pl.Series("n", np.array([1.0, 2.5, float("nan")]), nan_to_null=True)

assert_series_equal(s0, s1, nans_compare_equal=True)
assert_series_equal(s0, s1)
assert s2.to_list() == [1.0, 2.5, None]


Expand Down
156 changes: 95 additions & 61 deletions py-polars/tests/unit/testing/test_assert_frame_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from polars.exceptions import InvalidAssert
from polars.testing import assert_frame_equal, assert_frame_not_equal

nan = float("NaN")


@pytest.mark.parametrize(
("df1", "df2", "kwargs"),
Expand Down Expand Up @@ -103,23 +105,11 @@
{"rtol": 1},
id="list_of_none_and_float_integer_rtol",
),
pytest.param(
pl.DataFrame({"a": [[None, 1.3]]}),
pl.DataFrame({"a": [[None, 0.9]]}),
{"rtol": 1, "nans_compare_equal": False},
id="list_of_none_and_float_integer_rtol",
),
pytest.param(
pl.DataFrame({"a": [[[0.2, 3.0]]]}),
pl.DataFrame({"a": [[[0.2, 3.00000001]]]}),
{"atol": 0.1, "nans_compare_equal": True},
id="nested_list_of_float_atol_high_nans_compare_equal_true",
),
pytest.param(
pl.DataFrame({"a": [[[0.2, 3.0]]]}),
pl.DataFrame({"a": [[[0.2, 3.00000001]]]}),
{"atol": 0.1, "nans_compare_equal": False},
id="nested_list_of_float_atol_high_nans_compare_equal_false",
{"atol": 0.1},
id="nested_list_of_float_atol_high",
),
],
)
Expand Down Expand Up @@ -160,12 +150,6 @@ def test_assert_frame_equal_passes_assertion(
{"atol": -1, "rtol": 0},
id="list_of_float_negative_atol",
),
pytest.param(
pl.DataFrame({"a": [[math.nan, 1.3]]}),
pl.DataFrame({"a": [[math.nan, 0.9]]}),
{"rtol": 1, "nans_compare_equal": False},
id="list_of_nan_and_float_integer_rtol",
),
pytest.param(
pl.DataFrame({"a": [[2.0, 3.0]]}),
pl.DataFrame({"a": [[2, 3]]}),
Expand All @@ -175,50 +159,20 @@ def test_assert_frame_equal_passes_assertion(
pytest.param(
pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}),
pl.DataFrame({"a": [[[0.2, math.nan, 3.11]]]}),
{"atol": 0.1, "rtol": 0, "nans_compare_equal": True},
id="nested_list_of_float_and_nan_atol_high_nans_compare_equal_true",
),
pytest.param(
pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}),
pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}),
{"nans_compare_equal": False},
id="nested_list_of_float_and_nan_atol_high_nans_compare_equal_false",
),
pytest.param(
pl.DataFrame({"a": [[[0.2, 3.0]]]}),
pl.DataFrame({"a": [[[0.2, 3.11]]]}),
{"atol": 0.1, "nans_compare_equal": False},
id="nested_list_of_float_atol_high_nans_compare_equal_false",
{"atol": 0.1, "rtol": 0},
id="nested_list_of_float_and_nan_atol_high",
),
pytest.param(
pl.DataFrame({"a": [[[[0.2, 3.0]]]]}),
pl.DataFrame({"a": [[[[0.2, 3.11]]]]}),
{"atol": 0.1, "rtol": 0, "nans_compare_equal": True},
id="double_nested_list_of_float_atol_high_nans_compare_equal_true",
),
pytest.param(
pl.DataFrame({"a": [[[[0.2, math.nan, 3.0]]]]}),
pl.DataFrame({"a": [[[[0.2, math.nan, 3.0]]]]}),
{"atol": 0.1, "nans_compare_equal": False},
id="double_nested_list_of_float_and_nan_atol_high_nans_compare_equal_false",
),
pytest.param(
pl.DataFrame({"a": [[[[0.2, 3.0]]]]}),
pl.DataFrame({"a": [[[[0.2, 3.11]]]]}),
{"atol": 0.1, "rtol": 0, "nans_compare_equal": False},
id="double_nested_list_of_float_atol_high_nans_compare_equal_false",
{"atol": 0.1, "rtol": 0},
id="double_nested_list_of_float_atol_high",
),
pytest.param(
pl.DataFrame({"a": [[[[[0.2, 3.0]]]]]}),
pl.DataFrame({"a": [[[[[0.2, 3.11]]]]]}),
{"atol": 0.1, "rtol": 0, "nans_compare_equal": True},
id="triple_nested_list_of_float_atol_high_nans_compare_equal_true",
),
pytest.param(
pl.DataFrame({"a": [[[[[0.2, math.nan, 3.0]]]]]}),
pl.DataFrame({"a": [[[[[0.2, math.nan, 3.0]]]]]}),
{"atol": 0.1, "nans_compare_equal": False},
id="triple_nested_list_of_float_and_nan_atol_high_nans_compare_equal_true",
{"atol": 0.1, "rtol": 0},
id="triple_nested_list_of_float_atol_high",
),
],
)
Expand All @@ -233,7 +187,6 @@ def test_assert_frame_equal_raises_assertion_error(


def test_compare_frame_equal_nans() -> None:
nan = float("NaN")
df1 = pl.DataFrame(
data={"x": [1.0, nan], "y": [nan, 2.0]},
schema=[("x", pl.Float32), ("y", pl.Float64)],
Expand All @@ -250,8 +203,6 @@ def test_compare_frame_equal_nans() -> None:


def test_compare_frame_equal_nested_nans() -> None:
nan = float("NaN")

# list dtype
df1 = pl.DataFrame(
data={"x": [[1.0, nan]], "y": [[nan, 2.0]]},
Expand Down Expand Up @@ -306,10 +257,13 @@ def test_compare_frame_equal_nested_nans() -> None:
)

assert_frame_equal(df3, df3)
assert_frame_not_equal(df3, df3, nans_compare_equal=False)
with pytest.deprecated_call():
assert_frame_not_equal(df3, df3, nans_compare_equal=False)

assert_frame_equal(df4, df4)
assert_frame_not_equal(df4, df4, nans_compare_equal=False)

with pytest.deprecated_call():
assert_frame_not_equal(df4, df4, nans_compare_equal=False)

assert_frame_not_equal(df3, df4)
for check_dtype in (True, False):
Expand Down Expand Up @@ -418,3 +372,83 @@ def test_assert_frame_not_equal() -> None:
df = pl.DataFrame({"a": [1, 2]})
with pytest.raises(AssertionError, match="frames are equal"):
assert_frame_not_equal(df, df)


@pytest.mark.parametrize(
("df1", "df2", "kwargs"),
[
pytest.param(
pl.DataFrame({"a": [[None, 1.3]]}),
pl.DataFrame({"a": [[None, 0.9]]}),
{"rtol": 1},
id="list_of_none_and_float_integer_rtol",
),
pytest.param(
pl.DataFrame({"a": [[[0.2, 3.0]]]}),
pl.DataFrame({"a": [[[0.2, 3.00000001]]]}),
{"atol": 0.1},
id="nested_list_of_float_atol_high_nans_compare_equal_false",
),
],
)
def test_assert_frame_equal_passes_assertion_deprecated_nans_compare_equal_false(
df1: pl.DataFrame,
df2: pl.DataFrame,
kwargs: dict[str, Any],
) -> None:
with pytest.deprecated_call():
assert_frame_equal(df1, df2, nans_compare_equal=False, **kwargs)
with pytest.raises(AssertionError), pytest.deprecated_call():
assert_frame_not_equal(df1, df2, nans_compare_equal=False, **kwargs)


@pytest.mark.parametrize(
("df1", "df2", "kwargs"),
[
pytest.param(
pl.DataFrame({"a": [[math.nan, 1.3]]}),
pl.DataFrame({"a": [[math.nan, 0.9]]}),
{"rtol": 1},
id="list_of_nan_and_float_integer_rtol",
),
pytest.param(
pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}),
pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}),
{},
id="nested_list_of_float_and_nan_atol_high",
),
pytest.param(
pl.DataFrame({"a": [[[0.2, 3.0]]]}),
pl.DataFrame({"a": [[[0.2, 3.11]]]}),
{"atol": 0.1},
id="nested_list_of_float_atol_high",
),
pytest.param(
pl.DataFrame({"a": [[[[0.2, math.nan, 3.0]]]]}),
pl.DataFrame({"a": [[[[0.2, math.nan, 3.0]]]]}),
{"atol": 0.1},
id="double_nested_list_of_float_and_nan_atol_high",
),
pytest.param(
pl.DataFrame({"a": [[[[0.2, 3.0]]]]}),
pl.DataFrame({"a": [[[[0.2, 3.11]]]]}),
{"atol": 0.1, "rtol": 0},
id="double_nested_list_of_float_atol_high",
),
pytest.param(
pl.DataFrame({"a": [[[[[0.2, math.nan, 3.0]]]]]}),
pl.DataFrame({"a": [[[[[0.2, math.nan, 3.0]]]]]}),
{"atol": 0.1},
id="triple_nested_list_of_float_and_nan_atol_high",
),
],
)
def test_assert_frame_equal_raises_assertion_error_deprecated_nans_compare_equal_false(
df1: pl.DataFrame,
df2: pl.DataFrame,
kwargs: dict[str, Any],
) -> None:
with pytest.raises(AssertionError), pytest.deprecated_call():
assert_frame_equal(df1, df2, nans_compare_equal=False, **kwargs)
with pytest.deprecated_call():
assert_frame_not_equal(df1, df2, nans_compare_equal=False, **kwargs)
Loading

0 comments on commit 46c405a

Please sign in to comment.