Skip to content

Commit

Permalink
Provide user-defined invariants for logical node extensions. (#14329)
Browse files Browse the repository at this point in the history
* feat(13525): permit user-defined invariants on logical plan extensions

* test(13525): demonstrate extension node invariants catching improper mutation during an optimizer pass

* chore: update docs

* refactor: remove the extra Invariant interface around an FnMut, since it doesn't make sense for the extension node's checks

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
wiedld and alamb authored Feb 4, 2025
1 parent cfc7c60 commit d8bc49f
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 10 deletions.
190 changes: 187 additions & 3 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
//!
use std::fmt::Debug;
use std::hash::Hash;
use std::task::{Context, Poll};
use std::{any::Any, collections::BTreeMap, fmt, sync::Arc};

Expand Down Expand Up @@ -93,7 +94,7 @@ use datafusion::{
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::ScalarValue;
use datafusion_expr::{FetchType, Projection, SortExpr};
use datafusion_expr::{FetchType, InvariantLevel, Projection, SortExpr};
use datafusion_optimizer::optimizer::ApplyOrder;
use datafusion_optimizer::AnalyzerRule;
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
Expand Down Expand Up @@ -295,20 +296,175 @@ async fn topk_plan() -> Result<()> {
Ok(())
}

#[tokio::test]
/// Run invariant checks on the logical plan extension [`TopKPlanNode`].
async fn topk_invariants() -> Result<()> {
// Test: pass an InvariantLevel::Always
let pass = InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Always,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?;
run_and_compare_query(ctx, "Topk context").await?;

// Test: fail an InvariantLevel::Always
let fail = InvariantMock {
should_fail_invariant: true,
kind: InvariantLevel::Always,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?;
matches!(
&*run_and_compare_query(ctx, "Topk context")
.await
.unwrap_err()
.message(),
"node fails check, such as improper inputs"
);

// Test: pass an InvariantLevel::Executable
let pass = InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Executable,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?;
run_and_compare_query(ctx, "Topk context").await?;

// Test: fail an InvariantLevel::Executable
let fail = InvariantMock {
should_fail_invariant: true,
kind: InvariantLevel::Executable,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?;
matches!(
&*run_and_compare_query(ctx, "Topk context")
.await
.unwrap_err()
.message(),
"node fails check, such as improper inputs"
);

Ok(())
}

#[tokio::test]
async fn topk_invariants_after_invalid_mutation() -> Result<()> {
// CONTROL
// Build a valid topK plan.
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(TopKQueryPlanner {}))
// 1. adds a valid TopKPlanNode
.with_optimizer_rule(Arc::new(TopKOptimizerRule {
invariant_mock: Some(InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Always,
}),
}))
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
.build();
let ctx = setup_table(SessionContext::new_with_state(state)).await?;
run_and_compare_query(ctx, "Topk context").await?;

// Test
// Build a valid topK plan.
// Then have an invalid mutation in an optimizer run.
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(TopKQueryPlanner {}))
// 1. adds a valid TopKPlanNode
.with_optimizer_rule(Arc::new(TopKOptimizerRule {
invariant_mock: Some(InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Always,
}),
}))
// 2. break the TopKPlanNode
.with_optimizer_rule(Arc::new(OptimizerMakeExtensionNodeInvalid {}))
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
.build();
let ctx = setup_table(SessionContext::new_with_state(state)).await?;
matches!(
&*run_and_compare_query(ctx, "Topk context")
.await
.unwrap_err()
.message(),
"node fails check, such as improper inputs"
);

Ok(())
}

fn make_topk_context() -> SessionContext {
make_topk_context_with_invariants(None)
}

fn make_topk_context_with_invariants(
invariant_mock: Option<InvariantMock>,
) -> SessionContext {
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(TopKQueryPlanner {}))
.with_optimizer_rule(Arc::new(TopKOptimizerRule {}))
.with_optimizer_rule(Arc::new(TopKOptimizerRule { invariant_mock }))
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
.build();
SessionContext::new_with_state(state)
}

#[derive(Debug)]
struct OptimizerMakeExtensionNodeInvalid;

impl OptimizerRule for OptimizerMakeExtensionNodeInvalid {
fn name(&self) -> &str {
"OptimizerMakeExtensionNodeInvalid"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}

fn supports_rewrite(&self) -> bool {
true
}

// Example rewrite pass which impacts validity of the extension node.
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>, DataFusionError> {
if let LogicalPlan::Extension(Extension { node }) = &plan {
if let Some(prev) = node.as_any().downcast_ref::<TopKPlanNode>() {
return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
node: Arc::new(TopKPlanNode {
k: prev.k,
input: prev.input.clone(),
expr: prev.expr.clone(),
// In a real use case, this rewriter could have change the number of inputs, etc
invariant_mock: Some(InvariantMock {
should_fail_invariant: true,
kind: InvariantLevel::Always,
}),
}),
})));
}
};

Ok(Transformed::no(plan))
}
}

// ------ The implementation of the TopK code follows -----

#[derive(Debug)]
Expand Down Expand Up @@ -336,7 +492,10 @@ impl QueryPlanner for TopKQueryPlanner {
}

#[derive(Default, Debug)]
struct TopKOptimizerRule {}
struct TopKOptimizerRule {
/// A testing-only hashable fixture.
invariant_mock: Option<InvariantMock>,
}

impl OptimizerRule for TopKOptimizerRule {
fn name(&self) -> &str {
Expand Down Expand Up @@ -380,6 +539,7 @@ impl OptimizerRule for TopKOptimizerRule {
k: fetch,
input: input.as_ref().clone(),
expr: expr[0].clone(),
invariant_mock: self.invariant_mock.clone(),
}),
})));
}
Expand All @@ -396,6 +556,10 @@ struct TopKPlanNode {
/// The sort expression (this example only supports a single sort
/// expr)
expr: SortExpr,

/// A testing-only hashable fixture.
/// For actual use, define the [`Invariant`] in the [`UserDefinedLogicalNodeCore::invariants`].
invariant_mock: Option<InvariantMock>,
}

impl Debug for TopKPlanNode {
Expand All @@ -406,6 +570,12 @@ impl Debug for TopKPlanNode {
}
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
struct InvariantMock {
should_fail_invariant: bool,
kind: InvariantLevel,
}

impl UserDefinedLogicalNodeCore for TopKPlanNode {
fn name(&self) -> &str {
"TopK"
Expand All @@ -420,6 +590,19 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
self.input.schema()
}

fn check_invariants(&self, check: InvariantLevel, _plan: &LogicalPlan) -> Result<()> {
if let Some(InvariantMock {
should_fail_invariant,
kind,
}) = self.invariant_mock.clone()
{
if should_fail_invariant && check == kind {
return internal_err!("node fails check, such as improper inputs");
}
}
Ok(())
}

fn expressions(&self) -> Vec<Expr> {
vec![self.expr.expr.clone()]
}
Expand All @@ -440,6 +623,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
k: self.k,
input: inputs.swap_remove(0),
expr: self.expr.with_expr(exprs.swap_remove(0)),
invariant_mock: self.invariant_mock.clone(),
})
}

Expand Down
20 changes: 20 additions & 0 deletions datafusion/expr/src/logical_plan/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::{any::Any, collections::HashSet, fmt, sync::Arc};

use super::InvariantLevel;

/// This defines the interface for [`LogicalPlan`] nodes that can be
/// used to extend DataFusion with custom relational operators.
///
Expand Down Expand Up @@ -54,6 +56,9 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
/// Return the output schema of this logical plan node.
fn schema(&self) -> &DFSchemaRef;

/// Perform check of invariants for the extension node.
fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()>;

/// Returns all expressions in the current logical plan node. This should
/// not include expressions of any inputs (aka non-recursively).
///
Expand Down Expand Up @@ -244,6 +249,17 @@ pub trait UserDefinedLogicalNodeCore:
/// Return the output schema of this logical plan node.
fn schema(&self) -> &DFSchemaRef;

/// Perform check of invariants for the extension node.
///
/// This is the default implementation for extension nodes.
fn check_invariants(
&self,
_check: InvariantLevel,
_plan: &LogicalPlan,
) -> Result<()> {
Ok(())
}

/// Returns all expressions in the current logical plan node. This
/// should not include expressions of any inputs (aka
/// non-recursively). These expressions are used for optimizer
Expand Down Expand Up @@ -336,6 +352,10 @@ impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
self.schema()
}

fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> {
self.check_invariants(check, plan)
}

fn expressions(&self) -> Vec<Expr> {
self.expressions()
}
Expand Down
46 changes: 43 additions & 3 deletions datafusion/expr/src/logical_plan/invariants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ use crate::{
Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window,
};

use super::Extension;

#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum InvariantLevel {
/// Invariants that are always true in DataFusion `LogicalPlan`s
/// such as the number of expected children and no duplicated output fields
Expand All @@ -41,19 +44,56 @@ pub enum InvariantLevel {
Executable,
}

pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> {
/// Apply the [`InvariantLevel::Always`] check at the current plan node only.
///
/// This does not recurs to any child nodes.
pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> {
// Refer to <https://datafusion.apache.org/contributor-guide/specification/invariants.html#relation-name-tuples-in-logical-fields-and-logical-columns-are-unique>
assert_unique_field_names(plan)?;

Ok(())
}

/// Visit the plan nodes, and confirm the [`InvariantLevel::Executable`]
/// as well as the less stringent [`InvariantLevel::Always`] checks.
pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> {
assert_always_invariants(plan)?;
// Always invariants
assert_always_invariants_at_current_node(plan)?;
assert_valid_extension_nodes(plan, InvariantLevel::Always)?;

// Executable invariants
assert_valid_extension_nodes(plan, InvariantLevel::Executable)?;
assert_valid_semantic_plan(plan)?;
Ok(())
}

/// Asserts that the query plan, and subplan, extension nodes have valid invariants.
///
/// Refer to [`UserDefinedLogicalNode::check_invariants`](super::UserDefinedLogicalNode)
/// for more details of user-provided extension node invariants.
fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> {
plan.apply_with_subqueries(|plan: &LogicalPlan| {
if let LogicalPlan::Extension(Extension { node }) = plan {
node.check_invariants(check, plan)?;
}
plan.apply_expressions(|expr| {
// recursively look for subqueries
expr.apply(|expr| {
match expr {
Expr::Exists(Exists { subquery, .. })
| Expr::InSubquery(InSubquery { subquery, .. })
| Expr::ScalarSubquery(subquery) => {
assert_valid_extension_nodes(&subquery.subquery, check)?;
}
_ => {}
};
Ok(TreeNodeRecursion::Continue)
})
})
})
.map(|_| ())
}

/// Returns an error if plan, and subplans, do not have unique fields.
///
/// This invariant is subject to change.
Expand Down Expand Up @@ -87,7 +127,7 @@ pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Resul

/// Asserts that the subqueries are structured properly with valid node placement.
///
/// Refer to [`check_subquery_expr`] for more details.
/// Refer to [`check_subquery_expr`] for more details of the internal invariants.
fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> {
plan.apply_with_subqueries(|plan: &LogicalPlan| {
plan.apply_expressions(|expr| {
Expand Down
Loading

0 comments on commit d8bc49f

Please sign in to comment.