Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide user-defined invariants for logical node extensions. #14329

Merged
merged 5 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 187 additions & 3 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
//!

use std::fmt::Debug;
use std::hash::Hash;
use std::task::{Context, Poll};
use std::{any::Any, collections::BTreeMap, fmt, sync::Arc};

Expand Down Expand Up @@ -93,7 +94,7 @@ use datafusion::{
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::ScalarValue;
use datafusion_expr::{FetchType, Projection, SortExpr};
use datafusion_expr::{FetchType, InvariantLevel, Projection, SortExpr};
use datafusion_optimizer::optimizer::ApplyOrder;
use datafusion_optimizer::AnalyzerRule;
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
Expand Down Expand Up @@ -295,20 +296,175 @@ async fn topk_plan() -> Result<()> {
Ok(())
}

#[tokio::test]
/// Run invariant checks on the logical plan extension [`TopKPlanNode`].
async fn topk_invariants() -> Result<()> {
Comment on lines +300 to +301
Copy link
Contributor Author

@wiedld wiedld Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test: demonstrate the basic use case. That user-defined invariants will fail for an invalid extension node.

// Test: pass an InvariantLevel::Always
let pass = InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Always,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?;
run_and_compare_query(ctx, "Topk context").await?;

// Test: fail an InvariantLevel::Always
let fail = InvariantMock {
should_fail_invariant: true,
kind: InvariantLevel::Always,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?;
matches!(
&*run_and_compare_query(ctx, "Topk context")
.await
.unwrap_err()
.message(),
"node fails check, such as improper inputs"
);

// Test: pass an InvariantLevel::Executable
let pass = InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Executable,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?;
run_and_compare_query(ctx, "Topk context").await?;

// Test: fail an InvariantLevel::Executable
let fail = InvariantMock {
should_fail_invariant: true,
kind: InvariantLevel::Executable,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?;
matches!(
&*run_and_compare_query(ctx, "Topk context")
.await
.unwrap_err()
.message(),
"node fails check, such as improper inputs"
);

Ok(())
}

#[tokio::test]
async fn topk_invariants_after_invalid_mutation() -> Result<()> {
Comment on lines +349 to +350
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test: demonstrate a failed invariant check after logical plan mutation (during optimizer run).

// CONTROL
// Build a valid topK plan.
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(TopKQueryPlanner {}))
// 1. adds a valid TopKPlanNode
.with_optimizer_rule(Arc::new(TopKOptimizerRule {
invariant_mock: Some(InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Always,
}),
}))
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
.build();
let ctx = setup_table(SessionContext::new_with_state(state)).await?;
run_and_compare_query(ctx, "Topk context").await?;

// Test
// Build a valid topK plan.
// Then have an invalid mutation in an optimizer run.
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(TopKQueryPlanner {}))
// 1. adds a valid TopKPlanNode
.with_optimizer_rule(Arc::new(TopKOptimizerRule {
invariant_mock: Some(InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Always,
}),
}))
// 2. break the TopKPlanNode
.with_optimizer_rule(Arc::new(OptimizerMakeExtensionNodeInvalid {}))
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
.build();
let ctx = setup_table(SessionContext::new_with_state(state)).await?;
matches!(
&*run_and_compare_query(ctx, "Topk context")
.await
.unwrap_err()
.message(),
"node fails check, such as improper inputs"
);

Ok(())
}

fn make_topk_context() -> SessionContext {
make_topk_context_with_invariants(None)
}

fn make_topk_context_with_invariants(
invariant_mock: Option<InvariantMock>,
) -> SessionContext {
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(TopKQueryPlanner {}))
.with_optimizer_rule(Arc::new(TopKOptimizerRule {}))
.with_optimizer_rule(Arc::new(TopKOptimizerRule { invariant_mock }))
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
.build();
SessionContext::new_with_state(state)
}

#[derive(Debug)]
struct OptimizerMakeExtensionNodeInvalid;

impl OptimizerRule for OptimizerMakeExtensionNodeInvalid {
fn name(&self) -> &str {
"OptimizerMakeExtensionNodeInvalid"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}

fn supports_rewrite(&self) -> bool {
true
}

// Example rewrite pass which impacts validity of the extension node.
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>, DataFusionError> {
if let LogicalPlan::Extension(Extension { node }) = &plan {
if let Some(prev) = node.as_any().downcast_ref::<TopKPlanNode>() {
return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
node: Arc::new(TopKPlanNode {
k: prev.k,
input: prev.input.clone(),
expr: prev.expr.clone(),
// In a real use case, this rewriter could have change the number of inputs, etc
invariant_mock: Some(InvariantMock {
should_fail_invariant: true,
kind: InvariantLevel::Always,
}),
}),
})));
}
};

Ok(Transformed::no(plan))
}
}

// ------ The implementation of the TopK code follows -----

#[derive(Debug)]
Expand Down Expand Up @@ -336,7 +492,10 @@ impl QueryPlanner for TopKQueryPlanner {
}

#[derive(Default, Debug)]
struct TopKOptimizerRule {}
struct TopKOptimizerRule {
/// A testing-only hashable fixture.
invariant_mock: Option<InvariantMock>,
}

impl OptimizerRule for TopKOptimizerRule {
fn name(&self) -> &str {
Expand Down Expand Up @@ -380,6 +539,7 @@ impl OptimizerRule for TopKOptimizerRule {
k: fetch,
input: input.as_ref().clone(),
expr: expr[0].clone(),
invariant_mock: self.invariant_mock.clone(),
}),
})));
}
Expand All @@ -396,6 +556,10 @@ struct TopKPlanNode {
/// The sort expression (this example only supports a single sort
/// expr)
expr: SortExpr,

/// A testing-only hashable fixture.
/// For actual use, define the [`Invariant`] in the [`UserDefinedLogicalNodeCore::invariants`].
invariant_mock: Option<InvariantMock>,
}

impl Debug for TopKPlanNode {
Expand All @@ -406,6 +570,12 @@ impl Debug for TopKPlanNode {
}
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
struct InvariantMock {
should_fail_invariant: bool,
kind: InvariantLevel,
}

impl UserDefinedLogicalNodeCore for TopKPlanNode {
fn name(&self) -> &str {
"TopK"
Expand All @@ -420,6 +590,19 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
self.input.schema()
}

fn check_invariants(&self, check: InvariantLevel, _plan: &LogicalPlan) -> Result<()> {
if let Some(InvariantMock {
should_fail_invariant,
kind,
}) = self.invariant_mock.clone()
{
if should_fail_invariant && check == kind {
return internal_err!("node fails check, such as improper inputs");
}
}
Ok(())
}

fn expressions(&self) -> Vec<Expr> {
vec![self.expr.expr.clone()]
}
Expand All @@ -440,6 +623,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
k: self.k,
input: inputs.swap_remove(0),
expr: self.expr.with_expr(exprs.swap_remove(0)),
invariant_mock: self.invariant_mock.clone(),
})
}

Expand Down
20 changes: 20 additions & 0 deletions datafusion/expr/src/logical_plan/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::{any::Any, collections::HashSet, fmt, sync::Arc};

use super::InvariantLevel;

/// This defines the interface for [`LogicalPlan`] nodes that can be
/// used to extend DataFusion with custom relational operators.
///
Expand Down Expand Up @@ -54,6 +56,9 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
/// Return the output schema of this logical plan node.
fn schema(&self) -> &DFSchemaRef;

/// Perform check of invariants for the extension node.
fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()>;

/// Returns all expressions in the current logical plan node. This should
/// not include expressions of any inputs (aka non-recursively).
///
Expand Down Expand Up @@ -244,6 +249,17 @@ pub trait UserDefinedLogicalNodeCore:
/// Return the output schema of this logical plan node.
fn schema(&self) -> &DFSchemaRef;

/// Perform check of invariants for the extension node.
///
/// This is the default implementation for extension nodes.
fn check_invariants(
&self,
_check: InvariantLevel,
_plan: &LogicalPlan,
) -> Result<()> {
Ok(())
}

/// Returns all expressions in the current logical plan node. This
/// should not include expressions of any inputs (aka
/// non-recursively). These expressions are used for optimizer
Expand Down Expand Up @@ -336,6 +352,10 @@ impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
self.schema()
}

fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> {
self.check_invariants(check, plan)
}

fn expressions(&self) -> Vec<Expr> {
self.expressions()
}
Expand Down
46 changes: 43 additions & 3 deletions datafusion/expr/src/logical_plan/invariants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ use crate::{
Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window,
};

use super::Extension;

#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum InvariantLevel {
/// Invariants that are always true in DataFusion `LogicalPlan`s
/// such as the number of expected children and no duplicated output fields
Expand All @@ -41,19 +44,56 @@ pub enum InvariantLevel {
Executable,
}

pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> {
/// Apply the [`InvariantLevel::Always`] check at the current plan node only.
wiedld marked this conversation as resolved.
Show resolved Hide resolved
///
/// This does not recurs to any child nodes.
pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> {
// Refer to <https://datafusion.apache.org/contributor-guide/specification/invariants.html#relation-name-tuples-in-logical-fields-and-logical-columns-are-unique>
assert_unique_field_names(plan)?;

Ok(())
}

/// Visit the plan nodes, and confirm the [`InvariantLevel::Executable`]
/// as well as the less stringent [`InvariantLevel::Always`] checks.
pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> {
Comment on lines +57 to 59
Copy link
Contributor Author

@wiedld wiedld Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documents the existing behavior.

  • The assert_always_invariants() (renamed to assert_always_invariants_at_current_node) assess only the current node, and does not assess the remaining DAG.
  • whereas the assert_executable_invariants() (a) visits the subplan, and (b) validates the always and executable.

Copy link
Contributor Author

@wiedld wiedld Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The does feel like a blurring of definitions. The previous decision was made based upon minimizing the performance impact; at the time, we wanted the "always" invariants to be be cheap and able to be checked more frequently.

However, as of now the frequency is:

  • LP always invariants checked before analyzer
  • LP executable (including always) invariants checked:
    • after analyzer
    • once before all optimizer runs
    • once after all optimizer runs

Should I undo this blurring, and have the assert_always_invariants also be recursive?

assert_always_invariants(plan)?;
// Always invariants
assert_always_invariants_at_current_node(plan)?;
assert_valid_extension_nodes(plan, InvariantLevel::Always)?;

// Executable invariants
assert_valid_extension_nodes(plan, InvariantLevel::Executable)?;
assert_valid_semantic_plan(plan)?;
Ok(())
}

/// Asserts that the query plan, and subplan, extension nodes have valid invariants.
///
/// Refer to [`UserDefinedLogicalNode::check_invariants`](super::UserDefinedLogicalNode)
/// for more details of user-provided extension node invariants.
fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written I think this this does a separate walk of the tree than assert_subqueries_are_valid

Maybe as a follow on PR we could unify the walks (so the tree gets walked once and all checks applied) rather than two separate plans

plan.apply_with_subqueries(|plan: &LogicalPlan| {
if let LogicalPlan::Extension(Extension { node }) = plan {
node.check_invariants(check, plan)?;
}
plan.apply_expressions(|expr| {
// recursively look for subqueries
expr.apply(|expr| {
match expr {
Expr::Exists(Exists { subquery, .. })
| Expr::InSubquery(InSubquery { subquery, .. })
| Expr::ScalarSubquery(subquery) => {
assert_valid_extension_nodes(&subquery.subquery, check)?;
}
_ => {}
};
Ok(TreeNodeRecursion::Continue)
})
})
})
.map(|_| ())
}

/// Returns an error if plan, and subplans, do not have unique fields.
///
/// This invariant is subject to change.
Expand Down Expand Up @@ -87,7 +127,7 @@ pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Resul

/// Asserts that the subqueries are structured properly with valid node placement.
///
/// Refer to [`check_subquery_expr`] for more details.
/// Refer to [`check_subquery_expr`] for more details of the internal invariants.
fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> {
plan.apply_with_subqueries(|plan: &LogicalPlan| {
plan.apply_expressions(|expr| {
Expand Down
Loading