Skip to content

Commit

Permalink
test(13525): demonstrate extension node invariants catching improper …
Browse files Browse the repository at this point in the history
…mutation during an optimizer pass
  • Loading branch information
wiedld committed Jan 28, 2025
1 parent 876959d commit 3836444
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 5 deletions.
99 changes: 99 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,62 @@ async fn topk_invariants() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn topk_invariants_after_invalid_mutation() -> Result<()> {
// CONTROL
// Build a valid topK plan.
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(TopKQueryPlanner {}))
// 1. adds a valid TopKPlanNode
.with_optimizer_rule(Arc::new(TopKOptimizerRule {
invariant_mock: Some(InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Always,
}),
}))
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
.build();
let ctx = setup_table(SessionContext::new_with_state(state)).await?;
run_and_compare_query(ctx, "Topk context").await?;

// Test
// Build a valid topK plan.
// Then have an invalid mutation in an optimizer run.
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(TopKQueryPlanner {}))
// 1. adds a valid TopKPlanNode
.with_optimizer_rule(Arc::new(TopKOptimizerRule {
invariant_mock: Some(InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Always,
}),
}))
// 2. break the TopKPlanNode
.with_optimizer_rule(Arc::new(OptimizerMakeExtensionNodeInvalid {}))
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
.build();
let ctx = setup_table(SessionContext::new_with_state(state)).await?;
matches!(
&*run_and_compare_query(ctx, "Topk context")
.await
.unwrap_err()
.message(),
"node fails check, such as improper inputs"
);

Ok(())
}

fn make_topk_context() -> SessionContext {
make_topk_context_with_invariants(None)
}
Expand All @@ -366,6 +422,49 @@ fn make_topk_context_with_invariants(
SessionContext::new_with_state(state)
}

#[derive(Debug)]
struct OptimizerMakeExtensionNodeInvalid;

impl OptimizerRule for OptimizerMakeExtensionNodeInvalid {
fn name(&self) -> &str {
"OptimizerMakeExtensionNodeInvalid"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}

fn supports_rewrite(&self) -> bool {
true
}

// Example rewrite pass which impacts validity of the extension node.
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>, DataFusionError> {
if let LogicalPlan::Extension(Extension { node }) = &plan {
if let Some(prev) = node.as_any().downcast_ref::<TopKPlanNode>() {
return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
node: Arc::new(TopKPlanNode {
k: prev.k,
input: prev.input.clone(),
expr: prev.expr.clone(),
// In a real use case, this rewriter could have change the number of inputs, etc
invariant_mock: Some(InvariantMock {
should_fail_invariant: true,
kind: InvariantLevel::Always,
}),
}),
})));
}
};

Ok(Transformed::no(plan))
}
}

// ------ The implementation of the TopK code follows -----

#[derive(Debug)]
Expand Down
6 changes: 3 additions & 3 deletions datafusion/expr/src/logical_plan/invariants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ pub enum InvariantLevel {
Executable,
}

/// Apply the [`InvariantLevel::Always`] check at the root plan node only.
pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> {
/// Apply the [`InvariantLevel::Always`] check at the current plan node only.
pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> {
// Refer to <https://datafusion.apache.org/contributor-guide/specification/invariants.html#relation-name-tuples-in-logical-fields-and-logical-columns-are-unique>
assert_unique_field_names(plan)?;

Expand All @@ -73,7 +73,7 @@ pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> {
/// as well as the less stringent [`InvariantLevel::Always`] checks.
pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> {
// Always invariants
assert_always_invariants(plan)?;
assert_always_invariants_at_current_node(plan)?;
assert_valid_extension_nodes(plan, InvariantLevel::Always)?;

// Executable invariants
Expand Down
5 changes: 3 additions & 2 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ use std::sync::{Arc, LazyLock};

use super::dml::CopyTo;
use super::invariants::{
assert_always_invariants, assert_executable_invariants, InvariantLevel,
assert_always_invariants_at_current_node, assert_executable_invariants,
InvariantLevel,
};
use super::DdlStatement;
use crate::builder::{change_redundant_column, unnest_with_options};
Expand Down Expand Up @@ -1137,7 +1138,7 @@ impl LogicalPlan {
/// checks that the plan conforms to the listed invariant level, returning an Error if not
pub fn check_invariants(&self, check: InvariantLevel) -> Result<()> {
match check {
InvariantLevel::Always => assert_always_invariants(self),
InvariantLevel::Always => assert_always_invariants_at_current_node(self),
InvariantLevel::Executable => assert_executable_invariants(self),
}
}
Expand Down

0 comments on commit 3836444

Please sign in to comment.