Skip to content

Commit

Permalink
wip: save udfs
Browse files Browse the repository at this point in the history
  • Loading branch information
jbeisen committed Oct 26, 2023
1 parent e2007b4 commit e737624
Show file tree
Hide file tree
Showing 39 changed files with 2,132 additions and 456 deletions.
11 changes: 11 additions & 0 deletions arroyo-api/migrations/V17__udfs_table.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CREATE TABLE udfs (
pub_id VARCHAR PRIMARY KEY,
organization_id VARCHAR NOT NULL,
created_by VARCHAR NOT NULL,
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP NOT NULL,
name TEXT NOT NULL,
definition TEXT NOT NULL,
language TEXT NOT NULL,
description TEXT
);
27 changes: 27 additions & 0 deletions arroyo-api/queries/api_queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,30 @@ WHERE job_configs.organization_id = :organization_id AND job_configs.id = :job_i
) OR :starting_after = '')
ORDER BY jlm.created_at DESC
LIMIT :limit::integer;

----------- udfs -----------------------

--: DbUdf (description?)

--! create_udf
INSERT INTO udfs (pub_id, organization_id, created_by, name, language, definition, description)
VALUES (:pub_id, :organization_id, :created_by, :name, :language, :definition, :description);

--! get_udf: DbUdf
SELECT pub_id, name, definition, created_at, updated_at, language, description
FROM udfs
WHERE organization_id = :organization_id AND pub_id = :pub_id;

--! get_udf_by_name: DbUdf
SELECT pub_id, name, definition, created_at, updated_at, language, description
FROM udfs
WHERE organization_id = :organization_id AND name = :name;

--! get_udfs: DbUdf
SELECT pub_id, name, definition, created_at, updated_at, language, description
FROM udfs
WHERE organization_id = :organization_id;

--! delete_udf
DELETE FROM udfs
WHERE organization_id = :organization_id AND pub_id = :pub_id;
15 changes: 12 additions & 3 deletions arroyo-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ use crate::pipelines::__path_get_pipelines;
use crate::pipelines::__path_post_pipeline;
use crate::pipelines::{
__path_delete_pipeline, __path_get_pipeline, __path_get_pipeline_jobs, __path_patch_pipeline,
__path_restart_pipeline, __path_validate_query, __path_validate_udfs,
__path_restart_pipeline, __path_validate_query,
};
use crate::rest::__path_ping;
use crate::rest_utils::{bad_request, log_and_map, ErrorResp};
use crate::udfs::{__path_create_udf, __path_delete_udf, __path_get_udfs, __path_validate_udf};
use arroyo_rpc::api_types::{checkpoints::*, connections::*, metrics::*, pipelines::*, udfs::*, *};
use arroyo_rpc::formats::*;

mod cloud;
mod connection_profiles;
mod connection_tables;
Expand All @@ -37,6 +39,7 @@ mod optimizations;
mod pipelines;
pub mod rest;
mod rest_utils;
mod udfs;

include!(concat!(env!("OUT_DIR"), "/api-sql.rs"));

Expand Down Expand Up @@ -127,7 +130,7 @@ pub(crate) fn to_micros(dt: OffsetDateTime) -> u64 {
paths(
ping,
validate_query,
validate_udfs,
validate_udf,
post_pipeline,
patch_pipeline,
restart_pipeline,
Expand All @@ -150,6 +153,9 @@ pub(crate) fn to_micros(dt: OffsetDateTime) -> u64 {
test_schema,
get_confluent_schema,
get_checkpoint_details,
create_udf,
get_udfs,
delete_udf
),
components(schemas(
PipelinePost,
Expand Down Expand Up @@ -210,9 +216,12 @@ pub(crate) fn to_micros(dt: OffsetDateTime) -> u64 {
OperatorCheckpointGroup,
ValidateQueryPost,
QueryValidationResult,
ValidateUdfsPost,
ValidateUdfPost,
UdfValidationResult,
Udf,
UdfPost,
SharedUdf,
SharedUdfCollection,
)),
tags(
(name = "ping", description = "Ping endpoint"),
Expand Down
238 changes: 105 additions & 133 deletions arroyo-api/src/pipelines.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ use arroyo_rpc::api_types::pipelines::{
Job, Pipeline, PipelineEdge, PipelineGraph, PipelineNode, PipelinePatch, PipelinePost,
PipelineRestart, QueryValidationResult, StopType, ValidateQueryPost,
};
use arroyo_rpc::api_types::udfs::{UdfValidationResult, ValidateUdfsPost};
use arroyo_rpc::api_types::udfs::{SharedUdf, UdfLanguage, UdfValidationResult, ValidateUdfPost};
use arroyo_rpc::api_types::{JobCollection, PaginationQueryParams, PipelineCollection};
use arroyo_rpc::grpc::api::{
create_pipeline_req, CreateJobReq, CreatePipelineReq, CreateSqlJob, CreateUdf, PipelineProgram,
Udf, UdfLanguage,
Udf,
};
use arroyo_rpc::grpc::controller_grpc_client::ControllerGrpcClient;
use arroyo_rpc::grpc::{CheckUdfsReq, ValidationResult};
Expand Down Expand Up @@ -47,7 +47,9 @@ use create_pipeline_req::Config::Sql;
const DEFAULT_CHECKPOINT_INTERVAL: Duration = Duration::from_secs(10);

async fn compile_sql<'e, E>(
sql: &CreateSqlJob,
query: String,
local_udf_defs: Vec<String>,
parallelism: usize,
auth_data: &AuthData,
tx: &E,
) -> anyhow::Result<(Program, Vec<i64>)>
Expand All @@ -56,19 +58,36 @@ where
{
let mut schema_provider = ArroyoSchemaProvider::new();

for udf in sql.udfs.iter() {
match UdfLanguage::from_i32(udf.language) {
Some(UdfLanguage::Rust) => {
schema_provider
.add_rust_udf(&udf.definition)
.map_err(|e| anyhow!(format!("Could not process UDF: {:?}", e)))?;
}
None => {
return Err(anyhow!("Unsupported UDF language."));
let shared_udfs = api_queries::get_udfs()
.bind(tx, &auth_data.organization_id)
.all()
.await
.map_err(|e| anyhow!(format!("Error getting shared UDFs: {}", e)))?
.into_iter()
.map(|u| u.into())
.collect::<Vec<SharedUdf>>();

// TODO: add only the UDFs that are actually used in the query
for udf in shared_udfs {
match udf.language {
UdfLanguage::Rust => {
let _ = schema_provider.add_rust_udf(&udf.definition).map_err(|e| {
warn!(
"Could not process shared UDF {}: {:?}",
udf.name,
e.root_cause()
);
});
}
}
}

for udf_def in local_udf_defs.iter() {
schema_provider
.add_rust_udf(&udf_def)
.map_err(|e| anyhow!(format!("Could not process local UDF: {:?}", e)))?;
}

let tables = connection_tables::get_all_connection_tables(auth_data, tx)
.await
.map_err(|e| anyhow!(e.message))?;
Expand Down Expand Up @@ -97,10 +116,10 @@ where
}

let (program, connections) = arroyo_sql::parse_and_get_program(
&sql.query,
&query,
schema_provider,
SqlConfig {
default_parallelism: sql.parallelism as usize,
default_parallelism: parallelism,
},
)
.await
Expand Down Expand Up @@ -158,10 +177,18 @@ pub(crate) async fn create_pipeline<'a>(
)));
}

let udf_defs = sql.udfs.iter().map(|t| t.definition.to_string()).collect();

pipeline_type = PipelineType::sql;
(program, connections) = compile_sql(&sql, &auth, tx)
.await
.map_err(|e| bad_request(e.to_string()))?;
(program, connections) = compile_sql(
sql.query.clone(),
udf_defs,
sql.parallelism as usize,
&auth,
tx,
)
.await
.map_err(|e| bad_request(e.to_string()))?;
text = Some(sql.query);
udfs = Some(
sql.udfs
Expand Down Expand Up @@ -329,129 +356,74 @@ pub async fn validate_query(
let client = client(&state.pool).await?;
let auth_data = authenticate(&state.pool, bearer_auth).await?;

let sql = CreateSqlJob {
query: validate_query_post.query,
parallelism: 1,
udfs: validate_query_post
.udfs
.clone()
.unwrap_or(vec![])
.into_iter()
.map(|u| CreateUdf {
language: 0,
definition: u.definition.to_string(),
})
.collect(),
preview: false,
};
// let sql = CreateSqlJob {
// query: validate_query_post.query,
// parallelism: 1,
// udfs: validate_query_post
// .udfs
// .clone()
// .unwrap_or(vec![])
// .into_iter()
// .map(|u| CreateUdf {
// language: 0,
// definition: u.definition.to_string(),
// })
// .collect(),
// preview: false,
// };

let udf_defs = validate_query_post
.udfs
.clone()
.unwrap_or(vec![])
.into_iter()
.map(|u| u.definition.to_string())
.collect();

let pipeline_graph_validation_result = match compile_sql(&sql, &auth_data, &client).await {
Ok((mut program, _)) => {
optimizations::optimize(&mut program.graph);
let nodes = program
.graph
.node_weights()
.map(|node| PipelineNode {
node_id: node.operator_id.to_string(),
operator: format!("{:?}", node),
parallelism: node.clone().parallelism as u32,
})
.collect();

let edges = program
.graph
.edge_references()
.map(|edge| {
let src = program.graph.node_weight(edge.source()).unwrap();
let target = program.graph.node_weight(edge.target()).unwrap();
PipelineEdge {
src_id: src.operator_id.to_string(),
dest_id: target.operator_id.to_string(),
key_type: edge.weight().key.to_string(),
value_type: edge.weight().value.to_string(),
edge_type: format!("{:?}", edge.weight().typ),
}
})
.collect();
let pipeline_graph_validation_result =
match compile_sql(validate_query_post.query, udf_defs, 1, &auth_data, &client).await {
Ok((mut program, _)) => {
optimizations::optimize(&mut program.graph);
let nodes = program
.graph
.node_weights()
.map(|node| PipelineNode {
node_id: node.operator_id.to_string(),
operator: format!("{:?}", node),
parallelism: node.clone().parallelism as u32,
})
.collect();

let edges = program
.graph
.edge_references()
.map(|edge| {
let src = program.graph.node_weight(edge.source()).unwrap();
let target = program.graph.node_weight(edge.target()).unwrap();
PipelineEdge {
src_id: src.operator_id.to_string(),
dest_id: target.operator_id.to_string(),
key_type: edge.weight().key.to_string(),
value_type: edge.weight().value.to_string(),
edge_type: format!("{:?}", edge.weight().typ),
}
})
.collect();

QueryValidationResult {
graph: Some(PipelineGraph { nodes, edges }),
errors: None,
QueryValidationResult {
graph: Some(PipelineGraph { nodes, edges }),
errors: None,
}
}
}
Err(e) => QueryValidationResult {
graph: None,
errors: Some(vec![e.to_string()]),
},
};
Err(e) => QueryValidationResult {
graph: None,
errors: Some(vec![e.to_string()]),
},
};

Ok(Json(pipeline_graph_validation_result))
}

/// Validate UDFs
#[utoipa::path(
post,
path = "/v1/pipelines/validate_udfs",
tag = "pipelines",
request_body = ValidateUdfsPost,
responses(
(status = 200, description = "Validated query", body = UdfValidationResult),
),
)]
pub async fn validate_udfs(
State(state): State<AppState>,
bearer_auth: BearerAuth,
WithRejection(Json(validate_udfs_post), _): WithRejection<Json<ValidateUdfsPost>, ApiError>,
) -> Result<Json<UdfValidationResult>, ErrorResp> {
let _auth_data = authenticate(&state.pool, bearer_auth).await?;

// Return an ok (valid) if the controller is not available or if it fails to validate the UDFs

let mut controller = match ControllerGrpcClient::connect(state.controller_addr.clone()).await {
Ok(controller) => controller,
Err(e) => {
warn!(
"Failed to connect to controller, skipping UDF validation: {}",
e
);
return Ok(Json(UdfValidationResult {
udfs_rs: Some(validate_udfs_post.udfs_rs),
errors: None,
}));
}
};

let check_udfs_resp = match controller
.check_udfs(CheckUdfsReq {
udfs_rs: validate_udfs_post.udfs_rs.clone(),
})
.await
{
Ok(resp) => resp.into_inner(),
Err(e) => {
warn!("Controller failed to validate UDF: {}", e);
return Ok(Json(UdfValidationResult {
udfs_rs: Some(validate_udfs_post.udfs_rs),
errors: None,
}));
}
};

let udf_validation_result = if check_udfs_resp.result == ValidationResult::Error as i32 {
UdfValidationResult {
udfs_rs: None,
errors: Some(check_udfs_resp.errors),
}
} else {
UdfValidationResult {
udfs_rs: Some(validate_udfs_post.udfs_rs),
errors: None,
}
};

Ok(Json(udf_validation_result))
}

/// Create a new pipeline
///
/// The API will create a single job for the pipeline.
Expand Down
Loading

0 comments on commit e737624

Please sign in to comment.