diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 4ba9f21b641c..dcfb30198725 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -346,6 +346,62 @@ async fn topk_invariants() -> Result<()> { Ok(()) } +#[tokio::test] +async fn topk_invariants_after_invalid_mutation() -> Result<()> { + // CONTROL + // Build a valid topK plan. + let config = SessionConfig::new().with_target_partitions(48); + let runtime = Arc::new(RuntimeEnv::default()); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .with_query_planner(Arc::new(TopKQueryPlanner {})) + // 1. adds a valid TopKPlanNode + .with_optimizer_rule(Arc::new(TopKOptimizerRule { + invariant_mock: Some(InvariantMock { + should_fail_invariant: false, + kind: InvariantLevel::Always, + }), + })) + .with_analyzer_rule(Arc::new(MyAnalyzerRule {})) + .build(); + let ctx = setup_table(SessionContext::new_with_state(state)).await?; + run_and_compare_query(ctx, "Topk context").await?; + + // Test + // Build a valid topK plan. + // Then have an invalid mutation in an optimizer run. + let config = SessionConfig::new().with_target_partitions(48); + let runtime = Arc::new(RuntimeEnv::default()); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .with_query_planner(Arc::new(TopKQueryPlanner {})) + // 1. adds a valid TopKPlanNode + .with_optimizer_rule(Arc::new(TopKOptimizerRule { + invariant_mock: Some(InvariantMock { + should_fail_invariant: false, + kind: InvariantLevel::Always, + }), + })) + // 2. break the TopKPlanNode + .with_optimizer_rule(Arc::new(OptimizerMakeExtensionNodeInvalid {})) + .with_analyzer_rule(Arc::new(MyAnalyzerRule {})) + .build(); + let ctx = setup_table(SessionContext::new_with_state(state)).await?; + matches!( + &*run_and_compare_query(ctx, "Topk context") + .await + .unwrap_err() + .message(), + "node fails check, such as improper inputs" + ); + + Ok(()) +} + fn make_topk_context() -> SessionContext { make_topk_context_with_invariants(None) } @@ -366,6 +422,49 @@ fn make_topk_context_with_invariants( SessionContext::new_with_state(state) } +#[derive(Debug)] +struct OptimizerMakeExtensionNodeInvalid; + +impl OptimizerRule for OptimizerMakeExtensionNodeInvalid { + fn name(&self) -> &str { + "OptimizerMakeExtensionNodeInvalid" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn supports_rewrite(&self) -> bool { + true + } + + // Example rewrite pass which impacts validity of the extension node. + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result, DataFusionError> { + if let LogicalPlan::Extension(Extension { node }) = &plan { + if let Some(prev) = node.as_any().downcast_ref::() { + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: prev.k, + input: prev.input.clone(), + expr: prev.expr.clone(), + // In a real use case, this rewriter could have change the number of inputs, etc + invariant_mock: Some(InvariantMock { + should_fail_invariant: true, + kind: InvariantLevel::Always, + }), + }), + }))); + } + }; + + Ok(Transformed::no(plan)) + } +} + // ------ The implementation of the TopK code follows ----- #[derive(Debug)] diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index 4959dafdef99..fb50fbe42e81 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -61,8 +61,8 @@ pub enum InvariantLevel { Executable, } -/// Apply the [`InvariantLevel::Always`] check at the root plan node only. -pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> { +/// Apply the [`InvariantLevel::Always`] check at the current plan node only. +pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> { // Refer to assert_unique_field_names(plan)?; @@ -73,7 +73,7 @@ pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> { /// as well as the less stringent [`InvariantLevel::Always`] checks. pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> { // Always invariants - assert_always_invariants(plan)?; + assert_always_invariants_at_current_node(plan)?; assert_valid_extension_nodes(plan, InvariantLevel::Always)?; // Executable invariants diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 24fb0609b0fe..ebad1dcf9de4 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -25,7 +25,8 @@ use std::sync::{Arc, LazyLock}; use super::dml::CopyTo; use super::invariants::{ - assert_always_invariants, assert_executable_invariants, InvariantLevel, + assert_always_invariants_at_current_node, assert_executable_invariants, + InvariantLevel, }; use super::DdlStatement; use crate::builder::{change_redundant_column, unnest_with_options}; @@ -1137,7 +1138,7 @@ impl LogicalPlan { /// checks that the plan conforms to the listed invariant level, returning an Error if not pub fn check_invariants(&self, check: InvariantLevel) -> Result<()> { match check { - InvariantLevel::Always => assert_always_invariants(self), + InvariantLevel::Always => assert_always_invariants_at_current_node(self), InvariantLevel::Executable => assert_executable_invariants(self), } }