From e21569fc314121dfe7a0fb6105c2a81e5d72210c Mon Sep 17 00:00:00 2001 From: Jonah Eisen Date: Fri, 27 Oct 2023 15:47:44 -0700 Subject: [PATCH] wip: only used udfs in program --- arroyo-compiler-service/src/main.rs | 3 ++ arroyo-sql/src/expressions.rs | 60 +++++++++++++++++++++++- arroyo-sql/src/lib.rs | 2 + arroyo-sql/src/plan_graph.rs | 72 +++++++++++++++++++++++++++-- 4 files changed, 131 insertions(+), 6 deletions(-) diff --git a/arroyo-compiler-service/src/main.rs b/arroyo-compiler-service/src/main.rs index f5b0f6147..e6527807e 100644 --- a/arroyo-compiler-service/src/main.rs +++ b/arroyo-compiler-service/src/main.rs @@ -270,6 +270,9 @@ impl CompilerGrpc for CompileService { &self, request: Request, ) -> Result, Status> { + // only allow one request to be active at a given time + let _guard = self.lock.lock().await; + info!("Checking UDFs"); let start = Instant::now(); diff --git a/arroyo-sql/src/expressions.rs b/arroyo-sql/src/expressions.rs index 2d3b6fca1..f0117008a 100644 --- a/arroyo-sql/src/expressions.rs +++ b/arroyo-sql/src/expressions.rs @@ -23,6 +23,8 @@ use datafusion_expr::{ use proc_macro2::TokenStream; use quote::{format_ident, quote}; use regex::Regex; +use std::collections::HashSet; +use std::hash::Hash; use std::{fmt::Debug, sync::Arc, time::Duration}; use syn::{parse_quote, parse_str, Ident, Path}; @@ -315,6 +317,44 @@ impl CodeGenerator for Expression { } impl Expression { + pub fn udfs(&self) -> HashSet { + return match self { + Expression::RustUdf(r) => { + let mut udfs = HashSet::new(); + udfs.insert(r.name.clone()); + udfs.extend( + r.args + .iter() + .flat_map(|(_, e)| e.udfs()) + .collect::>(), + ); + udfs + } + Expression::UnaryBoolean(u) => u.input.udfs(), + Expression::Column(_) => HashSet::new(), + Expression::Literal(_) => HashSet::new(), + Expression::BinaryComparison(b) => { + b.left.udfs().into_iter().chain(b.right.udfs()).collect() + } + Expression::BinaryMath(b) => b.left.udfs().into_iter().chain(b.right.udfs()).collect(), + Expression::StructField(s) => s.struct_expression.udfs(), + Expression::Aggregation(a) => a.producing_expression.udfs(), + Expression::Cast(c) => c.input.udfs(), + Expression::Numeric(n) => n.input.udfs(), + Expression::Date(d) => d.udfs(), + _ => HashSet::new(), + // TODO: the rest + // Expression::String(_) => {} + // Expression::Hash(_) => {} + // Expression::DataStructure(_) => {} + // Expression::Json(_) => {} + // Expression::WrapType(_) => {} + // Expression::Case(_) => {} + // Expression::WindowUDF(_) => {} + // Expression::Unnest(_, _) => {} + }; + } + pub(crate) fn has_max_value(&self, field: &StructField) -> Option { match self { Expression::BinaryComparison(BinaryComparisonExpression { left, op, right }) => { @@ -3553,11 +3593,20 @@ impl JsonExpression { #[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd)] pub struct RustUdfExpression { - name: String, + pub name: String, args: Vec<(TypeDef, Expression)>, ret_type: TypeDef, } +impl RustUdafExpression { + fn udfs(&self) -> Vec { + self.args + .iter() + .flat_map(|(_, expression)| expression.udfs()) + .collect() + } +} + impl CodeGenerator for RustUdfExpression { fn generate(&self, input_context: &ValuePointerContext) -> syn::Expr { let name = format_ident!("{}", &self.name); @@ -3953,6 +4002,15 @@ fn extract_literal_string(expr: Expression) -> Result { } impl DateTimeFunction { + fn udfs(&self) -> HashSet { + match self { + DateTimeFunction::DatePart(_, expr) | DateTimeFunction::DateTrunc(_, expr) => { + expr.udfs() + } + DateTimeFunction::FromUnixTime(expr) => expr.udfs(), + } + } + fn date_part(date_part: Expression, expr: Expression) -> anyhow::Result { let date_part = extract_literal_string(date_part)? .as_str() diff --git a/arroyo-sql/src/lib.rs b/arroyo-sql/src/lib.rs index 721fae45d..db7232732 100644 --- a/arroyo-sql/src/lib.rs +++ b/arroyo-sql/src/lib.rs @@ -40,11 +40,13 @@ use schemas::window_arrow_struct; use tables::{schema_defs, ConnectorTable, Insert, Table}; use crate::code_gen::{CodeGenerator, ValuePointerContext}; +use crate::plan_graph::PlanOperator; use crate::types::{StructDef, StructField, TypeDef}; use arroyo_rpc::api_types::connections::{ConnectionSchema, ConnectionType}; use arroyo_rpc::formats::{Format, JsonFormat}; use datafusion_common::DataFusionError; use quote::ToTokens; +use std::collections::HashSet; use std::time::{Duration, SystemTime}; use std::{collections::HashMap, sync::Arc}; use syn::{parse_quote, parse_str, FnArg, Item, ReturnType, Visibility}; diff --git a/arroyo-sql/src/plan_graph.rs b/arroyo-sql/src/plan_graph.rs index 1a92e70b6..84fa20b41 100644 --- a/arroyo-sql/src/plan_graph.rs +++ b/arroyo-sql/src/plan_graph.rs @@ -14,6 +14,7 @@ use petgraph::graph::{DiGraph, NodeIndex}; use quote::{quote, ToTokens}; use syn::{parse_quote, parse_str, Type}; +use crate::expressions::AggregateComputation; use crate::{ code_gen::{ BinAggregatingContext, CodeGenerator, CombiningContext, JoinListsContext, JoinPairContext, @@ -1861,6 +1862,63 @@ pub fn get_program( key_structs.extend(key_names); }); + let mut record_transforms: Vec = vec![]; + let mut expressions: Vec = vec![]; + for node in plan_graph.graph.node_weights() { + match &node.operator { + PlanOperator::Source(_, _) => {} + PlanOperator::Watermark(_) => {} + PlanOperator::RecordTransform(r) => { + record_transforms.push(r.clone()); + } + PlanOperator::FusedRecordTransform(f) => { + record_transforms.extend(f.expressions.iter().cloned()) + } + PlanOperator::Unkey => {} + PlanOperator::WindowAggregate { window, projection } => { + // let p = projection.aggregates.iter(); + projection.aggregates.iter().for_each(|a| match a { + AggregateComputation::Builtin { + column, + computation, + } => { + expressions.push(*computation.producing_expression.clone()); + } + AggregateComputation::UDAF { .. } => {} + }); + } + PlanOperator::NonWindowAggregate { .. } => {} + PlanOperator::TumblingWindowTwoPhaseAggregator { .. } => {} + PlanOperator::SlidingWindowTwoPhaseAggregator { .. } => {} + PlanOperator::InstantJoin => {} + PlanOperator::JoinWithExpiration { .. } => {} + PlanOperator::JoinListMerge(_, _) => {} + PlanOperator::JoinPairMerge(_, _) => {} + PlanOperator::Flatten => {} + PlanOperator::WindowFunction(_) => {} + PlanOperator::TumblingLocalAggregator { .. } => {} + PlanOperator::SlidingAggregatingTopN { .. } => {} + PlanOperator::TumblingTopN { .. } => {} + PlanOperator::StreamOperator(_, _) => {} + PlanOperator::ToDebezium => {} + PlanOperator::FromDebezium => {} + PlanOperator::FromUpdating => {} + PlanOperator::Sink(_, _) => {} + } + } + + let mut used_udfs = HashSet::new(); + record_transforms.iter().for_each(|r| match r { + RecordTransform::ValueProjection(p) => p + .fields + .iter() + .for_each(|(c, e)| used_udfs.extend(e.udfs())), + RecordTransform::KeyProjection(_) => {} + RecordTransform::UnnestProjection(_) => {} + RecordTransform::TimestampAssignment(_) => {} + RecordTransform::Filter(e) => used_udfs.extend(e.udfs()), + }); + // find all types that are produced by a source or consumed by a sink let connector_types: HashSet<_> = plan_graph .graph @@ -1911,15 +1969,19 @@ pub fn get_program( .map(|(_, v)| v), ); + println!("schema_provider udfs: {:?}", schema_provider.udf_defs); + println!("used udfs: {:?}", used_udfs); + + // add only the used udfs to the program let mut udfs = vec![]; udfs.push( - schema_provider - .udf_defs - .values() - .map(|u| u.def.as_str()) + used_udfs + .iter() + .map(|u| schema_provider.udf_defs.get(u).unwrap().def.as_str()) .collect::>() - .join("\n\n"), + .join("\n\n"), // TODO: why ); + println!("program udfs: {:?}", udfs); let graph: DiGraph = plan_graph.into();