From ff1c13959096d3fc44ba15add01aa12c8fbf304e Mon Sep 17 00:00:00 2001 From: Albert Skalt Date: Mon, 27 Jan 2025 21:17:55 +0300 Subject: [PATCH] equivalence classes: fix projection This patch fixes the logic that projects equivalence classes: when run over the projection mapping to find new equivalent expressions, we need to normalize a source expression. --- .../physical-expr/src/equivalence/class.rs | 64 +++++++++++++++- .../sqllogictest/test_files/join.slt.part | 76 +++++++++++++++++++ 2 files changed, 138 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 5c749a1a5a6e..7ee90bf3fc86 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -584,12 +584,18 @@ impl EquivalenceGroup { .collect::>(); (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) }); + // the key is the source expression and the value is the EquivalenceClass that contains the target expression of the source expression. let mut new_classes: IndexMap, EquivalenceClass> = IndexMap::new(); mapping.iter().for_each(|(source, target)| { + // We need to find equivalent projected expressions. + // e.g. table with columns [a,b,c] and a == b, projection: [a+c, b+c]. + // To conclude that a + c == b + c we firsty normalize all source expressions + // in the mapping, then merge all equivalent expressions into the classes. + let normalized_expr = self.normalize_expr(Arc::clone(source)); new_classes - .entry(Arc::clone(source)) + .entry(normalized_expr) .or_insert_with(EquivalenceClass::new_empty) .push(Arc::clone(target)); }); @@ -752,8 +758,9 @@ mod tests { use super::*; use crate::equivalence::tests::create_test_params; - use crate::expressions::{lit, BinaryExpr, Literal}; + use crate::expressions::{binary, col, lit, BinaryExpr, Literal}; + use arrow_schema::{DataType, Field, Schema}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; @@ -1038,4 +1045,57 @@ mod tests { Ok(()) } + + #[test] + fn test_project_classes() -> Result<()> { + // - columns: [a, b, c]. + // - "a" and "b" in the same equivalence class. + // - then after a+c, b+c projection col(0) and col(1) must be + // in the same class too. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + let mut group = EquivalenceGroup::empty(); + group.add_equal_conditions(&col("a", &schema)?, &col("b", &schema)?); + + let projected_schema = Arc::new(Schema::new(vec![ + Field::new("a+c", DataType::Int32, false), + Field::new("b+c", DataType::Int32, false), + ])); + + let mapping = ProjectionMapping { + map: vec![ + ( + binary( + col("a", &schema)?, + Operator::Plus, + col("c", &schema)?, + &schema, + )?, + col("a+c", &projected_schema)?, + ), + ( + binary( + col("b", &schema)?, + Operator::Plus, + col("c", &schema)?, + &schema, + )?, + col("b+c", &projected_schema)?, + ), + ], + }; + + let projected = group.project(&mapping); + + assert!(!projected.is_empty()); + let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?); + let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?); + + assert!(first_normalized.eq(&second_normalized)); + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/join.slt.part b/datafusion/sqllogictest/test_files/join.slt.part index 1feacc5ebe53..9d396a63d0b2 100644 --- a/datafusion/sqllogictest/test_files/join.slt.part +++ b/datafusion/sqllogictest/test_files/join.slt.part @@ -1312,3 +1312,79 @@ SELECT a+b*2, statement ok drop table t1; + +# Test that equivalent classes are projected correctly. + +statement ok +create table pairs(x int, y int) as values (1,1), (2,2), (3,3); + +statement ok +create table f(a int) as values (1), (2), (3); + +statement ok +create table s(b int) as values (1), (2), (3); + +statement ok +set datafusion.optimizer.repartition_joins = true; + +statement ok +set datafusion.execution.target_partitions = 16; + +# After the filter applying (x = y) we can join by both x and y, +# partitioning only once. + +query TT +explain +SELECT * FROM +(SELECT x+1 AS col0, y+1 AS col1 FROM PAIRS WHERE x == y) +JOIN f +ON col0 = f.a +JOIN s +ON col1 = s.b +---- +logical_plan +01)Inner Join: col1 = CAST(s.b AS Int64) +02)--Inner Join: col0 = CAST(f.a AS Int64) +03)----Projection: CAST(pairs.x AS Int64) + Int64(1) AS col0, CAST(pairs.y AS Int64) + Int64(1) AS col1 +04)------Filter: pairs.y = pairs.x +05)--------TableScan: pairs projection=[x, y] +06)----TableScan: f projection=[a] +07)--TableScan: s projection=[b] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col1@1, CAST(s.b AS Int64)@1)], projection=[col0@0, col1@1, a@2, b@3] +03)----ProjectionExec: expr=[col0@1 as col0, col1@2 as col1, a@0 as a] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(CAST(f.a AS Int64)@1, col0@0)], projection=[a@0, col0@2, col1@3] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------RepartitionExec: partitioning=Hash([CAST(f.a AS Int64)@1], 16), input_partitions=1 +08)--------------ProjectionExec: expr=[a@0 as a, CAST(a@0 AS Int64) as CAST(f.a AS Int64)] +09)----------------MemoryExec: partitions=1, partition_sizes=[1] +10)----------CoalesceBatchesExec: target_batch_size=8192 +11)------------RepartitionExec: partitioning=Hash([col0@0], 16), input_partitions=16 +12)--------------ProjectionExec: expr=[CAST(x@0 AS Int64) + 1 as col0, CAST(y@1 AS Int64) + 1 as col1] +13)----------------RepartitionExec: partitioning=RoundRobinBatch(16), input_partitions=1 +14)------------------CoalesceBatchesExec: target_batch_size=8192 +15)--------------------FilterExec: y@1 = x@0 +16)----------------------MemoryExec: partitions=1, partition_sizes=[1] +17)----CoalesceBatchesExec: target_batch_size=8192 +18)------RepartitionExec: partitioning=Hash([CAST(s.b AS Int64)@1], 16), input_partitions=1 +19)--------ProjectionExec: expr=[b@0 as b, CAST(b@0 AS Int64) as CAST(s.b AS Int64)] +20)----------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +drop table pairs; + +statement ok +drop table f; + +statement ok +drop table s; + +# Reset the configs to old values. +statement ok +set datafusion.execution.target_partitions = 4; + +statement ok +set datafusion.optimizer.repartition_joins = false; +