From cc35c8011f8c4982f27ca1bdf5d456c9d1eed44e Mon Sep 17 00:00:00 2001 From: Jonah Eisen Date: Tue, 14 Nov 2023 09:36:21 -0800 Subject: [PATCH] Make UDFs public when added to schema provider Restore the previous logic of modifying the UDF syntax tree at the time we add it to the schema provider. Also use the schema provider in the check udf flow. --- Cargo.lock | 83 ++++----- arroyo-console/src/udf_state.ts | 2 +- arroyo-controller/src/compiler.rs | 5 +- arroyo-controller/src/lib.rs | 126 +------------ arroyo-datastream/src/lib.rs | 3 + arroyo-rpc/proto/api.proto | 1 + arroyo-sql/Cargo.toml | 3 +- arroyo-sql/src/lib.rs | 283 +++++++++++++++++++----------- arroyo-sql/src/plan_graph.rs | 4 +- 9 files changed, 239 insertions(+), 271 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7f7f30513..775c4ddba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -484,7 +484,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "syn 2.0.33", + "syn 2.0.39", "thiserror", "time", "tokio", @@ -585,7 +585,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", - "syn 2.0.33", + "syn 2.0.39", "thiserror", "time", "tokio", @@ -614,7 +614,7 @@ dependencies = [ "rand", "regex", "serde", - "syn 2.0.33", + "syn 2.0.39", "tokio", "toml 0.7.8", "tonic", @@ -626,7 +626,7 @@ version = "0.7.0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -737,6 +737,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "petgraph", + "prettyplease 0.2.15", "proc-macro2", "quote", "regex", @@ -744,7 +745,7 @@ dependencies = [ "serde", "serde_json", "serde_json_path", - "syn 2.0.33", + "syn 2.0.39", "tokio", "tracing", "typify", @@ -761,7 +762,7 @@ dependencies = [ "quote", "runtime-macros-derive", "serde_json", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -785,7 +786,7 @@ dependencies = [ "quote", "serde", "serde_json", - "syn 2.0.33", + "syn 2.0.39", "test-log", "tokio", "tracing", @@ -1096,7 +1097,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -1113,7 +1114,7 @@ checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -1515,7 +1516,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -1933,7 +1934,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -2425,7 +2426,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f34ba9a9bcb8645379e9de8cb3ecfcf4d1c85ba66d90deb3259206fa5aa193b" dependencies = [ "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -2473,7 +2474,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -2495,7 +2496,7 @@ checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" dependencies = [ "darling_core 0.20.3", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -3072,7 +3073,7 @@ dependencies = [ "num-traits", "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -3422,7 +3423,7 @@ checksum = "1f1446badfb800d7c940c35627e4dc3c7fabaab093f009507f3d288430172c30" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -3630,7 +3631,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -4960,7 +4961,7 @@ checksum = "49e7bc1560b95a3c4a25d03de42fe76ca718ab92d1a22a55b9b4cf67b3ae635c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -5357,7 +5358,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -5693,7 +5694,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -5785,7 +5786,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -5865,7 +5866,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" dependencies = [ "proc-macro2", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -6316,7 +6317,7 @@ dependencies = [ "quote", "refinery-core", "regex", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -6721,7 +6722,7 @@ dependencies = [ "quote", "rust-embed-utils", "shellexpand", - "syn 2.0.33", + "syn 2.0.39", "walkdir", ] @@ -7052,7 +7053,7 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -7128,7 +7129,7 @@ checksum = "bb9387330da43020c17237e22c76bd19c93305c75d99ec962c58f385c7e1f5ad" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -7149,7 +7150,7 @@ checksum = "8725e1dfadb3a50f7e5ce0b1a540466f6ed3fe7a0fca2ac2b8b831d31316bd00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -7170,7 +7171,7 @@ dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -7211,7 +7212,7 @@ dependencies = [ "darling 0.20.3", "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -7567,7 +7568,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -7640,9 +7641,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.33" +version = "2.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9caece70c63bfba29ec2fed841a09851b14a235c60010fa4de58089b6c025668" +checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" dependencies = [ "proc-macro2", "quote", @@ -7794,7 +7795,7 @@ checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -7910,7 +7911,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -8265,7 +8266,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -8381,7 +8382,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -8411,7 +8412,7 @@ dependencies = [ "regress", "schemars", "serde_json", - "syn 2.0.33", + "syn 2.0.39", "thiserror", "unicode-ident", ] @@ -8427,7 +8428,7 @@ dependencies = [ "serde", "serde_json", "serde_tokenstream", - "syn 2.0.33", + "syn 2.0.39", "typify-impl", ] @@ -8560,7 +8561,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", ] [[package]] @@ -8682,7 +8683,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", "wasm-bindgen-shared", ] @@ -8716,7 +8717,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.39", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/arroyo-console/src/udf_state.ts b/arroyo-console/src/udf_state.ts index 6b93a8eff..79eb8c03f 100644 --- a/arroyo-console/src/udf_state.ts +++ b/arroyo-console/src/udf_state.ts @@ -137,7 +137,7 @@ export const getLocalUdfsContextValue = () => { `/*\n` + `[dependencies]\n\n` + `*/\n\n` + - `pub fn ${functionName}(x: i64) -> i64 {\n` + + `fn ${functionName}(x: i64) -> i64 {\n` + ' // Write your function here\n' + ' // Tip: rename the function to something descriptive\n\n' + '}'; diff --git a/arroyo-controller/src/compiler.rs b/arroyo-controller/src/compiler.rs index 0efb44736..792a2146a 100644 --- a/arroyo-controller/src/compiler.rs +++ b/arroyo-controller/src/compiler.rs @@ -1,5 +1,5 @@ +use crate::cargo_toml; use crate::states::fatal; -use crate::{cargo_toml, parse_dependencies}; use anyhow::{anyhow, Result}; use arroyo_datastream::{parse_type, Operator, Program, WasmBehavior}; use arroyo_rpc::grpc::compiler_grpc_client::CompilerGrpcClient; @@ -145,11 +145,10 @@ impl ProgramCompiler { .udfs .iter() .map(|udf| { - let dependencies = parse_dependencies(&udf.definition)?; Ok(UdfCrate { name: udf.name.to_string(), definition: udf.definition.to_string(), - cargo_toml: cargo_toml(&udf.name, &dependencies), + cargo_toml: cargo_toml(&udf.name, &udf.dependencies), }) }) .collect() diff --git a/arroyo-controller/src/lib.rs b/arroyo-controller/src/lib.rs index b111d8074..12603d154 100644 --- a/arroyo-controller/src/lib.rs +++ b/arroyo-controller/src/lib.rs @@ -2,7 +2,6 @@ // TODO: factor out complex types #![allow(clippy::type_complexity)] -use anyhow::bail; use arroyo_rpc::grpc::compiler_grpc_client::CompilerGrpcClient; use arroyo_rpc::grpc::controller_grpc_server::{ControllerGrpc, ControllerGrpcServer}; use arroyo_rpc::grpc::{ @@ -18,13 +17,13 @@ use arroyo_rpc::grpc::{ }; use arroyo_rpc::public_ids::{generate_id, IdTypes}; use arroyo_server_common::log_event; +use arroyo_sql::{parse_dependencies, ArroyoSchemaProvider}; use arroyo_types::{ from_micros, ports, DatabaseConfig, NodeId, WorkerId, REMOTE_COMPILER_ENDPOINT_ENV, }; use deadpool_postgres::{ManagerConfig, Pool, RecyclingMethod}; use lazy_static::lazy_static; use prometheus::{register_gauge, Gauge}; -use regex::Regex; use serde_json::json; use states::{Created, State, StateMachine}; use std::collections::{HashMap, HashSet}; @@ -32,7 +31,6 @@ use std::env; use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, Instant, SystemTime}; -use syn::{parse_file, Item}; use time::OffsetDateTime; use tokio::sync::broadcast; use tokio::sync::mpsc::error::TrySendError; @@ -459,45 +457,10 @@ impl ControllerGrpc for ControllerServer { } }; - let mut function_name = None; - { - let result = match parse_file(definition.as_str()) { - Ok(result) => result, - Err(e) => { - return Ok(udf_error_resp(e)); - } - }; - - for item in result.items { - match item { - Item::Fn(f) => { - if function_name.is_some() { - return Ok(Response::new(CheckUdfsResp { - errors: vec!["Only one function is allowed in UDFs".to_string()], - udf_name: None, - })); - } - function_name = Some(f.sig.ident.to_string()); - } - Item::Use(_) => {} - _ => { - return Ok(Response::new(CheckUdfsResp { - errors: vec![ - "Only functions and use statements are allowed in UDFs".to_string() - ], - udf_name: None, - })) - } - } - } - } - - // unwrap function or return error - let Some(function_name) = function_name else { - return Ok(Response::new(CheckUdfsResp { - errors: vec!["No function found in UDF".to_string()], - udf_name: None, - })); + // use the ArroyoSchemaProvider to do some validation and to get the function name + let function_name = match ArroyoSchemaProvider::new().add_rust_udf(&definition) { + Ok(function_name) => function_name, + Err(e) => return Ok(udf_error_resp(e)), }; // build cargo.toml @@ -745,82 +708,3 @@ where udf_name: None, }) } - -fn parse_dependencies(definition: &str) -> anyhow::Result { - // get content of dependencies comment using regex - let re = Regex::new(r"\/\*\n(\[dependencies\]\n[\s\S]*?)\*\/").unwrap(); - if re.find_iter(&definition).count() > 1 { - bail!("Only one dependencies definition is allowed in a UDF"); - } - - return if let Some(captures) = re.captures(&definition) { - if captures.len() != 2 { - bail!("Error parsing dependencies"); - } - Ok(captures.get(1).unwrap().as_str().to_string()) - } else { - Ok("[dependencies]\n# none defined\n".to_string()) - }; -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_dependencies_valid() { - let definition = r#" -/* -[dependencies] -serde = "1.0" -*/ - -pub fn my_udf() -> i64 { - 1 -} - "#; - - assert_eq!( - parse_dependencies(definition).unwrap(), - r#"[dependencies] -serde = "1.0" -"# - ); - } - - #[test] - fn test_parse_dependencies_none() { - let definition = r#" -pub fn my_udf() -> i64 { - 1 -} - "#; - - assert_eq!( - parse_dependencies(definition).unwrap(), - r#"[dependencies] -# none defined -"# - ); - } - - #[test] - fn test_parse_dependencies_multiple() { - let definition = r#" -/* -[dependencies] -serde = "1.0" -*/ - -/* -[dependencies] -serde = "1.0" -*/ - -pub fn my_udf() -> i64 { - 1 - - "#; - assert!(parse_dependencies(definition).is_err()); - } -} diff --git a/arroyo-datastream/src/lib.rs b/arroyo-datastream/src/lib.rs index 23fa18e36..0173606de 100644 --- a/arroyo-datastream/src/lib.rs +++ b/arroyo-datastream/src/lib.rs @@ -1081,6 +1081,7 @@ impl, &T1) -> T2> BiFunc pub struct ProgramUdf { pub name: String, pub definition: String, + pub dependencies: String, } #[derive(Encode, Decode, Clone, Debug)] @@ -1932,6 +1933,7 @@ impl Into for ProgramUdf { PipelineProgramUdf { name: self.name, definition: self.definition, + dependencies: self.dependencies, } } } @@ -1941,6 +1943,7 @@ impl Into for PipelineProgramUdf { ProgramUdf { name: self.name, definition: self.definition, + dependencies: self.dependencies, } } } diff --git a/arroyo-rpc/proto/api.proto b/arroyo-rpc/proto/api.proto index 132789fa3..52b7ce071 100644 --- a/arroyo-rpc/proto/api.proto +++ b/arroyo-rpc/proto/api.proto @@ -37,6 +37,7 @@ message CreateJobReq { message PipelineProgramUdf { string name = 1; string definition = 2; + string dependencies = 3; } message PipelineProgram { diff --git a/arroyo-sql/Cargo.toml b/arroyo-sql/Cargo.toml index df54a7705..db2c88a17 100644 --- a/arroyo-sql/Cargo.toml +++ b/arroyo-sql/Cargo.toml @@ -32,4 +32,5 @@ tracing = "0.1.37" typify = "0.0.13" schemars = "0.8" serde_json_path = "0.6.3" -apache-avro = "0.16.0" \ No newline at end of file +apache-avro = "0.16.0" +prettyplease = "0.2.4" diff --git a/arroyo-sql/src/lib.rs b/arroyo-sql/src/lib.rs index cf98a1032..cc050c334 100644 --- a/arroyo-sql/src/lib.rs +++ b/arroyo-sql/src/lib.rs @@ -45,10 +45,12 @@ use crate::types::{StructDef, StructField, TypeDef}; use arroyo_rpc::api_types::connections::{ConnectionSchema, ConnectionType}; use arroyo_rpc::formats::{Format, JsonFormat}; use datafusion_common::DataFusionError; +use prettyplease::unparse; +use regex::Regex; use std::collections::HashSet; use std::time::{Duration, SystemTime}; use std::{collections::HashMap, sync::Arc}; -use syn::{parse_quote, parse_str, FnArg, Item, ReturnType, Visibility}; +use syn::{parse_file, parse_quote, parse_str, FnArg, Item, ReturnType, Visibility}; use tracing::warn; const DEFAULT_IDLE_TIME: Option = Some(Duration::from_secs(5 * 60)); @@ -61,6 +63,7 @@ pub struct UdfDef { args: Vec, ret: TypeDef, def: String, + dependencies: String, } #[derive(Debug, Clone, Default)] @@ -234,126 +237,138 @@ impl ArroyoSchemaProvider { None } - pub fn add_rust_udf(&mut self, body: &str) -> Result<()> { - let file = syn::parse_file(body)?; + pub fn add_rust_udf(&mut self, body: &str) -> Result { + let mut file = parse_file(body)?; - if file - .items - .iter() - .filter(|item| matches!(item, Item::Fn(..))) - .count() - != 1 - { - bail!("UDF definition must contain exactly 1 function."); - }; + let mut functions = file.items.iter_mut().filter_map(|item| match item { + Item::Fn(function) => Some(function), + _ => None, + }); - for item in file.items { - let Item::Fn(mut function) = item else { - continue; - }; + let function = match (functions.next(), functions.next()) { + (Some(function), None) => function, + _ => bail!("UDF definition must contain exactly 1 function."), + }; - let name = function.sig.ident.to_string(); - let mut args: Vec = vec![]; - let mut vec_arguments = 0; - for (i, arg) in function.sig.inputs.iter().enumerate() { - match arg { - FnArg::Receiver(_) => { - bail!( - "Function {} has a 'self' argument, which is not allowed", - name - ) - } - FnArg::Typed(t) => { - if let Some(vec_type) = Self::vec_inner_type(&*t.ty) { - vec_arguments += 1; - args.push((&vec_type).try_into().map_err(|_| { + let name = function.sig.ident.to_string(); + let mut args: Vec = vec![]; + let mut vec_arguments = 0; + for (i, arg) in function.sig.inputs.iter().enumerate() { + match arg { + FnArg::Receiver(_) => { + bail!( + "Function {} has a 'self' argument, which is not allowed", + name + ) + } + FnArg::Typed(t) => { + if let Some(vec_type) = Self::vec_inner_type(&*t.ty) { + vec_arguments += 1; + args.push((&vec_type).try_into().map_err(|_| { anyhow!( "Could not convert function {} inner vector arg {} into a SQL data type", name, i ) })?); - } else { - args.push((&*t.ty).try_into().map_err(|_| { - anyhow!( - "Could not convert function {} arg {} into a SQL data type", - name, - i - ) - })?); - } + } else { + args.push((&*t.ty).try_into().map_err(|_| { + anyhow!( + "Could not convert function {} arg {} into a SQL data type", + name, + i + ) + })?); } } } + } - let ret: TypeDef = match &function.sig.output { - ReturnType::Default => bail!("Function {} return type must be specified", name), - ReturnType::Type(_, t) => (&**t).try_into().map_err(|_| { - anyhow!( - "Could not convert function {} return type into a SQL data type", - name - ) - })?, - }; - if vec_arguments > 0 && vec_arguments != args.len() { - bail!("Function {} arguments must be vectors or none", name); - } - if vec_arguments > 0 { - let return_type = Arc::new(ret.as_datatype().unwrap().clone()); - let name = function.sig.ident.to_string(); - let signature = Signature::exact( - args.iter() - .map(|t| t.as_datatype().unwrap().clone()) - .collect(), - Volatility::Volatile, + let ret: TypeDef = match &function.sig.output { + ReturnType::Default => bail!("Function {} return type must be specified", name), + ReturnType::Type(_, t) => (&**t).try_into().map_err(|_| { + anyhow!( + "Could not convert function {} return type into a SQL data type", + name + ) + })?, + }; + if vec_arguments > 0 && vec_arguments != args.len() { + bail!("Function {} arguments must be vectors or none", name); + } + if vec_arguments > 0 { + let return_type = Arc::new(ret.as_datatype().unwrap().clone()); + let name = function.sig.ident.to_string(); + let signature = Signature::exact( + args.iter() + .map(|t| t.as_datatype().unwrap().clone()) + .collect(), + Volatility::Volatile, + ); + let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); + let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unreachable!()); + let state_type: StateTypeFunction = Arc::new(|_| unreachable!()); + let udaf = + AggregateUDF::new(&name, &signature, &return_type, &accumulator, &state_type); + self.aggregate_functions + .insert(function.sig.ident.to_string(), Arc::new(udaf)); + } else { + let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); + + if self + .functions + .insert( + function.sig.ident.to_string(), + Arc::new(create_udf( + &function.sig.ident.to_string(), + args.iter() + .map(|t| t.as_datatype().unwrap().clone()) + .collect(), + Arc::new(ret.as_datatype().unwrap().clone()), + Volatility::Volatile, + make_scalar_function(fn_impl), + )), + ) + .is_some() + { + warn!( + "Global UDF '{}' is being overwritten", + function.sig.ident.to_string() ); - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unreachable!()); - let state_type: StateTypeFunction = Arc::new(|_| unreachable!()); - let udaf = - AggregateUDF::new(&name, &signature, &return_type, &accumulator, &state_type); - self.aggregate_functions - .insert(function.sig.ident.to_string(), Arc::new(udaf)); - } else { - let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); - - if self - .functions - .insert( - function.sig.ident.to_string(), - Arc::new(create_udf( - &function.sig.ident.to_string(), - args.iter() - .map(|t| t.as_datatype().unwrap().clone()) - .collect(), - Arc::new(ret.as_datatype().unwrap().clone()), - Volatility::Volatile, - make_scalar_function(fn_impl), - )), - ) - .is_some() - { - warn!( - "Global UDF '{}' is being overwritten", - function.sig.ident.to_string() - ); - }; - } + }; + } - function.vis = Visibility::Public(Default::default()); + function.vis = Visibility::Public(Default::default()); - self.udf_defs.insert( - function.sig.ident.to_string(), - UdfDef { - args, - ret, - def: body.to_string(), - }, - ); - } + self.udf_defs.insert( + function.sig.ident.to_string(), + UdfDef { + args, + ret, + def: unparse(&file.clone()), + dependencies: parse_dependencies(&body)?, + }, + ); + + Ok(name) + } +} - Ok(()) +pub fn parse_dependencies(definition: &str) -> Result { + // get content of dependencies comment using regex + let re = Regex::new(r"\/\*\n(\[dependencies\]\n[\s\S]*?)\*\/").unwrap(); + if re.find_iter(&definition).count() > 1 { + bail!("Only one dependencies definition is allowed in a UDF"); } + + return if let Some(captures) = re.captures(&definition) { + if captures.len() != 2 { + bail!("Error parsing dependencies"); + } + Ok(captures.get(1).unwrap().as_str().to_string()) + } else { + Ok("[dependencies]\n# none defined\n".to_string()) + }; } fn create_table_source(fields: Vec) -> Arc { @@ -738,3 +753,65 @@ pub fn has_duplicate_udf_names<'a>(definitions: impl Iterator } false } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_dependencies_valid() { + let definition = r#" +/* +[dependencies] +serde = "1.0" +*/ + +pub fn my_udf() -> i64 { + 1 +} + "#; + + assert_eq!( + parse_dependencies(definition).unwrap(), + r#"[dependencies] +serde = "1.0" +"# + ); + } + + #[test] + fn test_parse_dependencies_none() { + let definition = r#" +pub fn my_udf() -> i64 { + 1 +} + "#; + + assert_eq!( + parse_dependencies(definition).unwrap(), + r#"[dependencies] +# none defined +"# + ); + } + + #[test] + fn test_parse_dependencies_multiple() { + let definition = r#" +/* +[dependencies] +serde = "1.0" +*/ + +/* +[dependencies] +serde = "1.0" +*/ + +pub fn my_udf() -> i64 { + 1 + + "#; + assert!(parse_dependencies(definition).is_err()); + } +} diff --git a/arroyo-sql/src/plan_graph.rs b/arroyo-sql/src/plan_graph.rs index 3c85b0d42..6490a49ae 100644 --- a/arroyo-sql/src/plan_graph.rs +++ b/arroyo-sql/src/plan_graph.rs @@ -2025,11 +2025,13 @@ pub fn get_program( // add only the used udfs to the program let mut udfs: HashMap = HashMap::new(); used_udfs.iter().for_each(|u| { + let udf = schema_provider.udf_defs.get(u).unwrap(); udfs.insert( u.clone(), ProgramUdf { name: u.clone(), - definition: schema_provider.udf_defs.get(u).unwrap().def.clone(), + definition: udf.def.clone(), + dependencies: udf.dependencies.clone(), }, ); });