Skip to content

Commit

Permalink
wip: only used udfs in program
Browse files Browse the repository at this point in the history
  • Loading branch information
jbeisen committed Oct 27, 2023
1 parent cddd8c4 commit 14b6946
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 5 deletions.
3 changes: 3 additions & 0 deletions arroyo-compiler-service/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ impl CompilerGrpc for CompileService {
&self,
request: Request<CheckUdfsReq>,
) -> Result<Response<CheckUdfsCompilerResp>, Status> {
// only allow one request to be active at a given time
let _guard = self.lock.lock().await;

info!("Checking UDFs");
let start = Instant::now();

Expand Down
62 changes: 62 additions & 0 deletions arroyo-sql/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use datafusion_expr::{
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use regex::Regex;
use std::collections::HashSet;
use std::hash::Hash;
use std::{fmt::Debug, sync::Arc, time::Duration};
use syn::{parse_quote, parse_str, Ident, Path};

Expand Down Expand Up @@ -315,6 +317,44 @@ impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for Expression {
}

impl Expression {
pub fn udfs(&self) -> HashSet<String> {
return match self {
Expression::RustUdf(r) => {
let mut udfs = HashSet::new();
udfs.insert(r.name.clone());
udfs.extend(
r.args
.iter()
.flat_map(|(_, e)| e.udfs())
.collect::<HashSet<String>>(),
);
udfs
}
Expression::UnaryBoolean(u) => u.input.udfs(),
Expression::Column(_) => HashSet::new(),
Expression::Literal(_) => HashSet::new(),
Expression::BinaryComparison(b) => {
b.left.udfs().into_iter().chain(b.right.udfs()).collect()
}
Expression::BinaryMath(b) => b.left.udfs().into_iter().chain(b.right.udfs()).collect(),
Expression::StructField(s) => s.struct_expression.udfs(),
Expression::Aggregation(a) => a.producing_expression.udfs(),
Expression::Cast(c) => c.input.udfs(),
Expression::Numeric(n) => n.input.udfs(),
Expression::Date(d) => d.udfs(),
_ => HashSet::new(),
// TODO: the rest
// Expression::String(_) => {}
// Expression::Hash(_) => {}
// Expression::DataStructure(_) => {}
// Expression::Json(_) => {}
// Expression::WrapType(_) => {}
// Expression::Case(_) => {}
// Expression::WindowUDF(_) => {}
// Expression::Unnest(_, _) => {}
};
}

pub(crate) fn has_max_value(&self, field: &StructField) -> Option<u64> {
match self {
Expression::BinaryComparison(BinaryComparisonExpression { left, op, right }) => {
Expand Down Expand Up @@ -3558,6 +3598,15 @@ pub struct RustUdfExpression {
ret_type: TypeDef,
}

impl RustUdafExpression {
fn udfs(&self) -> Vec<String> {
self.args
.iter()
.flat_map(|(_, expression)| expression.udfs())
.collect()
}
}

impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for RustUdfExpression {
fn generate(&self, input_context: &ValuePointerContext) -> syn::Expr {
let name = format_ident!("{}", &self.name);
Expand Down Expand Up @@ -3617,6 +3666,10 @@ pub struct RustUdafExpression {
ret_type: TypeDef,
}
impl RustUdafExpression {
pub fn expressions(&self) -> Vec<Expression> {
self.args.iter().map(|(_, e)| e.clone()).collect()
}

fn try_from_aggregate_udf(
ctx: &mut ExpressionContext<'_>,
aggregate_udf: &AggregateUDF,
Expand Down Expand Up @@ -3953,6 +4006,15 @@ fn extract_literal_string(expr: Expression) -> Result<String, anyhow::Error> {
}

impl DateTimeFunction {
fn udfs(&self) -> HashSet<String> {
match self {
DateTimeFunction::DatePart(_, expr) | DateTimeFunction::DateTrunc(_, expr) => {
expr.udfs()
}
DateTimeFunction::FromUnixTime(expr) => expr.udfs(),
}
}

fn date_part(date_part: Expression, expr: Expression) -> anyhow::Result<Expression> {
let date_part = extract_literal_string(date_part)?
.as_str()
Expand Down
75 changes: 70 additions & 5 deletions arroyo-sql/src/plan_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use petgraph::graph::{DiGraph, NodeIndex};
use quote::{quote, ToTokens};
use syn::{parse_quote, parse_str, Type};

use crate::expressions::AggregateComputation;
use crate::{
code_gen::{
BinAggregatingContext, CodeGenerator, CombiningContext, JoinListsContext, JoinPairContext,
Expand Down Expand Up @@ -1861,6 +1862,70 @@ pub fn get_program(
key_structs.extend(key_names);
});

let mut expressions: Vec<Expression> = vec![];
let mut record_transforms: Vec<RecordTransform> = vec![];
for node in plan_graph.graph.node_weights() {
match &node.operator {
PlanOperator::Source(_, _) => {}
PlanOperator::Watermark(_) => {}
PlanOperator::RecordTransform(r) => {
record_transforms.push(r.clone());
}
PlanOperator::FusedRecordTransform(f) => {
record_transforms.extend(f.expressions.iter().cloned())
}
PlanOperator::Unkey => {}
PlanOperator::WindowAggregate { window, projection } => {
// let p = projection.aggregates.iter();
projection.aggregates.iter().for_each(|a| match a {
AggregateComputation::Builtin {
column,
computation,
} => {
expressions.push(*computation.producing_expression.clone());
}
AggregateComputation::UDAF {
column,
computation,
} => {
expressions.extend(computation.expressions());
}
});
}
PlanOperator::NonWindowAggregate { .. } => {}
PlanOperator::TumblingWindowTwoPhaseAggregator { .. } => {}
PlanOperator::SlidingWindowTwoPhaseAggregator { .. } => {}
PlanOperator::InstantJoin => {}
PlanOperator::JoinWithExpiration { .. } => {}
PlanOperator::JoinListMerge(_, _) => {}
PlanOperator::JoinPairMerge(_, _) => {}
PlanOperator::Flatten => {}
PlanOperator::WindowFunction(_) => {}
PlanOperator::TumblingLocalAggregator { .. } => {}
PlanOperator::SlidingAggregatingTopN { .. } => {}
PlanOperator::TumblingTopN { .. } => {}
PlanOperator::StreamOperator(_, _) => {}
PlanOperator::ToDebezium => {}
PlanOperator::FromDebezium => {}
PlanOperator::FromUpdating => {}
PlanOperator::Sink(_, _) => {}
}
}

let mut used_udfs = HashSet::new();
record_transforms.iter().for_each(|r| match r {
RecordTransform::ValueProjection(p) => p
.fields
.iter()
.for_each(|(c, e)| used_udfs.extend(e.udfs())),
// TODO: the rest
RecordTransform::KeyProjection(_) => {}
RecordTransform::UnnestProjection(_) => {}
RecordTransform::TimestampAssignment(_) => {}
RecordTransform::Filter(e) => used_udfs.extend(e.udfs()),
});
expressions.iter().for_each(|e| used_udfs.extend(e.udfs()));

// find all types that are produced by a source or consumed by a sink
let connector_types: HashSet<_> = plan_graph
.graph
Expand Down Expand Up @@ -1911,14 +1976,14 @@ pub fn get_program(
.map(|(_, v)| v),
);

// add only the used udfs to the program
let mut udfs = vec![];
udfs.push(
schema_provider
.udf_defs
.values()
.map(|u| u.def.as_str())
used_udfs
.iter()
.map(|u| schema_provider.udf_defs.get(u).unwrap().def.as_str())
.collect::<Vec<_>>()
.join("\n\n"),
.join("\n\n"), // TODO: why
);

let graph: DiGraph<StreamNode, StreamEdge> = plan_graph.into();
Expand Down

0 comments on commit 14b6946

Please sign in to comment.