Skip to content

Commit

Permalink
wip: only used udfs in program
Browse files Browse the repository at this point in the history
  • Loading branch information
jbeisen committed Oct 27, 2023
1 parent cddd8c4 commit e21569f
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 6 deletions.
3 changes: 3 additions & 0 deletions arroyo-compiler-service/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ impl CompilerGrpc for CompileService {
&self,
request: Request<CheckUdfsReq>,
) -> Result<Response<CheckUdfsCompilerResp>, Status> {
// only allow one request to be active at a given time
let _guard = self.lock.lock().await;

info!("Checking UDFs");
let start = Instant::now();

Expand Down
60 changes: 59 additions & 1 deletion arroyo-sql/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use datafusion_expr::{
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use regex::Regex;
use std::collections::HashSet;
use std::hash::Hash;
use std::{fmt::Debug, sync::Arc, time::Duration};
use syn::{parse_quote, parse_str, Ident, Path};

Expand Down Expand Up @@ -315,6 +317,44 @@ impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for Expression {
}

impl Expression {
pub fn udfs(&self) -> HashSet<String> {
return match self {
Expression::RustUdf(r) => {
let mut udfs = HashSet::new();
udfs.insert(r.name.clone());
udfs.extend(
r.args
.iter()
.flat_map(|(_, e)| e.udfs())
.collect::<HashSet<String>>(),
);
udfs
}
Expression::UnaryBoolean(u) => u.input.udfs(),
Expression::Column(_) => HashSet::new(),
Expression::Literal(_) => HashSet::new(),
Expression::BinaryComparison(b) => {
b.left.udfs().into_iter().chain(b.right.udfs()).collect()
}
Expression::BinaryMath(b) => b.left.udfs().into_iter().chain(b.right.udfs()).collect(),
Expression::StructField(s) => s.struct_expression.udfs(),
Expression::Aggregation(a) => a.producing_expression.udfs(),
Expression::Cast(c) => c.input.udfs(),
Expression::Numeric(n) => n.input.udfs(),
Expression::Date(d) => d.udfs(),
_ => HashSet::new(),
// TODO: the rest
// Expression::String(_) => {}
// Expression::Hash(_) => {}
// Expression::DataStructure(_) => {}
// Expression::Json(_) => {}
// Expression::WrapType(_) => {}
// Expression::Case(_) => {}
// Expression::WindowUDF(_) => {}
// Expression::Unnest(_, _) => {}
};
}

pub(crate) fn has_max_value(&self, field: &StructField) -> Option<u64> {
match self {
Expression::BinaryComparison(BinaryComparisonExpression { left, op, right }) => {
Expand Down Expand Up @@ -3553,11 +3593,20 @@ impl JsonExpression {

#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd)]
pub struct RustUdfExpression {
name: String,
pub name: String,
args: Vec<(TypeDef, Expression)>,
ret_type: TypeDef,
}

impl RustUdafExpression {
fn udfs(&self) -> Vec<String> {
self.args
.iter()
.flat_map(|(_, expression)| expression.udfs())
.collect()
}
}

impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for RustUdfExpression {
fn generate(&self, input_context: &ValuePointerContext) -> syn::Expr {
let name = format_ident!("{}", &self.name);
Expand Down Expand Up @@ -3953,6 +4002,15 @@ fn extract_literal_string(expr: Expression) -> Result<String, anyhow::Error> {
}

impl DateTimeFunction {
fn udfs(&self) -> HashSet<String> {
match self {
DateTimeFunction::DatePart(_, expr) | DateTimeFunction::DateTrunc(_, expr) => {
expr.udfs()
}
DateTimeFunction::FromUnixTime(expr) => expr.udfs(),
}
}

fn date_part(date_part: Expression, expr: Expression) -> anyhow::Result<Expression> {
let date_part = extract_literal_string(date_part)?
.as_str()
Expand Down
2 changes: 2 additions & 0 deletions arroyo-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ use schemas::window_arrow_struct;
use tables::{schema_defs, ConnectorTable, Insert, Table};

use crate::code_gen::{CodeGenerator, ValuePointerContext};
use crate::plan_graph::PlanOperator;
use crate::types::{StructDef, StructField, TypeDef};
use arroyo_rpc::api_types::connections::{ConnectionSchema, ConnectionType};
use arroyo_rpc::formats::{Format, JsonFormat};
use datafusion_common::DataFusionError;
use quote::ToTokens;
use std::collections::HashSet;
use std::time::{Duration, SystemTime};
use std::{collections::HashMap, sync::Arc};
use syn::{parse_quote, parse_str, FnArg, Item, ReturnType, Visibility};
Expand Down
72 changes: 67 additions & 5 deletions arroyo-sql/src/plan_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use petgraph::graph::{DiGraph, NodeIndex};
use quote::{quote, ToTokens};
use syn::{parse_quote, parse_str, Type};

use crate::expressions::AggregateComputation;
use crate::{
code_gen::{
BinAggregatingContext, CodeGenerator, CombiningContext, JoinListsContext, JoinPairContext,
Expand Down Expand Up @@ -1861,6 +1862,63 @@ pub fn get_program(
key_structs.extend(key_names);
});

let mut record_transforms: Vec<RecordTransform> = vec![];
let mut expressions: Vec<Expression> = vec![];
for node in plan_graph.graph.node_weights() {
match &node.operator {
PlanOperator::Source(_, _) => {}
PlanOperator::Watermark(_) => {}
PlanOperator::RecordTransform(r) => {
record_transforms.push(r.clone());
}
PlanOperator::FusedRecordTransform(f) => {
record_transforms.extend(f.expressions.iter().cloned())
}
PlanOperator::Unkey => {}
PlanOperator::WindowAggregate { window, projection } => {
// let p = projection.aggregates.iter();
projection.aggregates.iter().for_each(|a| match a {
AggregateComputation::Builtin {
column,
computation,
} => {
expressions.push(*computation.producing_expression.clone());
}
AggregateComputation::UDAF { .. } => {}
});
}
PlanOperator::NonWindowAggregate { .. } => {}
PlanOperator::TumblingWindowTwoPhaseAggregator { .. } => {}
PlanOperator::SlidingWindowTwoPhaseAggregator { .. } => {}
PlanOperator::InstantJoin => {}
PlanOperator::JoinWithExpiration { .. } => {}
PlanOperator::JoinListMerge(_, _) => {}
PlanOperator::JoinPairMerge(_, _) => {}
PlanOperator::Flatten => {}
PlanOperator::WindowFunction(_) => {}
PlanOperator::TumblingLocalAggregator { .. } => {}
PlanOperator::SlidingAggregatingTopN { .. } => {}
PlanOperator::TumblingTopN { .. } => {}
PlanOperator::StreamOperator(_, _) => {}
PlanOperator::ToDebezium => {}
PlanOperator::FromDebezium => {}
PlanOperator::FromUpdating => {}
PlanOperator::Sink(_, _) => {}
}
}

let mut used_udfs = HashSet::new();
record_transforms.iter().for_each(|r| match r {
RecordTransform::ValueProjection(p) => p
.fields
.iter()
.for_each(|(c, e)| used_udfs.extend(e.udfs())),
RecordTransform::KeyProjection(_) => {}
RecordTransform::UnnestProjection(_) => {}
RecordTransform::TimestampAssignment(_) => {}
RecordTransform::Filter(e) => used_udfs.extend(e.udfs()),
});

// find all types that are produced by a source or consumed by a sink
let connector_types: HashSet<_> = plan_graph
.graph
Expand Down Expand Up @@ -1911,15 +1969,19 @@ pub fn get_program(
.map(|(_, v)| v),
);

println!("schema_provider udfs: {:?}", schema_provider.udf_defs);
println!("used udfs: {:?}", used_udfs);

// add only the used udfs to the program
let mut udfs = vec![];
udfs.push(
schema_provider
.udf_defs
.values()
.map(|u| u.def.as_str())
used_udfs
.iter()
.map(|u| schema_provider.udf_defs.get(u).unwrap().def.as_str())
.collect::<Vec<_>>()
.join("\n\n"),
.join("\n\n"), // TODO: why
);
println!("program udfs: {:?}", udfs);

let graph: DiGraph<StreamNode, StreamEdge> = plan_graph.into();

Expand Down

0 comments on commit e21569f

Please sign in to comment.