Skip to content

Commit

Permalink
feat(13525): permit user-defined invariants on logical plan extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
wiedld committed Jan 22, 2025
1 parent 2aff98e commit 876959d
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 5 deletions.
106 changes: 103 additions & 3 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
//!
use std::fmt::Debug;
use std::hash::Hash;
use std::task::{Context, Poll};
use std::{any::Any, collections::BTreeMap, fmt, sync::Arc};

Expand Down Expand Up @@ -93,7 +94,7 @@ use datafusion::{
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::ScalarValue;
use datafusion_expr::{FetchType, Projection, SortExpr};
use datafusion_expr::{FetchType, Invariant, InvariantLevel, Projection, SortExpr};
use datafusion_optimizer::optimizer::ApplyOrder;
use datafusion_optimizer::AnalyzerRule;
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
Expand Down Expand Up @@ -295,15 +296,71 @@ async fn topk_plan() -> Result<()> {
Ok(())
}

#[tokio::test]
/// Run invariant checks on the logical plan extension [`TopKPlanNode`].
async fn topk_invariants() -> Result<()> {
// Test: pass an InvariantLevel::Always
let pass = InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Always,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?;
run_and_compare_query(ctx, "Topk context").await?;

// Test: fail an InvariantLevel::Always
let fail = InvariantMock {
should_fail_invariant: true,
kind: InvariantLevel::Always,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?;
matches!(
&*run_and_compare_query(ctx, "Topk context")
.await
.unwrap_err()
.message(),
"node fails check, such as improper inputs"
);

// Test: pass an InvariantLevel::Executable
let pass = InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Executable,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?;
run_and_compare_query(ctx, "Topk context").await?;

// Test: fail an InvariantLevel::Executable
let fail = InvariantMock {
should_fail_invariant: true,
kind: InvariantLevel::Executable,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?;
matches!(
&*run_and_compare_query(ctx, "Topk context")
.await
.unwrap_err()
.message(),
"node fails check, such as improper inputs"
);

Ok(())
}

fn make_topk_context() -> SessionContext {
make_topk_context_with_invariants(None)
}

fn make_topk_context_with_invariants(
invariant_mock: Option<InvariantMock>,
) -> SessionContext {
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(TopKQueryPlanner {}))
.with_optimizer_rule(Arc::new(TopKOptimizerRule {}))
.with_optimizer_rule(Arc::new(TopKOptimizerRule { invariant_mock }))
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
.build();
SessionContext::new_with_state(state)
Expand Down Expand Up @@ -336,7 +393,10 @@ impl QueryPlanner for TopKQueryPlanner {
}

#[derive(Default, Debug)]
struct TopKOptimizerRule {}
struct TopKOptimizerRule {
/// A testing-only hashable fixture.
invariant_mock: Option<InvariantMock>,
}

impl OptimizerRule for TopKOptimizerRule {
fn name(&self) -> &str {
Expand Down Expand Up @@ -380,6 +440,7 @@ impl OptimizerRule for TopKOptimizerRule {
k: fetch,
input: input.as_ref().clone(),
expr: expr[0].clone(),
invariant_mock: self.invariant_mock.clone(),
}),
})));
}
Expand All @@ -396,6 +457,10 @@ struct TopKPlanNode {
/// The sort expression (this example only supports a single sort
/// expr)
expr: SortExpr,

/// A testing-only hashable fixture.
/// For actual use, define the [`Invariant`] in the [`UserDefinedLogicalNodeCore::invariants`].
invariant_mock: Option<InvariantMock>,
}

impl Debug for TopKPlanNode {
Expand All @@ -406,6 +471,20 @@ impl Debug for TopKPlanNode {
}
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
struct InvariantMock {
should_fail_invariant: bool,
kind: InvariantLevel,
}

fn invariant_helper_mock_ok(_: &LogicalPlan) -> Result<()> {
Ok(())
}

fn invariant_helper_mock_fails(_: &LogicalPlan) -> Result<()> {
internal_err!("node fails check, such as improper inputs")
}

impl UserDefinedLogicalNodeCore for TopKPlanNode {
fn name(&self) -> &str {
"TopK"
Expand All @@ -420,6 +499,26 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
self.input.schema()
}

fn invariants(&self) -> Vec<Invariant> {
if let Some(InvariantMock {
should_fail_invariant,
kind,
}) = self.invariant_mock.clone()
{
if should_fail_invariant {
return vec![Invariant {
kind,
fun: Arc::new(invariant_helper_mock_fails),
}];
}
return vec![Invariant {
kind,
fun: Arc::new(invariant_helper_mock_ok),
}];
}
vec![] // same as default impl
}

fn expressions(&self) -> Vec<Expr> {
vec![self.expr.expr.clone()]
}
Expand All @@ -440,6 +539,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
k: self.k,
input: inputs.swap_remove(0),
expr: self.expr.with_expr(exprs.swap_remove(0)),
invariant_mock: self.invariant_mock.clone(),
})
}

Expand Down
31 changes: 31 additions & 0 deletions datafusion/expr/src/logical_plan/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::{any::Any, collections::HashSet, fmt, sync::Arc};

use super::invariants::Invariant;
use super::InvariantLevel;

/// This defines the interface for [`LogicalPlan`] nodes that can be
/// used to extend DataFusion with custom relational operators.
///
Expand Down Expand Up @@ -54,6 +57,22 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
/// Return the output schema of this logical plan node.
fn schema(&self) -> &DFSchemaRef;

/// Return the list of invariants.
///
/// Implementing this function enables the user to define the
/// invariants for a given logical plan extension.
fn invariants(&self) -> Vec<Invariant> {
vec![]
}

/// Perform check of invariants for the extension node.
fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> {
self.invariants()
.into_iter()
.filter(|inv| check == inv.kind)
.try_for_each(|inv| inv.check(plan))
}

/// Returns all expressions in the current logical plan node. This should
/// not include expressions of any inputs (aka non-recursively).
///
Expand Down Expand Up @@ -244,6 +263,14 @@ pub trait UserDefinedLogicalNodeCore:
/// Return the output schema of this logical plan node.
fn schema(&self) -> &DFSchemaRef;

/// Return the list of invariants.
///
/// Implementing this function enables the user to define the
/// invariants for a given logical plan extension.
fn invariants(&self) -> Vec<Invariant> {
vec![]
}

/// Returns all expressions in the current logical plan node. This
/// should not include expressions of any inputs (aka
/// non-recursively). These expressions are used for optimizer
Expand Down Expand Up @@ -336,6 +363,10 @@ impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
self.schema()
}

fn invariants(&self) -> Vec<Invariant> {
self.invariants()
}

fn expressions(&self) -> Vec<Expr> {
self.expressions()
}
Expand Down
57 changes: 56 additions & 1 deletion datafusion/expr/src/logical_plan/invariants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use datafusion_common::{
internal_err, plan_err,
tree_node::{TreeNode, TreeNodeRecursion},
Expand All @@ -28,6 +30,24 @@ use crate::{
Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window,
};

use super::Extension;

pub type InvariantFn = Arc<dyn Fn(&LogicalPlan) -> Result<()> + Send + Sync>;

#[derive(Clone)]
pub struct Invariant {
pub kind: InvariantLevel,
pub fun: InvariantFn,
}

impl Invariant {
/// Return an error if invariant does not hold true.
pub fn check(&self, plan: &LogicalPlan) -> Result<()> {
(self.fun)(plan)
}
}

#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum InvariantLevel {
/// Invariants that are always true in DataFusion `LogicalPlan`s
/// such as the number of expected children and no duplicated output fields
Expand All @@ -41,19 +61,54 @@ pub enum InvariantLevel {
Executable,
}

/// Apply the [`InvariantLevel::Always`] check at the root plan node only.
pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> {
// Refer to <https://datafusion.apache.org/contributor-guide/specification/invariants.html#relation-name-tuples-in-logical-fields-and-logical-columns-are-unique>
assert_unique_field_names(plan)?;

Ok(())
}

/// Visit the plan nodes, and confirm the [`InvariantLevel::Executable`]
/// as well as the less stringent [`InvariantLevel::Always`] checks.
pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> {
// Always invariants
assert_always_invariants(plan)?;
assert_valid_extension_nodes(plan, InvariantLevel::Always)?;

// Executable invariants
assert_valid_extension_nodes(plan, InvariantLevel::Executable)?;
assert_valid_semantic_plan(plan)?;
Ok(())
}

/// Asserts that the query plan, and subplan, extension nodes have valid invariants.
///
/// Refer to [`UserDefinedLogicalNode::check_invariants`](super::UserDefinedLogicalNode)
/// for more details of user-provided extension node invariants.
fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> {
plan.apply_with_subqueries(|plan: &LogicalPlan| {
if let LogicalPlan::Extension(Extension { node }) = plan {
node.check_invariants(check, plan)?;
}
plan.apply_expressions(|expr| {
// recursively look for subqueries
expr.apply(|expr| {
match expr {
Expr::Exists(Exists { subquery, .. })
| Expr::InSubquery(InSubquery { subquery, .. })
| Expr::ScalarSubquery(subquery) => {
assert_valid_extension_nodes(&subquery.subquery, check)?;
}
_ => {}
};
Ok(TreeNodeRecursion::Continue)
})
})
})
.map(|_| ())
}

/// Returns an error if plan, and subplans, do not have unique fields.
///
/// This invariant is subject to change.
Expand Down Expand Up @@ -87,7 +142,7 @@ pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Resul

/// Asserts that the subqueries are structured properly with valid node placement.
///
/// Refer to [`check_subquery_expr`] for more details.
/// Refer to [`check_subquery_expr`] for more details of the internal invariants.
fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> {
plan.apply_with_subqueries(|plan: &LogicalPlan| {
plan.apply_expressions(|expr| {
Expand Down
4 changes: 3 additions & 1 deletion datafusion/expr/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ pub mod display;
pub mod dml;
mod extension;
pub(crate) mod invariants;
pub use invariants::{assert_expected_schema, check_subquery_expr, InvariantLevel};
pub use invariants::{
assert_expected_schema, check_subquery_expr, Invariant, InvariantFn, InvariantLevel,
};
mod plan;
mod statement;
pub mod tree_node;
Expand Down

0 comments on commit 876959d

Please sign in to comment.