Skip to content

Commit

Permalink
feat(python): Support decimals in assert utils (#12119)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Oct 30, 2023
1 parent f3ee4d0 commit a366bc9
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 71 deletions.
14 changes: 10 additions & 4 deletions py-polars/polars/testing/asserts/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,25 @@ def assert_frame_equal(
left, right = _sort_dataframes(left, right)

for c in left.columns:
s_left, s_right = left.get_column(c), right.get_column(c)
try:
_assert_series_values_equal(
left.get_column(c),
right.get_column(c),
s_left,
s_right,
check_exact=check_exact,
rtol=rtol,
atol=atol,
nans_compare_equal=nans_compare_equal,
categorical_as_str=categorical_as_str,
)
except AssertionError as exc:
msg = f"values for column {c!r} are different"
raise AssertionError(msg) from exc
raise_assertion_error(
objects,
f"value mismatch for column {c!r}",
s_left.to_list(),
s_right.to_list(),
cause=exc,
)


def _assert_correct_input_type(
Expand Down
123 changes: 67 additions & 56 deletions py-polars/polars/testing/asserts/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
UNSIGNED_INTEGER_DTYPES,
Array,
Categorical,
Decimal,
Float64,
Int64,
List,
Struct,
Expand Down Expand Up @@ -151,6 +153,14 @@ def _assert_series_values_equal(
if right.dtype == Categorical:
right = right.cast(Utf8)

# Handle decimals
# TODO: Delete this branch when Decimal equality is implemented
# https://github.com/pola-rs/polars/issues/12118
if left.dtype == Decimal:
left = left.cast(Float64)
if right.dtype == Decimal:
right = right.cast(Float64)

# Determine unequal elements
try:
unequal = left.ne_missing(right)
Expand All @@ -170,15 +180,25 @@ def _assert_series_values_equal(

# Check nested dtypes in separate function
if _comparing_nested_numerics(left.dtype, right.dtype):
if _assert_series_nested(
left=left.filter(unequal),
right=right.filter(unequal),
check_exact=check_exact,
rtol=rtol,
atol=atol,
nans_compare_equal=nans_compare_equal,
categorical_as_str=categorical_as_str,
):
try:
_assert_series_nested_values_equal(
left=left.filter(unequal),
right=right.filter(unequal),
check_exact=check_exact,
rtol=rtol,
atol=atol,
nans_compare_equal=nans_compare_equal,
categorical_as_str=categorical_as_str,
)
except AssertionError as exc:
raise_assertion_error(
"Series",
"nested value mismatch",
left=left.to_list(),
right=right.to_list(),
cause=exc,
)
else: # All nested values match
return

# If no differences found during exact checking, we're done
Expand All @@ -192,10 +212,7 @@ def _assert_series_values_equal(
or right.dtype not in NUMERIC_DTYPES
):
raise_assertion_error(
"Series",
"exact value mismatch",
left=left.to_list(),
right=right.to_list(),
"Series", "exact value mismatch", left=left.to_list(), right=right.to_list()
)

_assert_series_null_values_match(left, right)
Expand All @@ -209,40 +226,7 @@ def _assert_series_values_equal(
)


def _assert_series_null_values_match(left: Series, right: Series) -> None:
null_value_mismatch = left.is_null() != right.is_null()
if null_value_mismatch.any():
raise_assertion_error(
"Series", "null value mismatch", left.to_list(), right.to_list()
)


def _assert_series_nan_values_match(
left: Series, right: Series, *, nans_compare_equal: bool
) -> None:
if not _comparing_floats(left.dtype, right.dtype):
return

if nans_compare_equal:
nan_value_mismatch = left.is_nan() != right.is_nan()
if nan_value_mismatch.any():
raise_assertion_error(
"Series",
"nan value mismatch - nans compare equal",
left.to_list(),
right.to_list(),
)

elif left.is_nan().any() or right.is_nan().any():
raise_assertion_error(
"Series",
"nan value mismatch - nans compare unequal",
left.to_list(),
right.to_list(),
)


def _assert_series_nested(
def _assert_series_nested_values_equal(
left: Series,
right: Series,
*,
Expand All @@ -251,11 +235,11 @@ def _assert_series_nested(
atol: float,
nans_compare_equal: bool,
categorical_as_str: bool,
) -> bool:
) -> None:
# compare nested lists element-wise
if _comparing_lists(left.dtype, right.dtype):
for s1, s2 in zip(left, right):
if (s1 is None and s2 is not None) or (s2 is None and s1 is not None):
if s1 is None or s2 is None:
raise_assertion_error("Series", "nested value mismatch", s1, s2)

_assert_series_values_equal(
Expand All @@ -267,10 +251,9 @@ def _assert_series_nested(
nans_compare_equal=nans_compare_equal,
categorical_as_str=categorical_as_str,
)
return True

# unnest structs as series and compare
elif _comparing_structs(left.dtype, right.dtype):
else:
ls, rs = left.struct.unnest(), right.struct.unnest()
for s1, s2 in zip(ls, rs):
_assert_series_values_equal(
Expand All @@ -282,11 +265,39 @@ def _assert_series_nested(
nans_compare_equal=nans_compare_equal,
categorical_as_str=categorical_as_str,
)
return True
else:
# fall-back to outer codepath (if mismatched dtypes we would expect
# the equality check to fail - unless ALL series values are null)
return False


def _assert_series_null_values_match(left: Series, right: Series) -> None:
null_value_mismatch = left.is_null() != right.is_null()
if null_value_mismatch.any():
raise_assertion_error(
"Series", "null value mismatch", left.to_list(), right.to_list()
)


def _assert_series_nan_values_match(
left: Series, right: Series, *, nans_compare_equal: bool
) -> None:
if not _comparing_floats(left.dtype, right.dtype):
return

if nans_compare_equal:
nan_value_mismatch = left.is_nan() != right.is_nan()
if nan_value_mismatch.any():
raise_assertion_error(
"Series",
"nan value mismatch - nans compare equal",
left.to_list(),
right.to_list(),
)

elif left.is_nan().any() or right.is_nan().any():
raise_assertion_error(
"Series",
"nan value mismatch - nans compare unequal",
left.to_list(),
right.to_list(),
)


def _comparing_floats(left: PolarsDataType, right: PolarsDataType) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions py-polars/tests/unit/testing/test_assert_frame_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_compare_frame_equal_nans() -> None:
schema=[("x", pl.Float32), ("y", pl.Float64)],
)
assert_frame_not_equal(df1, df2)
with pytest.raises(AssertionError, match="values for column 'y' are different"):
with pytest.raises(AssertionError, match="value mismatch for column 'y'"):
assert_frame_equal(df1, df2, check_exact=True)


Expand All @@ -215,7 +215,7 @@ def test_compare_frame_equal_nested_nans() -> None:
schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))],
)
assert_frame_not_equal(df1, df2)
with pytest.raises(AssertionError, match="values for column 'y' are different"):
with pytest.raises(AssertionError, match="value mismatch for column 'y'"):
assert_frame_equal(df1, df2, check_exact=True)

# struct dtype
Expand Down Expand Up @@ -328,7 +328,7 @@ def test_assert_frame_equal_ignore_row_order() -> None:
df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]})
df2 = pl.DataFrame({"a": [2, 1], "b": [3, 4]})
df3 = pl.DataFrame({"b": [3, 4], "a": [2, 1]})
with pytest.raises(AssertionError, match="values for column 'a' are different"):
with pytest.raises(AssertionError, match="value mismatch for column 'a'"):
assert_frame_equal(df1, df2)

assert_frame_equal(df1, df2, check_row_order=False)
Expand Down
55 changes: 47 additions & 8 deletions py-polars/tests/unit/testing/test_assert_series_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
from datetime import datetime, time, timedelta
from decimal import Decimal as D
from typing import Any

import pytest
Expand Down Expand Up @@ -561,11 +562,28 @@ def test_assert_series_equal_nested_struct_float() -> None:
assert_series_equal(s1, s2)


def test_assert_series_equal_nested_list_full_null() -> None:
def test_assert_series_equal_full_null_incompatible_dtypes_raises() -> None:
s1 = pl.Series([None, None], dtype=pl.Categorical)
s2 = pl.Series([None, None], dtype=pl.Int16)

# You could argue this should pass, but it's rare enough not to warrant the
# additional check
with pytest.raises(AssertionError, match="incompatible data types"):
assert_series_equal(s1, s2, check_dtype=False)


def test_assert_series_equal_full_null_nested_list() -> None:
s = pl.Series([None, None], dtype=pl.List(pl.Float64))
assert_series_equal(s, s)


def test_assert_series_equal_full_null_nested_not_nested() -> None:
s1 = pl.Series([None, None], dtype=pl.List(pl.Float64))
s2 = pl.Series([None, None], dtype=pl.Float64)

assert_series_equal(s1, s2, check_dtype=False)


def test_assert_series_equal_nested_list_nan() -> None:
s = pl.Series([[1.0, 2.0], [3.0, nan]], dtype=pl.List(pl.Float64))
assert_series_equal(s, s)
Expand All @@ -578,13 +596,6 @@ def test_assert_series_equal_nested_list_none() -> None:
assert_series_equal(s1, s2)


def test_assert_series_equal_full_none_nested_not_nested() -> None:
s1 = pl.Series([None, None], dtype=pl.List(pl.Float64))
s2 = pl.Series([None, None], dtype=pl.Float64)

assert_series_equal(s1, s2, check_dtype=False)


def test_assert_series_equal_unsigned_ints_underflow() -> None:
s1 = pl.Series([1, 3], dtype=pl.UInt8)
s2 = pl.Series([2, 4], dtype=pl.Int64)
Expand All @@ -610,6 +621,34 @@ def test_assert_series_equal_nested_int() -> None:
assert_series_equal(s1, s2, check_exact=True)


def test_series_equal_nested_lengths_mismatch() -> None:
s1 = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.List(pl.Float64))
s2 = pl.Series([[1.0, 2.0, 3.0], [4.0]], dtype=pl.List(pl.Float64))

with pytest.raises(AssertionError, match="nested value mismatch"):
assert_series_equal(s1, s2)


def test_series_equal_decimals_exact() -> None:
s1 = pl.Series([D("1.00000"), D("2.00000")], dtype=pl.Decimal)
s2 = pl.Series([D("1.00000"), D("2.00001")], dtype=pl.Decimal)
with pytest.raises(AssertionError, match="exact value mismatch"):
assert_series_equal(s1, s2, check_exact=True)


def test_series_equal_decimals_inexact() -> None:
s1 = pl.Series([D("1.00000"), D("2.00000")], dtype=pl.Decimal)
s2 = pl.Series([D("1.00000"), D("2.00001")], dtype=pl.Decimal)
assert_series_equal(s1, s2, check_exact=False)


def test_series_equal_decimals_inexact_fail() -> None:
s1 = pl.Series([D("1.00000"), D("2.00000")], dtype=pl.Decimal)
s2 = pl.Series([D("1.00000"), D("2.00001")], dtype=pl.Decimal)
with pytest.raises(AssertionError, match="value mismatch"):
assert_series_equal(s1, s2, check_exact=False, rtol=0)


def test_compare_series_nans_assert_equal_deprecated() -> None:
srs1 = pl.Series([1.0, 2.0, nan, 4.0, None, 6.0])
srs2 = pl.Series([1.0, nan, 3.0, 4.0, None, 6.0])
Expand Down

0 comments on commit a366bc9

Please sign in to comment.