Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Support uncorrelated EXISTS #14474

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,7 @@ pub fn table_scan_with_filter_and_fetch(
)
}

fn table_source(table_schema: &Schema) -> Arc<dyn TableSource> {
pub fn table_source(table_schema: &Schema) -> Arc<dyn TableSource> {
let table_schema = Arc::new(table_schema.clone());
Arc::new(LogicalTableSource { table_schema })
}
Expand Down
45 changes: 24 additions & 21 deletions datafusion/expr/src/logical_plan/invariants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,31 +209,26 @@ pub fn check_subquery_expr(

// Recursively check the unsupported outer references in the sub query plan.
fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> {
check_inner_plan(inner_plan, true)
check_inner_plan(inner_plan)
}

// Recursively check the unsupported outer references in the sub query plan.
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Result<()> {
if !can_contain_outer_ref && inner_plan.contains_outer_reference() {
return plan_err!("Accessing outer reference columns is not allowed in the plan");
}
fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> {
// We want to support as many operators as possible inside the correlated subquery
match inner_plan {
LogicalPlan::Aggregate(_) => {
inner_plan.apply_children(|plan| {
check_inner_plan(plan, can_contain_outer_ref)?;
check_inner_plan(plan)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}
LogicalPlan::Filter(Filter { input, .. }) => {
check_inner_plan(input, can_contain_outer_ref)
}
LogicalPlan::Filter(Filter { input, .. }) => check_inner_plan(input),
LogicalPlan::Window(window) => {
check_mixed_out_refer_in_window(window)?;
inner_plan.apply_children(|plan| {
check_inner_plan(plan, can_contain_outer_ref)?;
check_inner_plan(plan)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
Expand All @@ -250,7 +245,7 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Unnest(_) => {
inner_plan.apply_children(|plan| {
check_inner_plan(plan, can_contain_outer_ref)?;
check_inner_plan(plan)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
Expand All @@ -263,7 +258,7 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re
}) => match join_type {
JoinType::Inner => {
inner_plan.apply_children(|plan| {
check_inner_plan(plan, can_contain_outer_ref)?;
check_inner_plan(plan)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
Expand All @@ -272,26 +267,34 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re
| JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::LeftMark => {
check_inner_plan(left, can_contain_outer_ref)?;
check_inner_plan(right, false)
check_inner_plan(left)?;
check_no_outer_references(right)
}
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
check_inner_plan(left, false)?;
check_inner_plan(right, can_contain_outer_ref)
check_no_outer_references(left)?;
check_inner_plan(right)
}
JoinType::Full => {
inner_plan.apply_children(|plan| {
check_inner_plan(plan, false)?;
check_no_outer_references(plan)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}
},
LogicalPlan::Extension(_) => Ok(()),
_ => plan_err!(
"Unsupported operator in the subquery plan: {}",
plan => check_no_outer_references(plan),
}
}

fn check_no_outer_references(inner_plan: &LogicalPlan) -> Result<()> {
if inner_plan.contains_outer_reference() {
plan_err!(
"Accessing outer reference columns is not allowed in the plan: {}",
findepi marked this conversation as resolved.
Show resolved Hide resolved
inner_plan.display()
),
)
} else {
Ok(())
}
}

Expand Down Expand Up @@ -433,6 +436,6 @@ mod test {
}),
});

check_inner_plan(&plan, true).unwrap();
check_inner_plan(&plan).unwrap();
}
}
149 changes: 116 additions & 33 deletions datafusion/optimizer/src/decorrelate_predicate_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
use datafusion_expr::logical_plan::{JoinType, Subquery};
use datafusion_expr::utils::{conjunction, split_conjunction_owned};
use datafusion_expr::{
exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
LogicalPlan, LogicalPlanBuilder, Operator,
};

Expand Down Expand Up @@ -342,7 +342,7 @@ fn build_join(
replace_qualified_name(filter, &all_correlated_cols, &alias).map(Some)
})?;

if let Some(join_filter) = match (join_filter_opt, in_predicate_opt) {
let join_filter = match (join_filter_opt, in_predicate_opt) {
(
Some(join_filter),
Some(Expr::BinaryExpr(BinaryExpr {
Expand All @@ -353,9 +353,9 @@ fn build_join(
) => {
let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col));
Some(in_predicate.and(join_filter))
in_predicate.and(join_filter)
}
(Some(join_filter), _) => Some(join_filter),
(Some(join_filter), _) => join_filter,
(
_,
Some(Expr::BinaryExpr(BinaryExpr {
Expand All @@ -366,24 +366,23 @@ fn build_join(
) => {
let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col));
Some(in_predicate)
in_predicate
}
_ => None,
} {
// join our sub query into the main plan
let new_plan = LogicalPlanBuilder::from(left.clone())
.join_on(sub_query_alias, join_type, Some(join_filter))?
.build()?;
debug!(
"predicate subquery optimized:\n{}",
new_plan.display_indent()
);
Ok(Some(new_plan))
} else {
Ok(None)
}
(None, None) => lit(true),
_ => return Ok(None),
};
// join our sub query into the main plan
let new_plan = LogicalPlanBuilder::from(left.clone())
.join_on(sub_query_alias, join_type, Some(join_filter))?
.build()?;
debug!(
"predicate subquery optimized:\n{}",
new_plan.display_indent()
);
Ok(Some(new_plan))
}

#[derive(Debug)]
struct SubqueryInfo {
query: Subquery,
where_in_expr: Option<Expr>,
Expand Down Expand Up @@ -429,6 +428,7 @@ mod tests {
use crate::test::*;

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_expr::builder::table_source;
use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan};

fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
Expand Down Expand Up @@ -1423,7 +1423,14 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;

assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan)
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n LeftSemi Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
\n Projection: orders.o_custkey [o_custkey:Int64]\
\n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_optimized_plan_equal(plan, expected)
}

/// Test for correlated exists subquery not equal
Expand Down Expand Up @@ -1609,13 +1616,14 @@ mod tests {
.build()?;

// not optimized
findepi marked this conversation as resolved.
Show resolved Hide resolved
let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
Filter: EXISTS (<subquery>) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
Subquery: [o_custkey:Int64]
Projection: orders.o_custkey [o_custkey:Int64]
Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]\
\n LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
\n Projection: orders.o_custkey [o_custkey:Int64]\
\n Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";

assert_optimized_plan_equal(plan, expected)
}
Expand Down Expand Up @@ -1654,7 +1662,13 @@ mod tests {
.project(vec![col("test.b")])?
.build()?;

assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan)
let expected = "Projection: test.b [b:UInt32]\
\n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
\n Projection: sq.c [c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_equal(plan, expected)
}

/// Test for single NOT exists subquery filter
Expand All @@ -1666,7 +1680,13 @@ mod tests {
.project(vec![col("test.b")])?
.build()?;

assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan)
let expected = "Projection: test.b [b:UInt32]\
\n LeftAnti Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
\n Projection: sq.c [c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_equal(plan, expected)
}

#[test]
Expand Down Expand Up @@ -1750,12 +1770,12 @@ mod tests {

// Subquery and outer query refer to the same table.
let expected = "Projection: test.b [b:UInt32]\
\n Filter: EXISTS (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
\n Subquery: [c:UInt32]\
\n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
\n Projection: test.c [c:UInt32]\
\n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
}
Expand Down Expand Up @@ -1844,6 +1864,69 @@ mod tests {
assert_optimized_plan_equal(plan, expected)
}

#[test]
fn exists_uncorrelated_unnest() -> Result<()> {
let subquery_table_source = table_source(&Schema::new(vec![Field::new(
"arr",
DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
true,
)]));
let subquery = LogicalPlanBuilder::scan_with_filters(
"sq",
subquery_table_source,
None,
vec![],
)?
.unnest_column("arr")?
.build()?;
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(exists(Arc::new(subquery)))?
.project(vec![col("test.b")])?
.build()?;

let expected = "Projection: test.b [b:UInt32]\
\n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [arr:Int32;N]\
\n Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N]\
\n TableScan: sq [arr:List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]";
assert_optimized_plan_equal(plan, expected)
}

#[test]
fn exists_correlated_unnest() -> Result<()> {
eprintln!("test start: exists_correlated_unnest");
findepi marked this conversation as resolved.
Show resolved Hide resolved
let table_scan = test_table_scan()?;
let subquery_table_source = table_source(&Schema::new(vec![Field::new(
"a",
DataType::List(Arc::new(Field::new_list_field(DataType::UInt32, true))),
true,
)]));
let subquery = LogicalPlanBuilder::scan_with_filters(
"sq",
subquery_table_source,
None,
vec![],
)?
.unnest_column("a")?
.filter(col("a").eq(out_ref_col(DataType::UInt32, "test.b")))?
.build()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(exists(Arc::new(subquery)))?
.project(vec![col("test.b")])?
.build()?;

let expected = "Projection: test.b [b:UInt32]\
\n LeftSemi Join: Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [a:UInt32;N]\
\n Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N]\
\n TableScan: sq [a:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]";

assert_optimized_plan_equal(plan, expected)
}

#[test]
fn upper_case_ident() -> Result<()> {
let fields = vec![
Expand Down
18 changes: 11 additions & 7 deletions datafusion/sqllogictest/test_files/explain.slt
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,17 @@ query TT
explain select a from t1 where exists (select count(*) from t2);
----
logical_plan
01)Filter: EXISTS (<subquery>)
02)--Subquery:
03)----Projection: count(*)
04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]
05)--------TableScan: t2
06)--TableScan: t1 projection=[a]
physical_plan_error This feature is not implemented: Physical plan does not support logical expression Exists(Exists { subquery: <subquery>, negated: false })
01)LeftSemi Join:
02)--TableScan: t1 projection=[a]
03)--SubqueryAlias: __correlated_sq_1
04)----Projection:
05)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]
06)--------TableScan: t2 projection=[]
physical_plan
01)NestedLoopJoinExec: join_type=LeftSemi
02)--MemoryExec: partitions=1, partition_sizes=[0]
03)--ProjectionExec: expr=[]
04)----PlaceholderRowExec

statement ok
drop table t1;
Expand Down
Loading