From c2be3cba23abf8c9d6e50f494399675addf7447d Mon Sep 17 00:00:00 2001 From: Lava <34743145+CanglongCl@users.noreply.github.com> Date: Fri, 4 Aug 2023 20:23:19 +0800 Subject: [PATCH] Revert `validate probe` and fix bugs --- .../polars-core/src/frame/hash_join/args.rs | 71 +++++++------------ .../frame/hash_join/single_keys_dispatch.rs | 29 ++++---- .../src/frame/hash_join/single_keys_inner.rs | 11 +-- .../src/frame/hash_join/single_keys_left.rs | 11 +-- .../src/frame/hash_join/single_keys_outer.rs | 13 +--- 5 files changed, 45 insertions(+), 90 deletions(-) diff --git a/crates/polars-core/src/frame/hash_join/args.rs b/crates/polars-core/src/frame/hash_join/args.rs index 3b1ac2d49f7a..88e9ea79d932 100644 --- a/crates/polars-core/src/frame/hash_join/args.rs +++ b/crates/polars-core/src/frame/hash_join/args.rs @@ -131,48 +131,28 @@ impl JoinValidation { Ok(()) } - pub(super) fn validate_probe<'a, F, I, T>( + pub(super) fn validate_probe( &self, - probe: F, - // In a left join, probe is always in lhs. - // In a inner or outer join, it is the longest relationship of both sides. - is_rhs: bool, - ) -> PolarsResult<()> - where - F: Fn() -> I + Send + Sync, - I: 'a + Iterator + Send + Sync, - T: Send + Hash + Eq + Sync + Copy + AsU64 + 'a, - { + s_left: &Series, + s_right: &Series, + build_shortest_table: bool, + ) -> PolarsResult<()> { + // Only check the `build` side. + // The other side use `validate_build` to check + + // In default, probe is the left series. + // the shortest relation is built, and will be put in the right + // If left is shorter, swap. + // If left == right, apply swap which is the same logic as `det_hash_prone_order` + let should_swap = build_shortest_table && s_left.len() >= s_right.len(); + let probe = if should_swap { s_right } else { s_left }; + use JoinValidation::*; - let fail = match self.swap(is_rhs) { - // Only check the `prone` side. - // The other side use `validate_build` to check - ManyToMany | ManyToOne => false, - OneToMany | OneToOne => { - // check any key in prone is duplicated - let n_partitions = _set_partition_size(); - POOL.install(|| { - (0..n_partitions) - .into_par_iter() - .find_any(|partition_no| { - let partition_no = *partition_no as u64; - let n_partitions = n_partitions as u64; - - let mut hash_set: PlHashSet = - PlHashSet::with_capacity(HASHMAP_INIT_SIZE); - probe().any(|key| { - if this_partition(key.as_u64(), partition_no, n_partitions) { - !hash_set.insert(key) - } else { - false - } - }) - }) - .is_some() - }) - } + let valid = match self.swap(should_swap) { + ManyToMany | ManyToOne => true, + OneToMany | OneToOne => probe.n_unique()? == probe.len(), }; - polars_ensure!(!fail, ComputeError: "the join keys did not fulfil {} validation", self); + polars_ensure!(valid, ComputeError: "the join keys did not fulfil {} validation", self); Ok(()) } @@ -180,15 +160,16 @@ impl JoinValidation { &self, build_size: usize, expected_size: usize, - is_rhs: bool, + swapped: bool, ) -> PolarsResult<()> { use JoinValidation::*; - // Only check the `build` side. - // The other side use `validate_prone` to check - let valid = match self.swap(is_rhs) { - ManyToMany | ManyToOne => true, - OneToMany | OneToOne => build_size == expected_size, + // In default, build is in rhs. + let valid = match self.swap(swapped) { + // Only check the `build` side. + // The other side use `validate_prone` to check + ManyToMany | OneToMany => true, + ManyToOne | OneToOne => build_size == expected_size, }; polars_ensure!(valid, ComputeError: "the join keys did not fulfil {} validation", self); Ok(()) diff --git a/crates/polars-core/src/frame/hash_join/single_keys_dispatch.rs b/crates/polars-core/src/frame/hash_join/single_keys_dispatch.rs index f5c7b32485df..9ecf3c84a403 100644 --- a/crates/polars-core/src/frame/hash_join/single_keys_dispatch.rs +++ b/crates/polars-core/src/frame/hash_join/single_keys_dispatch.rs @@ -13,6 +13,7 @@ impl Series { validate: JoinValidation, ) -> PolarsResult { let (lhs, rhs) = (self.to_physical_repr(), other.to_physical_repr()); + validate.validate_probe(&lhs, &rhs, false)?; use DataType::*; match lhs.dtype() { @@ -83,6 +84,7 @@ impl Series { validate: JoinValidation, ) -> PolarsResult<(InnerJoinIds, bool)> { let (lhs, rhs) = (self.to_physical_repr(), other.to_physical_repr()); + validate.validate_probe(&lhs, &rhs, true)?; use DataType::*; match lhs.dtype() { @@ -119,6 +121,7 @@ impl Series { validate: JoinValidation, ) -> PolarsResult, Option)>> { let (lhs, rhs) = (self.to_physical_repr(), other.to_physical_repr()); + validate.validate_probe(&lhs, &rhs, true)?; use DataType::*; match lhs.dtype() { @@ -191,7 +194,7 @@ where Option: AsU64, { let n_threads = POOL.current_num_threads(); - let (a, b, swap) = det_hash_prone_order!(left, right); + let (a, b, swapped) = det_hash_prone_order!(left, right); let splitted_a = split_ca(a, n_threads).unwrap(); let splitted_b = split_ca(b, n_threads).unwrap(); match ( @@ -204,24 +207,24 @@ where let keys_a = splitted_to_slice(&splitted_a); let keys_b = splitted_to_slice(&splitted_b); Ok(( - hash_join_tuples_inner(keys_a, keys_b, swap, validate)?, - !swap, + hash_join_tuples_inner(keys_a, keys_b, swapped, validate)?, + !swapped, )) } (true, true, _, _) => { let keys_a = splitted_by_chunks(&splitted_a); let keys_b = splitted_by_chunks(&splitted_b); Ok(( - hash_join_tuples_inner(keys_a, keys_b, swap, validate)?, - !swap, + hash_join_tuples_inner(keys_a, keys_b, swapped, validate)?, + !swapped, )) } _ => { let keys_a = splitted_to_opt_vec(&splitted_a); let keys_b = splitted_to_opt_vec(&splitted_b); Ok(( - hash_join_tuples_inner(keys_a, keys_b, swap, validate)?, - !swap, + hash_join_tuples_inner(keys_a, keys_b, swapped, validate)?, + !swapped, )) } } @@ -408,12 +411,12 @@ impl BinaryChunked { other: &BinaryChunked, validate: JoinValidation, ) -> PolarsResult<(InnerJoinIds, bool)> { - let (splitted_a, splitted_b, swap, hb) = self.prepare(other, true); + let (splitted_a, splitted_b, swapped, hb) = self.prepare(other, true); let str_hashes_a = prepare_bytes(&splitted_a, &hb); let str_hashes_b = prepare_bytes(&splitted_b, &hb); Ok(( - hash_join_tuples_inner(str_hashes_a, str_hashes_b, swap, validate)?, - !swap, + hash_join_tuples_inner(str_hashes_a, str_hashes_b, swapped, validate)?, + !swapped, )) } @@ -454,7 +457,7 @@ impl BinaryChunked { other: &BinaryChunked, validate: JoinValidation, ) -> PolarsResult, Option)>> { - let (a, b, swap) = det_hash_prone_order!(self, other); + let (a, b, swapped) = det_hash_prone_order!(self, other); let n_partitions = _set_partition_size(); let splitted_a = split_ca(a, n_partitions).unwrap(); @@ -470,7 +473,7 @@ impl BinaryChunked { .iter() .map(|ca| ca.into_no_null_iter()) .collect::>(); - hash_join_tuples_outer(iters_a, iters_b, swap, validate) + hash_join_tuples_outer(iters_a, iters_b, swapped, validate) } _ => { let iters_a = splitted_a @@ -481,7 +484,7 @@ impl BinaryChunked { .iter() .map(|ca| ca.into_iter()) .collect::>(); - hash_join_tuples_outer(iters_a, iters_b, swap, validate) + hash_join_tuples_outer(iters_a, iters_b, swapped, validate) } } } diff --git a/crates/polars-core/src/frame/hash_join/single_keys_inner.rs b/crates/polars-core/src/frame/hash_join/single_keys_inner.rs index 9ad2702df320..a877fa516ab5 100644 --- a/crates/polars-core/src/frame/hash_join/single_keys_inner.rs +++ b/crates/polars-core/src/frame/hash_join/single_keys_inner.rs @@ -49,19 +49,10 @@ where // first we hash one relation let hash_tbls = if validate.needs_checks() { - validate.validate_probe( - || { - let nested: &[IntoSlice] = probe.as_ref(); - nested - .iter() - .flat_map(|into_slice| into_slice.as_ref().iter()) - }, - swapped, - )?; let expected_size = build.iter().map(|v| v.as_ref().len()).sum(); let hash_tbls = build_tables(build); let build_size = hash_tbls.iter().map(|m| m.len()).sum(); - validate.validate_build(build_size, expected_size, !swapped)?; + validate.validate_build(build_size, expected_size, swapped)?; hash_tbls } else { build_tables(build) diff --git a/crates/polars-core/src/frame/hash_join/single_keys_left.rs b/crates/polars-core/src/frame/hash_join/single_keys_left.rs index 8cf32f44e8e8..4783df039004 100644 --- a/crates/polars-core/src/frame/hash_join/single_keys_left.rs +++ b/crates/polars-core/src/frame/hash_join/single_keys_left.rs @@ -113,19 +113,10 @@ where { // first we hash one relation let hash_tbls = if validate.needs_checks() { - validate.validate_probe( - || { - let nested: &[IntoSlice] = probe.as_ref(); - nested - .iter() - .flat_map(|into_slice| into_slice.as_ref().iter()) - }, - false, - )?; let expected_size = build.iter().map(|v| v.as_ref().len()).sum(); let hash_tbls = build_tables(build); let build_size = hash_tbls.iter().map(|m| m.len()).sum(); - validate.validate_build(build_size, expected_size, true)?; + validate.validate_build(build_size, expected_size, false)?; hash_tbls } else { build_tables(build) diff --git a/crates/polars-core/src/frame/hash_join/single_keys_outer.rs b/crates/polars-core/src/frame/hash_join/single_keys_outer.rs index e064affcb06c..61148c393fab 100644 --- a/crates/polars-core/src/frame/hash_join/single_keys_outer.rs +++ b/crates/polars-core/src/frame/hash_join/single_keys_outer.rs @@ -88,7 +88,7 @@ where let expected_size = build.iter().map(|i| i.size_hint().0).sum(); let hash_tbls = prepare_hashed_relation_threaded(build); let build_size = hash_tbls.iter().map(|m| m.len()).sum(); - validate.validate_build(build_size, expected_size, !swapped)?; + validate.validate_build(build_size, expected_size, swapped)?; hash_tbls } else { prepare_hashed_relation_threaded(build) @@ -98,17 +98,6 @@ where // we pre hash the probing values let (probe_hashes, _) = create_hash_and_keys_threaded_vectorized(probe, Some(random_state)); - if validate.needs_checks() { - validate.validate_probe( - || { - probe_hashes - .iter() - .flat_map(|it| it.iter().map(|item| item.0)) - }, - swapped, - )?; - } - let n_tables = hash_tbls.len() as u64; // probe the hash table.