Skip to content

Commit

Permalink
working version with left
Browse files Browse the repository at this point in the history
  • Loading branch information
barak1412 committed Nov 8, 2024
1 parent 66a8381 commit e02aa32
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 14 deletions.
2 changes: 1 addition & 1 deletion crates/polars-ops/src/frame/join/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ impl JoinValidation {
OneToMany | OneToOne => {
if !join_nulls && probe.null_count() > 0 {
probe.n_unique()? - 1 == probe.len() - probe.null_count()
}else {
} else {
probe.n_unique()? == probe.len()
}
},
Expand Down
69 changes: 58 additions & 11 deletions crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ pub trait SeriesJoin: SeriesSealed + Sized {
let (lhs, rhs, _, _) = prepare_binary::<BinaryType>(lhs, rhs, false);
let lhs = lhs.iter().map(|v| v.as_slice()).collect::<Vec<_>>();
let rhs = rhs.iter().map(|v| v.as_slice()).collect::<Vec<_>>();
hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls)
let build_null_count = other.null_count();
hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls, build_null_count)
},
T::BinaryOffset => {
let lhs = lhs.binary_offset().unwrap();
Expand All @@ -44,7 +45,8 @@ pub trait SeriesJoin: SeriesSealed + Sized {
// Take slices so that vecs are not copied
let lhs = lhs.iter().map(|k| k.as_slice()).collect::<Vec<_>>();
let rhs = rhs.iter().map(|k| k.as_slice()).collect::<Vec<_>>();
hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls)
let build_null_count = other.null_count();
hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls, build_null_count)
},
x if x.is_float() => {
with_match_physical_float_polars_type!(lhs.dtype(), |$T| {
Expand Down Expand Up @@ -184,8 +186,20 @@ pub trait SeriesJoin: SeriesSealed + Sized {
// Take slices so that vecs are not copied
let lhs = lhs.iter().map(|k| k.as_slice()).collect::<Vec<_>>();
let rhs = rhs.iter().map(|k| k.as_slice()).collect::<Vec<_>>();
let build_null_count = if swapped {
s_self.null_count()
} else {
other.null_count()
};
Ok((
hash_join_tuples_inner(lhs, rhs, swapped, validate, join_nulls)?,
hash_join_tuples_inner(
lhs,
rhs,
swapped,
validate,
join_nulls,
build_null_count,
)?,
!swapped,
))
},
Expand All @@ -196,8 +210,20 @@ pub trait SeriesJoin: SeriesSealed + Sized {
// Take slices so that vecs are not copied
let lhs = lhs.iter().map(|k| k.as_slice()).collect::<Vec<_>>();
let rhs = rhs.iter().map(|k| k.as_slice()).collect::<Vec<_>>();
let build_null_count = if swapped {
s_self.null_count()
} else {
other.null_count()
};
Ok((
hash_join_tuples_inner(lhs, rhs, swapped, validate, join_nulls)?,
hash_join_tuples_inner(
lhs,
rhs,
swapped,
validate,
join_nulls,
build_null_count,
)?,
!swapped,
))
},
Expand Down Expand Up @@ -352,20 +378,38 @@ where
.map(|arr| arr.as_slice().unwrap())
.collect::<Vec<_>>();
Ok((
hash_join_tuples_inner(splitted_a, splitted_b, swapped, validate, join_nulls)?,
hash_join_tuples_inner(
splitted_a, splitted_b, swapped, validate, join_nulls, 0,
)?,
!swapped,
))
} else {
Ok((
hash_join_tuples_inner(splitted_a, splitted_b, swapped, validate, join_nulls)?,
hash_join_tuples_inner(
splitted_a, splitted_b, swapped, validate, join_nulls, 0,
)?,
!swapped,
))
}
},
_ => Ok((
hash_join_tuples_inner(splitted_a, splitted_b, swapped, validate, join_nulls)?,
!swapped,
)),
_ => {
let build_null_count = if swapped {
left.null_count()
} else {
right.null_count()
};
Ok((
hash_join_tuples_inner(
splitted_a,
splitted_b,
swapped,
validate,
join_nulls,
build_null_count,
)?,
!swapped,
))
},
}
}

Expand Down Expand Up @@ -430,7 +474,7 @@ where
(0, 0, 1, 1) => {
let keys_a = chunks_as_slices(&splitted_a);
let keys_b = chunks_as_slices(&splitted_b);
hash_join_tuples_left(keys_a, keys_b, None, None, validate, join_nulls)
hash_join_tuples_left(keys_a, keys_b, None, None, validate, join_nulls, 0)
},
(0, 0, _, _) => {
let keys_a = chunks_as_slices(&splitted_a);
Expand All @@ -445,20 +489,23 @@ where
mapping_right.as_deref(),
validate,
join_nulls,
0,
)
},
_ => {
let keys_a = get_arrays(&splitted_a);
let keys_b = get_arrays(&splitted_b);
let (mapping_left, mapping_right) =
create_mappings(left.chunks(), right.chunks(), left.len(), right.len());
let build_null_count = right.null_count();
hash_join_tuples_left(
keys_a,
keys_b,
mapping_left.as_deref(),
mapping_right.as_deref(),
validate,
join_nulls,
build_null_count,
)
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ pub(super) fn hash_join_tuples_inner<T, I>(
swapped: bool,
validate: JoinValidation,
join_nulls: bool,
// We should know the number of nulls to avoid extra calculation
build_null_count: usize,
) -> PolarsResult<(Vec<IdxSize>, Vec<IdxSize>)>
where
I: IntoIterator<Item = T> + Send + Sync + Clone,
Expand All @@ -53,10 +55,13 @@ where
// NOTE: see the left join for more elaborate comments
// first we hash one relation
let hash_tbls = if validate.needs_checks() {
let expected_size = build
let mut expected_size = build
.iter()
.map(|v| v.clone().into_iter().size_hint().1.unwrap())
.sum();
if !join_nulls {
expected_size = expected_size - build_null_count;
}
let hash_tbls = build_tables(build, join_nulls);
let build_size = hash_tbls.iter().map(|m| m.len()).sum();
validate.validate_build(build_size, expected_size, swapped)?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ pub(super) fn hash_join_tuples_left<T, I>(
chunk_mapping_right: Option<&[ChunkId]>,
validate: JoinValidation,
join_nulls: bool,
// We should know the number of nulls to avoid extra calculation
build_null_count: usize,
) -> PolarsResult<LeftJoinIds>
where
I: IntoIterator<Item = T>,
Expand All @@ -123,7 +125,10 @@ where
let build = build.into_iter().map(|i| i.into_iter()).collect::<Vec<_>>();
// first we hash one relation
let hash_tbls = if validate.needs_checks() {
let expected_size = build.iter().map(|v| v.size_hint().1.unwrap()).sum();
let mut expected_size = build.iter().map(|v| v.size_hint().1.unwrap()).sum();
if !join_nulls {
expected_size = expected_size - build_null_count;
}
let hash_tbls = build_tables(build, join_nulls);
let build_size = hash_tbls.iter().map(|m| m.len()).sum();
validate.validate_build(build_size, expected_size, false)?;
Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/sql/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,3 +663,26 @@ def test_nested_join(join_clause: str) -> None:
"Species": "Human",
},
]


def test_join_nulls_19624() -> None:
df1 = pl.DataFrame({"a": [1, 2, None, None]})
df2 = pl.DataFrame({"a": [1, 1, 2, 2, None], "b": [0, 1, 2, 3, 4]})

# left join
result_df = df1.join(df2, how="left", on="a", join_nulls=False, validate="1:m")
expected_df = pl.DataFrame(
{"a": [1, 1, 2, 2, None, None], "b": [0, 1, 2, 3, None, None]}
)
assert_frame_equal(result_df, expected_df)
result_df = df2.join(df1, how="left", on="a", join_nulls=False, validate="m:1")
expected_df = pl.DataFrame({"a": [1, 1, 2, 2, None], "b": [0, 1, 2, 3, 4]})
assert_frame_equal(result_df, expected_df)

# inner join
result_df = df1.join(df2, how="inner", on="a", join_nulls=False, validate="1:m")
expected_df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [0, 1, 2, 3]})
assert_frame_equal(result_df, expected_df)
result_df = df2.join(df1, how="inner", on="a", join_nulls=False, validate="m:1")
expected_df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [0, 1, 2, 3]})
assert_frame_equal(result_df, expected_df)

0 comments on commit e02aa32

Please sign in to comment.