Skip to content

Commit

Permalink
fix(rust): check dtypes of single-column 'by' parameter in asof-join (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller authored Aug 4, 2023
1 parent 30122d0 commit 5a7905c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
9 changes: 7 additions & 2 deletions crates/polars-core/src/frame/asof_join/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,12 @@ fn dispatch_join<T: PolarsNumericType>(
tolerance: Option<AnyValue<'static>>,
) -> PolarsResult<Vec<Option<IdxSize>>> {
let out = if left_by.width() == 1 {
match left_by_s.dtype() {
let left_dtype = left_by_s.dtype();
let right_dtype = right_by_s.dtype();
polars_ensure!(left_dtype == right_dtype,
ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{}` and `{}`", left_dtype, right_dtype
);
match left_dtype {
DataType::Utf8 => asof_join_by_binary(
&left_by_s.utf8().unwrap().as_binary(),
&right_by_s.utf8().unwrap().as_binary(),
Expand Down Expand Up @@ -669,7 +674,7 @@ fn dispatch_join<T: PolarsNumericType>(
} else {
for (lhs, rhs) in left_by.get_columns().iter().zip(right_by.get_columns()) {
polars_ensure!(lhs.dtype() == rhs.dtype(),
ComputeError: "mismatching dtypes in 'on' parameter of asof-join: `{}` and `{}`", lhs.dtype(), rhs.dtype()
ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{}` and `{}`", lhs.dtype(), rhs.dtype()
);
#[cfg(feature = "dtype-categorical")]
_check_categorical_src(lhs.dtype(), rhs.dtype())?;
Expand Down
36 changes: 36 additions & 0 deletions py-polars/tests/unit/operations/test_join_asof.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,42 @@ def test_asof_join_schema_5684() -> None:
)


def test_join_asof_mismatched_dtypes() -> None:
# test 'on' dtype mismatch
df1 = pl.DataFrame(
{"a": pl.Series([1, 2, 3], dtype=pl.Int64), "b": ["a", "b", "c"]}
)
df2 = pl.DataFrame(
{"a": pl.Series([1, 2, 3], dtype=pl.Int32), "c": ["d", "e", "f"]}
)

with pytest.raises(
pl.exceptions.ComputeError, match="datatypes of join keys don't match"
):
df1.join_asof(df2, on="a", strategy="forward")

# test 'by' dtype mismatch
df1 = pl.DataFrame(
{
"time": pl.date_range(date(2018, 1, 1), date(2018, 1, 8), eager=True),
"group": pl.Series([1, 1, 1, 1, 2, 2, 2, 2], dtype=pl.Int32),
"value": [0, 0, None, None, 2, None, 1, None],
}
)
df2 = pl.DataFrame(
{
"time": pl.date_range(date(2018, 1, 1), date(2018, 1, 8), eager=True),
"group": pl.Series([1, 1, 1, 1, 2, 2, 2, 2], dtype=pl.Int64),
"value": [0, 0, None, None, 2, None, 1, None],
}
)

with pytest.raises(
pl.exceptions.ComputeError, match="mismatching dtypes in 'by' parameter"
):
df1.join_asof(df2, on="time", by="group", strategy="forward")


def test_join_asof_floats() -> None:
df1 = pl.DataFrame({"a": [1.0, 2.0, 3.0], "b": ["lrow1", "lrow2", "lrow3"]})
df2 = pl.DataFrame({"a": [0.59, 1.49, 2.89], "b": ["rrow1", "rrow2", "rrow3"]})
Expand Down

0 comments on commit 5a7905c

Please sign in to comment.