Skip to content

Commit

Permalink
Revert validate probe and fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
CanglongCl committed Aug 4, 2023
1 parent 771b0c4 commit c2be3cb
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 90 deletions.
71 changes: 26 additions & 45 deletions crates/polars-core/src/frame/hash_join/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,64 +131,45 @@ impl JoinValidation {
Ok(())
}

pub(super) fn validate_probe<'a, F, I, T>(
pub(super) fn validate_probe(
&self,
probe: F,
// In a left join, probe is always in lhs.
// In a inner or outer join, it is the longest relationship of both sides.
is_rhs: bool,
) -> PolarsResult<()>
where
F: Fn() -> I + Send + Sync,
I: 'a + Iterator<Item = T> + Send + Sync,
T: Send + Hash + Eq + Sync + Copy + AsU64 + 'a,
{
s_left: &Series,
s_right: &Series,
build_shortest_table: bool,
) -> PolarsResult<()> {
// Only check the `build` side.
// The other side use `validate_build` to check

// In default, probe is the left series.
// the shortest relation is built, and will be put in the right
// If left is shorter, swap.
// If left == right, apply swap which is the same logic as `det_hash_prone_order`
let should_swap = build_shortest_table && s_left.len() >= s_right.len();
let probe = if should_swap { s_right } else { s_left };

use JoinValidation::*;
let fail = match self.swap(is_rhs) {
// Only check the `prone` side.
// The other side use `validate_build` to check
ManyToMany | ManyToOne => false,
OneToMany | OneToOne => {
// check any key in prone is duplicated
let n_partitions = _set_partition_size();
POOL.install(|| {
(0..n_partitions)
.into_par_iter()
.find_any(|partition_no| {
let partition_no = *partition_no as u64;
let n_partitions = n_partitions as u64;

let mut hash_set: PlHashSet<T> =
PlHashSet::with_capacity(HASHMAP_INIT_SIZE);
probe().any(|key| {
if this_partition(key.as_u64(), partition_no, n_partitions) {
!hash_set.insert(key)
} else {
false
}
})
})
.is_some()
})
}
let valid = match self.swap(should_swap) {
ManyToMany | ManyToOne => true,
OneToMany | OneToOne => probe.n_unique()? == probe.len(),
};
polars_ensure!(!fail, ComputeError: "the join keys did not fulfil {} validation", self);
polars_ensure!(valid, ComputeError: "the join keys did not fulfil {} validation", self);
Ok(())
}

pub(super) fn validate_build(
&self,
build_size: usize,
expected_size: usize,
is_rhs: bool,
swapped: bool,
) -> PolarsResult<()> {
use JoinValidation::*;

// Only check the `build` side.
// The other side use `validate_prone` to check
let valid = match self.swap(is_rhs) {
ManyToMany | ManyToOne => true,
OneToMany | OneToOne => build_size == expected_size,
// In default, build is in rhs.
let valid = match self.swap(swapped) {
// Only check the `build` side.
// The other side use `validate_prone` to check
ManyToMany | OneToMany => true,
ManyToOne | OneToOne => build_size == expected_size,
};
polars_ensure!(valid, ComputeError: "the join keys did not fulfil {} validation", self);
Ok(())
Expand Down
29 changes: 16 additions & 13 deletions crates/polars-core/src/frame/hash_join/single_keys_dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ impl Series {
validate: JoinValidation,
) -> PolarsResult<LeftJoinIds> {
let (lhs, rhs) = (self.to_physical_repr(), other.to_physical_repr());
validate.validate_probe(&lhs, &rhs, false)?;

use DataType::*;
match lhs.dtype() {
Expand Down Expand Up @@ -83,6 +84,7 @@ impl Series {
validate: JoinValidation,
) -> PolarsResult<(InnerJoinIds, bool)> {
let (lhs, rhs) = (self.to_physical_repr(), other.to_physical_repr());
validate.validate_probe(&lhs, &rhs, true)?;

use DataType::*;
match lhs.dtype() {
Expand Down Expand Up @@ -119,6 +121,7 @@ impl Series {
validate: JoinValidation,
) -> PolarsResult<Vec<(Option<IdxSize>, Option<IdxSize>)>> {
let (lhs, rhs) = (self.to_physical_repr(), other.to_physical_repr());
validate.validate_probe(&lhs, &rhs, true)?;

use DataType::*;
match lhs.dtype() {
Expand Down Expand Up @@ -191,7 +194,7 @@ where
Option<T::Native>: AsU64,
{
let n_threads = POOL.current_num_threads();
let (a, b, swap) = det_hash_prone_order!(left, right);
let (a, b, swapped) = det_hash_prone_order!(left, right);
let splitted_a = split_ca(a, n_threads).unwrap();
let splitted_b = split_ca(b, n_threads).unwrap();
match (
Expand All @@ -204,24 +207,24 @@ where
let keys_a = splitted_to_slice(&splitted_a);
let keys_b = splitted_to_slice(&splitted_b);
Ok((
hash_join_tuples_inner(keys_a, keys_b, swap, validate)?,
!swap,
hash_join_tuples_inner(keys_a, keys_b, swapped, validate)?,
!swapped,
))
}
(true, true, _, _) => {
let keys_a = splitted_by_chunks(&splitted_a);
let keys_b = splitted_by_chunks(&splitted_b);
Ok((
hash_join_tuples_inner(keys_a, keys_b, swap, validate)?,
!swap,
hash_join_tuples_inner(keys_a, keys_b, swapped, validate)?,
!swapped,
))
}
_ => {
let keys_a = splitted_to_opt_vec(&splitted_a);
let keys_b = splitted_to_opt_vec(&splitted_b);
Ok((
hash_join_tuples_inner(keys_a, keys_b, swap, validate)?,
!swap,
hash_join_tuples_inner(keys_a, keys_b, swapped, validate)?,
!swapped,
))
}
}
Expand Down Expand Up @@ -408,12 +411,12 @@ impl BinaryChunked {
other: &BinaryChunked,
validate: JoinValidation,
) -> PolarsResult<(InnerJoinIds, bool)> {
let (splitted_a, splitted_b, swap, hb) = self.prepare(other, true);
let (splitted_a, splitted_b, swapped, hb) = self.prepare(other, true);
let str_hashes_a = prepare_bytes(&splitted_a, &hb);
let str_hashes_b = prepare_bytes(&splitted_b, &hb);
Ok((
hash_join_tuples_inner(str_hashes_a, str_hashes_b, swap, validate)?,
!swap,
hash_join_tuples_inner(str_hashes_a, str_hashes_b, swapped, validate)?,
!swapped,
))
}

Expand Down Expand Up @@ -454,7 +457,7 @@ impl BinaryChunked {
other: &BinaryChunked,
validate: JoinValidation,
) -> PolarsResult<Vec<(Option<IdxSize>, Option<IdxSize>)>> {
let (a, b, swap) = det_hash_prone_order!(self, other);
let (a, b, swapped) = det_hash_prone_order!(self, other);

let n_partitions = _set_partition_size();
let splitted_a = split_ca(a, n_partitions).unwrap();
Expand All @@ -470,7 +473,7 @@ impl BinaryChunked {
.iter()
.map(|ca| ca.into_no_null_iter())
.collect::<Vec<_>>();
hash_join_tuples_outer(iters_a, iters_b, swap, validate)
hash_join_tuples_outer(iters_a, iters_b, swapped, validate)
}
_ => {
let iters_a = splitted_a
Expand All @@ -481,7 +484,7 @@ impl BinaryChunked {
.iter()
.map(|ca| ca.into_iter())
.collect::<Vec<_>>();
hash_join_tuples_outer(iters_a, iters_b, swap, validate)
hash_join_tuples_outer(iters_a, iters_b, swapped, validate)
}
}
}
Expand Down
11 changes: 1 addition & 10 deletions crates/polars-core/src/frame/hash_join/single_keys_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,10 @@ where

// first we hash one relation
let hash_tbls = if validate.needs_checks() {
validate.validate_probe(
|| {
let nested: &[IntoSlice] = probe.as_ref();
nested
.iter()
.flat_map(|into_slice| into_slice.as_ref().iter())
},
swapped,
)?;
let expected_size = build.iter().map(|v| v.as_ref().len()).sum();
let hash_tbls = build_tables(build);
let build_size = hash_tbls.iter().map(|m| m.len()).sum();
validate.validate_build(build_size, expected_size, !swapped)?;
validate.validate_build(build_size, expected_size, swapped)?;
hash_tbls
} else {
build_tables(build)
Expand Down
11 changes: 1 addition & 10 deletions crates/polars-core/src/frame/hash_join/single_keys_left.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,10 @@ where
{
// first we hash one relation
let hash_tbls = if validate.needs_checks() {
validate.validate_probe(
|| {
let nested: &[IntoSlice] = probe.as_ref();
nested
.iter()
.flat_map(|into_slice| into_slice.as_ref().iter())
},
false,
)?;
let expected_size = build.iter().map(|v| v.as_ref().len()).sum();
let hash_tbls = build_tables(build);
let build_size = hash_tbls.iter().map(|m| m.len()).sum();
validate.validate_build(build_size, expected_size, true)?;
validate.validate_build(build_size, expected_size, false)?;
hash_tbls
} else {
build_tables(build)
Expand Down
13 changes: 1 addition & 12 deletions crates/polars-core/src/frame/hash_join/single_keys_outer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ where
let expected_size = build.iter().map(|i| i.size_hint().0).sum();
let hash_tbls = prepare_hashed_relation_threaded(build);
let build_size = hash_tbls.iter().map(|m| m.len()).sum();
validate.validate_build(build_size, expected_size, !swapped)?;
validate.validate_build(build_size, expected_size, swapped)?;
hash_tbls
} else {
prepare_hashed_relation_threaded(build)
Expand All @@ -98,17 +98,6 @@ where
// we pre hash the probing values
let (probe_hashes, _) = create_hash_and_keys_threaded_vectorized(probe, Some(random_state));

if validate.needs_checks() {
validate.validate_probe(
|| {
probe_hashes
.iter()
.flat_map(|it| it.iter().map(|item| item.0))
},
swapped,
)?;
}

let n_tables = hash_tbls.len() as u64;

// probe the hash table.
Expand Down

0 comments on commit c2be3cb

Please sign in to comment.