From 54d76c2ef15ea51607a17d1ec86e9cafa507a499 Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Thu, 3 Aug 2023 17:02:55 -0400 Subject: [PATCH] Check dtypes of single-column 'by' parameter in asof-join --- .../polars-core/src/frame/asof_join/groups.rs | 9 +++-- .../tests/unit/operations/test_join_asof.py | 36 +++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/crates/polars-core/src/frame/asof_join/groups.rs b/crates/polars-core/src/frame/asof_join/groups.rs index 2ab7490505b0..edd3f5be8a68 100644 --- a/crates/polars-core/src/frame/asof_join/groups.rs +++ b/crates/polars-core/src/frame/asof_join/groups.rs @@ -633,7 +633,12 @@ fn dispatch_join( tolerance: Option>, ) -> PolarsResult>> { let out = if left_by.width() == 1 { - match left_by_s.dtype() { + let left_dtype = left_by_s.dtype(); + let right_dtype = right_by_s.dtype(); + polars_ensure!(left_dtype == right_dtype, + ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{}` and `{}`", left_dtype, right_dtype + ); + match left_dtype { DataType::Utf8 => asof_join_by_binary( &left_by_s.utf8().unwrap().as_binary(), &right_by_s.utf8().unwrap().as_binary(), @@ -669,7 +674,7 @@ fn dispatch_join( } else { for (lhs, rhs) in left_by.get_columns().iter().zip(right_by.get_columns()) { polars_ensure!(lhs.dtype() == rhs.dtype(), - ComputeError: "mismatching dtypes in 'on' parameter of asof-join: `{}` and `{}`", lhs.dtype(), rhs.dtype() + ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{}` and `{}`", lhs.dtype(), rhs.dtype() ); #[cfg(feature = "dtype-categorical")] _check_categorical_src(lhs.dtype(), rhs.dtype())?; diff --git a/py-polars/tests/unit/operations/test_join_asof.py b/py-polars/tests/unit/operations/test_join_asof.py index 953255c3f974..12c24af58034 100644 --- a/py-polars/tests/unit/operations/test_join_asof.py +++ b/py-polars/tests/unit/operations/test_join_asof.py @@ -115,6 +115,42 @@ def test_asof_join_schema_5684() -> None: ) +def test_join_asof_mismatched_dtypes() -> None: + # test 'on' dtype mismatch + df1 = pl.DataFrame( + {"a": pl.Series([1, 2, 3], dtype=pl.Int64), "b": ["a", "b", "c"]} + ) + df2 = pl.DataFrame( + {"a": pl.Series([1, 2, 3], dtype=pl.Int32), "c": ["d", "e", "f"]} + ) + + with pytest.raises( + pl.exceptions.ComputeError, match="datatypes of join keys don't match" + ): + df1.join_asof(df2, on="a", strategy="forward") + + # test 'by' dtype mismatch + df1 = pl.DataFrame( + { + "time": pl.date_range(date(2018, 1, 1), date(2018, 1, 8), eager=True), + "group": pl.Series([1, 1, 1, 1, 2, 2, 2, 2], dtype=pl.Int32), + "value": [0, 0, None, None, 2, None, 1, None], + } + ) + df2 = pl.DataFrame( + { + "time": pl.date_range(date(2018, 1, 1), date(2018, 1, 8), eager=True), + "group": pl.Series([1, 1, 1, 1, 2, 2, 2, 2], dtype=pl.Int64), + "value": [0, 0, None, None, 2, None, 1, None], + } + ) + + with pytest.raises( + pl.exceptions.ComputeError, match="mismatching dtypes in 'by' parameter" + ): + df1.join_asof(df2, on="time", by="group", strategy="forward") + + def test_join_asof_floats() -> None: df1 = pl.DataFrame({"a": [1.0, 2.0, 3.0], "b": ["lrow1", "lrow2", "lrow3"]}) df2 = pl.DataFrame({"a": [0.59, 1.49, 2.89], "b": ["rrow1", "rrow2", "rrow3"]})