diff --git a/src/daft-local-execution/src/sinks/hash_join.rs b/src/daft-local-execution/src/sinks/hash_join.rs index 77b22a5791..c0cb2c36fc 100644 --- a/src/daft-local-execution/src/sinks/hash_join.rs +++ b/src/daft-local-execution/src/sinks/hash_join.rs @@ -87,6 +87,7 @@ impl HashJoinState { pub(crate) struct HashJoinOperator { right_on: Vec, + pruned_right_side_columns: Vec, _join_type: JoinType, join_state: HashJoinState, } @@ -130,9 +131,27 @@ impl HashJoinOperator { .zip(key_schema.fields.values()) .map(|(e, f)| e.cast(&f.dtype)) .collect::>(); + let common_join_keys = left_on + .iter() + .zip(right_on.iter()) + .filter_map(|(l, r)| { + if l.name() == r.name() { + Some(l.name()) + } else { + None + } + }) + .collect::>(); + let pruned_right_side_columns = right_schema + .fields + .keys() + .filter(|k| !common_join_keys.contains(k.as_str())) + .cloned() + .collect::>(); assert_eq!(join_type, JoinType::Inner); Ok(Self { right_on, + pruned_right_side_columns, _join_type: join_type, join_state: HashJoinState::new(&key_schema, left_on)?, }) @@ -152,6 +171,7 @@ impl HashJoinOperator { probe_table: probe_table.clone(), tables: tables.clone(), right_on: self.right_on.clone(), + pruned_right_side_columns: self.pruned_right_side_columns.clone(), }) } else { panic!("can't call as_intermediate_op when not in probing state") @@ -163,6 +183,7 @@ struct HashJoinProber { probe_table: Arc, tables: Arc>, right_on: Vec, + pruned_right_side_columns: Vec, } impl IntermediateOperator for HashJoinProber { @@ -201,7 +222,9 @@ impl IntermediateOperator for HashJoinProber { let left_table = left_growable.build()?; let right_table = right_growable.build()?; - let final_table = left_table.union(&right_table)?; + let pruned_right_table = right_table.get_columns(&self.pruned_right_side_columns)?; + + let final_table = left_table.union(&pruned_right_table)?; Ok(Arc::new(MicroPartition::new_loaded( final_table.schema.clone(), Arc::new(vec![final_table]),