From e02aa3205fe2c1c3b1211ef58472279d84b5c83e Mon Sep 17 00:00:00 2001 From: barak1412 Date: Fri, 8 Nov 2024 09:53:53 +0200 Subject: [PATCH] working version with left --- crates/polars-ops/src/frame/join/args.rs | 2 +- .../join/hash_join/single_keys_dispatch.rs | 69 ++++++++++++++++--- .../frame/join/hash_join/single_keys_inner.rs | 7 +- .../frame/join/hash_join/single_keys_left.rs | 7 +- py-polars/tests/unit/sql/test_joins.py | 23 +++++++ 5 files changed, 94 insertions(+), 14 deletions(-) diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index 68044514f49a..def36b76a677 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -257,7 +257,7 @@ impl JoinValidation { OneToMany | OneToOne => { if !join_nulls && probe.null_count() > 0 { probe.n_unique()? - 1 == probe.len() - probe.null_count() - }else { + } else { probe.n_unique()? == probe.len() } }, diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs index e9916e372e0d..7c365210b208 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs @@ -35,7 +35,8 @@ pub trait SeriesJoin: SeriesSealed + Sized { let (lhs, rhs, _, _) = prepare_binary::(lhs, rhs, false); let lhs = lhs.iter().map(|v| v.as_slice()).collect::>(); let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); - hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls) + let build_null_count = other.null_count(); + hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls, build_null_count) }, T::BinaryOffset => { let lhs = lhs.binary_offset().unwrap(); @@ -44,7 +45,8 @@ pub trait SeriesJoin: SeriesSealed + Sized { // Take slices so that vecs are not copied let lhs = lhs.iter().map(|k| k.as_slice()).collect::>(); let rhs = rhs.iter().map(|k| k.as_slice()).collect::>(); - hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls) + let build_null_count = other.null_count(); + hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls, build_null_count) }, x if x.is_float() => { with_match_physical_float_polars_type!(lhs.dtype(), |$T| { @@ -184,8 +186,20 @@ pub trait SeriesJoin: SeriesSealed + Sized { // Take slices so that vecs are not copied let lhs = lhs.iter().map(|k| k.as_slice()).collect::>(); let rhs = rhs.iter().map(|k| k.as_slice()).collect::>(); + let build_null_count = if swapped { + s_self.null_count() + } else { + other.null_count() + }; Ok(( - hash_join_tuples_inner(lhs, rhs, swapped, validate, join_nulls)?, + hash_join_tuples_inner( + lhs, + rhs, + swapped, + validate, + join_nulls, + build_null_count, + )?, !swapped, )) }, @@ -196,8 +210,20 @@ pub trait SeriesJoin: SeriesSealed + Sized { // Take slices so that vecs are not copied let lhs = lhs.iter().map(|k| k.as_slice()).collect::>(); let rhs = rhs.iter().map(|k| k.as_slice()).collect::>(); + let build_null_count = if swapped { + s_self.null_count() + } else { + other.null_count() + }; Ok(( - hash_join_tuples_inner(lhs, rhs, swapped, validate, join_nulls)?, + hash_join_tuples_inner( + lhs, + rhs, + swapped, + validate, + join_nulls, + build_null_count, + )?, !swapped, )) }, @@ -352,20 +378,38 @@ where .map(|arr| arr.as_slice().unwrap()) .collect::>(); Ok(( - hash_join_tuples_inner(splitted_a, splitted_b, swapped, validate, join_nulls)?, + hash_join_tuples_inner( + splitted_a, splitted_b, swapped, validate, join_nulls, 0, + )?, !swapped, )) } else { Ok(( - hash_join_tuples_inner(splitted_a, splitted_b, swapped, validate, join_nulls)?, + hash_join_tuples_inner( + splitted_a, splitted_b, swapped, validate, join_nulls, 0, + )?, !swapped, )) } }, - _ => Ok(( - hash_join_tuples_inner(splitted_a, splitted_b, swapped, validate, join_nulls)?, - !swapped, - )), + _ => { + let build_null_count = if swapped { + left.null_count() + } else { + right.null_count() + }; + Ok(( + hash_join_tuples_inner( + splitted_a, + splitted_b, + swapped, + validate, + join_nulls, + build_null_count, + )?, + !swapped, + )) + }, } } @@ -430,7 +474,7 @@ where (0, 0, 1, 1) => { let keys_a = chunks_as_slices(&splitted_a); let keys_b = chunks_as_slices(&splitted_b); - hash_join_tuples_left(keys_a, keys_b, None, None, validate, join_nulls) + hash_join_tuples_left(keys_a, keys_b, None, None, validate, join_nulls, 0) }, (0, 0, _, _) => { let keys_a = chunks_as_slices(&splitted_a); @@ -445,6 +489,7 @@ where mapping_right.as_deref(), validate, join_nulls, + 0, ) }, _ => { @@ -452,6 +497,7 @@ where let keys_b = get_arrays(&splitted_b); let (mapping_left, mapping_right) = create_mappings(left.chunks(), right.chunks(), left.len(), right.len()); + let build_null_count = right.null_count(); hash_join_tuples_left( keys_a, keys_b, @@ -459,6 +505,7 @@ where mapping_right.as_deref(), validate, join_nulls, + build_null_count, ) }, } diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs index f01c99529aea..a97431c67e12 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs @@ -44,6 +44,8 @@ pub(super) fn hash_join_tuples_inner( swapped: bool, validate: JoinValidation, join_nulls: bool, + // We should know the number of nulls to avoid extra calculation + build_null_count: usize, ) -> PolarsResult<(Vec, Vec)> where I: IntoIterator + Send + Sync + Clone, @@ -53,10 +55,13 @@ where // NOTE: see the left join for more elaborate comments // first we hash one relation let hash_tbls = if validate.needs_checks() { - let expected_size = build + let mut expected_size = build .iter() .map(|v| v.clone().into_iter().size_hint().1.unwrap()) .sum(); + if !join_nulls { + expected_size = expected_size - build_null_count; + } let hash_tbls = build_tables(build, join_nulls); let build_size = hash_tbls.iter().map(|m| m.len()).sum(); validate.validate_build(build_size, expected_size, swapped)?; diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs index 91c4f0cd1008..0367770701e2 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs @@ -112,6 +112,8 @@ pub(super) fn hash_join_tuples_left( chunk_mapping_right: Option<&[ChunkId]>, validate: JoinValidation, join_nulls: bool, + // We should know the number of nulls to avoid extra calculation + build_null_count: usize, ) -> PolarsResult where I: IntoIterator, @@ -123,7 +125,10 @@ where let build = build.into_iter().map(|i| i.into_iter()).collect::>(); // first we hash one relation let hash_tbls = if validate.needs_checks() { - let expected_size = build.iter().map(|v| v.size_hint().1.unwrap()).sum(); + let mut expected_size = build.iter().map(|v| v.size_hint().1.unwrap()).sum(); + if !join_nulls { + expected_size = expected_size - build_null_count; + } let hash_tbls = build_tables(build, join_nulls); let build_size = hash_tbls.iter().map(|m| m.len()).sum(); validate.validate_build(build_size, expected_size, false)?; diff --git a/py-polars/tests/unit/sql/test_joins.py b/py-polars/tests/unit/sql/test_joins.py index d25610eb6763..c423fc4c45f4 100644 --- a/py-polars/tests/unit/sql/test_joins.py +++ b/py-polars/tests/unit/sql/test_joins.py @@ -663,3 +663,26 @@ def test_nested_join(join_clause: str) -> None: "Species": "Human", }, ] + + +def test_join_nulls_19624() -> None: + df1 = pl.DataFrame({"a": [1, 2, None, None]}) + df2 = pl.DataFrame({"a": [1, 1, 2, 2, None], "b": [0, 1, 2, 3, 4]}) + + # left join + result_df = df1.join(df2, how="left", on="a", join_nulls=False, validate="1:m") + expected_df = pl.DataFrame( + {"a": [1, 1, 2, 2, None, None], "b": [0, 1, 2, 3, None, None]} + ) + assert_frame_equal(result_df, expected_df) + result_df = df2.join(df1, how="left", on="a", join_nulls=False, validate="m:1") + expected_df = pl.DataFrame({"a": [1, 1, 2, 2, None], "b": [0, 1, 2, 3, 4]}) + assert_frame_equal(result_df, expected_df) + + # inner join + result_df = df1.join(df2, how="inner", on="a", join_nulls=False, validate="1:m") + expected_df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [0, 1, 2, 3]}) + assert_frame_equal(result_df, expected_df) + result_df = df2.join(df1, how="inner", on="a", join_nulls=False, validate="m:1") + expected_df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [0, 1, 2, 3]}) + assert_frame_equal(result_df, expected_df)