From 963e5baf3ed58b016cd3613dfecd0a731719bbde Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Sun, 3 Dec 2023 16:26:56 -0800 Subject: [PATCH] lots of progress --- Cargo.lock | 155 +++++++---------- arroyo-api/src/pipelines.rs | 18 +- arroyo-rpc/src/api_types/pipelines.rs | 10 +- arroyo-sql-macro/src/lib.rs | 1 + arroyo-sql/src/lib.rs | 10 +- arroyo/Cargo.toml | 5 +- arroyo/src/query/mod.rs | 240 ++++++++++++++++++++++---- arroyo/src/query/model.rs | 204 +++++----------------- arroyo/src/query/runner.rs | 149 ++++++++++++++++ 9 files changed, 491 insertions(+), 301 deletions(-) create mode 100644 arroyo/src/query/runner.rs diff --git a/Cargo.lock b/Cargo.lock index f7461c707..9b4e2bdb4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -442,10 +442,11 @@ dependencies = [ "eventsource-client 0.12.0", "futures", "open", + "reedline", "reqwest", - "rustyline", "serde_json", "sqlparser 0.4.0", + "thiserror", "tokio", "tokio-stream", "tracing", @@ -1739,6 +1740,9 @@ name = "bitflags" version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +dependencies = [ + "serde", +] [[package]] name = "blake2" @@ -2062,17 +2066,6 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" -[[package]] -name = "clipboard-win" -version = "4.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7191c27c2357d9b7ef96baac1773290d4ca63b24205b82a3fd8a0637afcf0362" -dependencies = [ - "error-code", - "str-buf", - "winapi", -] - [[package]] name = "cmake" version = "0.1.50" @@ -2473,9 +2466,11 @@ checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" dependencies = [ "bitflags 2.4.1", "crossterm_winapi", + "futures-core", "libc", "mio", "parking_lot 0.12.1", + "serde", "signal-hook", "signal-hook-mio", "winapi", @@ -3174,12 +3169,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "endian-type" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" - [[package]] name = "enum-ordinalize" version = "3.1.15" @@ -3222,16 +3211,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "error-code" -version = "2.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64f18991e7bf11e7ffee451b5318b5c1a73c52d0d0ada6e5a3017c8c1ced6a21" -dependencies = [ - "libc", - "str-buf", -] - [[package]] name = "event-listener" version = "2.5.3" @@ -5261,15 +5240,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "nibble_vec" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" -dependencies = [ - "smallvec", -] - [[package]] name = "nix" version = "0.26.4" @@ -5324,6 +5294,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "nu-ansi-term" +version = "0.49.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c073d3c1930d0751774acf49e66653acecb416c3a54c6ec095a9b11caddb5a68" +dependencies = [ + "windows-sys 0.48.0", +] + [[package]] name = "num" version = "0.4.1" @@ -6428,16 +6407,6 @@ dependencies = [ "proc-macro2", ] -[[package]] -name = "radix_trie" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" -dependencies = [ - "endian-type", - "nibble_vec", -] - [[package]] name = "rand" version = "0.8.5" @@ -6589,6 +6558,26 @@ dependencies = [ "thiserror", ] +[[package]] +name = "reedline" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a093a20a6c473247c2e9971aaf4cedf9041bcd3f444dc7fad667d3b6b7a5fd" +dependencies = [ + "chrono", + "crossterm", + "fd-lock", + "itertools 0.10.5", + "nu-ansi-term 0.49.0", + "serde", + "strip-ansi-escapes", + "strum", + "strum_macros", + "thiserror", + "unicode-segmentation", + "unicode-width", +] + [[package]] name = "refinery" version = "0.8.11" @@ -7211,41 +7200,6 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" -[[package]] -name = "rustyline" -version = "12.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "994eca4bca05c87e86e15d90fc7a91d1be64b4482b38cb2d27474568fe7c9db9" -dependencies = [ - "bitflags 2.4.1", - "cfg-if", - "clipboard-win", - "fd-lock", - "home", - "libc", - "log", - "memchr", - "nix 0.26.4", - "radix_trie", - "rustyline-derive", - "scopeguard", - "unicode-segmentation", - "unicode-width", - "utf8parse", - "winapi", -] - -[[package]] -name = "rustyline-derive" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a32af5427251d2e4be14fc151eabe18abb4a7aad5efee7044da9f096c906a43" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.39", -] - [[package]] name = "ryu" version = "1.0.15" @@ -7897,12 +7851,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" -[[package]] -name = "str-buf" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e08d8363704e6c71fc928674353e6b7c23dcea9d82d7012c8faf2a3a025f8d0" - [[package]] name = "stringprep" version = "0.1.4" @@ -7914,6 +7862,15 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "strip-ansi-escapes" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ff8ef943b384c414f54aefa961dd2bd853add74ec75e7ac74cf91dba62bcfa" +dependencies = [ + "vte", +] + [[package]] name = "strsim" version = "0.10.0" @@ -8682,7 +8639,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" dependencies = [ "matchers", - "nu-ansi-term", + "nu-ansi-term 0.46.0", "once_cell", "regex", "sharded-slab", @@ -9045,6 +9002,26 @@ version = "0.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9dcc60c0624df774c82a0ef104151231d37da4962957d691c011c852b2473314" +[[package]] +name = "vte" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5022b5fbf9407086c180e9557be968742d839e68346af7792b8592489732197" +dependencies = [ + "utf8parse", + "vte_generate_state_changes", +] + +[[package]] +name = "vte_generate_state_changes" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d257817081c7dffcdbab24b9e62d2def62e2ff7d00b1c20062551e6cccc145ff" +dependencies = [ + "proc-macro2", + "quote", +] + [[package]] name = "waker-fn" version = "1.1.1" diff --git a/arroyo-api/src/pipelines.rs b/arroyo-api/src/pipelines.rs index c089fea24..395dc8e21 100644 --- a/arroyo-api/src/pipelines.rs +++ b/arroyo-api/src/pipelines.rs @@ -61,7 +61,7 @@ async fn compile_sql<'e, E>( parallelism: usize, auth_data: &AuthData, tx: &E, -) -> anyhow::Result +) -> anyhow::Result> where E: GenericClient, { @@ -293,7 +293,8 @@ pub(crate) async fn create_pipeline<'a>( tx, ) .await - .map_err(|e| bad_request(e.to_string()))?; + .map_err(|e| bad_request(e.to_string()))? + .ok_or_else(|| bad_request("The provided SQL does not contain a query"))?; text = Some(sql.query); udfs = Some(api_udfs); is_preview = sql.preview; @@ -469,7 +470,7 @@ pub async fn validate_query( let pipeline_graph_validation_result = match compile_sql(validate_query_post.query, &udfs, 1, &auth_data, &client).await { - Ok(CompiledSql { mut program, .. }) => { + Ok(Some(CompiledSql { mut program, .. })) => { optimizations::optimize(&mut program.graph); let nodes = program .graph @@ -499,12 +500,19 @@ pub async fn validate_query( QueryValidationResult { graph: Some(PipelineGraph { nodes, edges }), - errors: None, + errors: vec![], + missing_query: false, } } + Ok(None) => QueryValidationResult { + graph: None, + errors: vec![], + missing_query: true, + }, Err(e) => QueryValidationResult { graph: None, - errors: Some(vec![e.to_string()]), + errors: vec![e.to_string()], + missing_query: false, }, }; diff --git a/arroyo-rpc/src/api_types/pipelines.rs b/arroyo-rpc/src/api_types/pipelines.rs index a9496b7b7..b65faf021 100644 --- a/arroyo-rpc/src/api_types/pipelines.rs +++ b/arroyo-rpc/src/api_types/pipelines.rs @@ -2,6 +2,7 @@ use crate::api_types::udfs::Udf; use crate::grpc as grpc_proto; use crate::grpc::api as api_proto; use serde::{Deserialize, Serialize}; +use serde_json::Value; use utoipa::ToSchema; #[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] @@ -15,7 +16,9 @@ pub struct ValidateQueryPost { #[serde(rename_all = "camelCase")] pub struct QueryValidationResult { pub graph: Option, - pub errors: Option>, + #[serde(default)] + pub errors: Vec, + pub missing_query: bool, } #[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] @@ -165,7 +168,7 @@ pub struct OutputData { pub operator_id: String, pub timestamp: u64, pub key: String, - pub value: String, + pub value: Value, } impl From for OutputData { @@ -174,7 +177,8 @@ impl From for OutputData { operator_id: value.operator_id, timestamp: value.timestamp, key: value.key, - value: value.value, + value: serde_json::from_str(&value.value) + .expect("Received non-JSON data from web sink"), } } } diff --git a/arroyo-sql-macro/src/lib.rs b/arroyo-sql-macro/src/lib.rs index eeadf347f..b6cde2281 100644 --- a/arroyo-sql-macro/src/lib.rs +++ b/arroyo-sql-macro/src/lib.rs @@ -181,6 +181,7 @@ fn get_pipeline_module( }, ) .unwrap() + .unwrap() .program; let function = program.make_graph_function(); diff --git a/arroyo-sql/src/lib.rs b/arroyo-sql/src/lib.rs index e0ad6722b..1353079f0 100644 --- a/arroyo-sql/src/lib.rs +++ b/arroyo-sql/src/lib.rs @@ -450,11 +450,11 @@ pub async fn parse_and_get_program( query: &str, schema_provider: ArroyoSchemaProvider, config: SqlConfig, -) -> Result { +) -> Result> { let query = query.to_string(); if query.trim().is_empty() { - bail!("Query is empty"); + return Ok(None); } tokio::spawn(async move { parse_and_get_program_sync(query, schema_provider, config) }) @@ -466,7 +466,7 @@ pub fn parse_and_get_program_sync( query: String, mut schema_provider: ArroyoSchemaProvider, config: SqlConfig, -) -> Result { +) -> Result> { let dialect = PostgreSqlDialect {}; let mut inserts = vec![]; for statement in Parser::parse_sql(&dialect, &query)? { @@ -489,7 +489,7 @@ pub fn parse_and_get_program_sync( // if there are no insert nodes, return an error if sql_pipeline_builder.insert_nodes.is_empty() { - bail!("The provided SQL does not contain a query"); + return Ok(None); } // If there isn't a sink, add a web sink to the last insert @@ -526,7 +526,7 @@ pub fn parse_and_get_program_sync( plan_graph.add_sql_operator(output); } - get_program(plan_graph, sql_pipeline_builder.schema_provider.clone()) + get_program(plan_graph, sql_pipeline_builder.schema_provider.clone()).map(Some) } #[derive(Clone)] diff --git a/arroyo/Cargo.toml b/arroyo/Cargo.toml index 48252f749..69ebab80c 100644 --- a/arroyo/Cargo.toml +++ b/arroyo/Cargo.toml @@ -30,12 +30,13 @@ open = "5.0.0" reqwest = "0.11.20" tokio = { version = "1.32.0", features = ["full"] } tokio-stream = "0.1.14" -crossterm = "0.27.0" +crossterm = { version = "0.27.0", features = ["event-stream"] } comfy-table = "7.1.0" url = { version = "2.4.1", features = [] } -rustyline = { version = "12.0.0", features = ["derive"] } dirs = "5.0.1" eventsource-client = "0.12.0" futures = "0.3.29" serde_json = "1.0.108" tracing = "0.1.40" +reedline = "0.26.0" +thiserror = "1.0.50" diff --git a/arroyo/src/query/mod.rs b/arroyo/src/query/mod.rs index 0c074a711..fabd8e450 100644 --- a/arroyo/src/query/mod.rs +++ b/arroyo/src/query/mod.rs @@ -1,14 +1,25 @@ mod model; +mod runner; -use crate::query::model::QueryModel; +use crate::query::model::{QueryError, QueryModel}; +use crate::query::runner::{PipelineEvent, QueryRunner}; use crate::VERSION; use anyhow::{anyhow, bail, Context}; use arroyo_openapi::{Client, Error}; +use crossterm::event::Event::Key; +use crossterm::event::{Event, EventStream, KeyCode, KeyEvent, KeyModifiers}; +use crossterm::style::{style, Stylize}; +use crossterm::{cursor, execute, queue, style}; +use crossterm::{terminal, terminal::ClearType}; use eventsource_client::{Client as ESClient, SSE}; +use futures::future::FutureExt; +use reedline::{DefaultPrompt, FileBackedHistory, Reedline, Signal, ValidationResult, Validator}; use reqwest::ClientBuilder; -use rustyline::error::ReadlineError; -use rustyline::DefaultEditor; +use serde_json::Value; +use std::collections::HashMap; +use std::io::{stderr, stdout}; use std::time::Duration; +use tokio::select; use tokio_stream::StreamExt; use url::Url; @@ -55,63 +66,226 @@ pub async fn start_query(endpoint: Option) -> anyhow::Result<()> { } } - let mut session = QuerySession::new(client)?; + let mut session = QuerySession::new(&endpoint, client)?; session.run().await?; Ok(()) } +pub struct SQLValidator; + +impl Validator for SQLValidator { + fn validate(&self, line: &str) -> ValidationResult { + if !line.trim_end().ends_with(';') { + ValidationResult::Incomplete + } else { + ValidationResult::Complete + } + } +} +struct OutputTable { + columns: Vec<(String, usize)>, + width: usize, +} + +impl OutputTable { + pub fn new() -> Self { + OutputTable { + columns: vec![], + width: terminal::size().unwrap().0 as usize, + } + } + + pub fn add_data(&mut self, data: &HashMap) { + if self.columns.is_empty() { + self.columns = data + .iter() + .map(|(k, v)| (k.clone(), v.to_string().len())) + .collect(); + } + } +} + struct QuerySession { + endpoint: Url, client: Client, model: QueryModel, - editor: DefaultEditor, + editor: Reedline, } impl QuerySession { - fn new(client: Client) -> anyhow::Result { - Ok(Self { - model: QueryModel::new(client.clone()), - client, - editor: DefaultEditor::new().context("Failed to construct line editor")?, - }) - } - - async fn load_history(&mut self) -> anyhow::Result<()> { + fn new(endpoint: &str, client: Client) -> anyhow::Result { let config_dir = dirs::config_dir().unwrap_or_default().join("arroyo"); - tokio::fs::create_dir_all(&config_dir) - .await + std::fs::create_dir_all(&config_dir) .map_err(|_| anyhow!("failed to create config directory"))?; - let _ = self.editor.load_history(&config_dir.join("history.txt")); + let history = Box::new( + FileBackedHistory::with_file(5, config_dir.join("history.txt").into()) + .expect("Error configuring history with file"), + ); - Ok(()) + Ok(Self { + endpoint: Url::parse(&endpoint).unwrap(), + model: QueryModel::new(client.clone()), + editor: Reedline::create() + .with_validator(Box::new(SQLValidator)) + .with_ansi_colors(true) + .with_history(history), + client, + }) } - pub async fn run(&mut self) -> anyhow::Result<()> { - if let Err(e) = self.load_history().await { - eprintln!("Failed to load query history: {}", e); - } + pub async fn edit(&mut self) -> anyhow::Result> { + let prompt = DefaultPrompt::default(); loop { - match self.editor.readline("> ") { - Ok(line) => { - if self.model.push(&line) { - self.model.process_buffer().await?; + match self.editor.read_line(&prompt) { + Ok(Signal::Success(buffer)) => { + match self.model.process_buffer(buffer).await { + Ok(Some(query)) => { + return Ok(Some(query)); + } + Ok(_) => { + // updated our model with new DDL, but no query to execute + } + Err(err) => { + let msg = match err { + QueryError::ClientError(err) => { + format!("Error communicating with server: {}", err) + .red() + .to_string() + } + QueryError::InvalidQuery(errors) => { + format!( + "{}:\n{}", + "ERROR".bold().red(), + errors + .into_iter() + .map(|e| format!("• {}", e)) + .collect::>() + .join("\n") + .red() + ) + } + }; + eprintln!("{}", msg); + } } } - Err(ReadlineError::Interrupted) => { - println!("CTRL-C"); + Ok(Signal::CtrlC) | Ok(Signal::CtrlD) => { break; } - Err(ReadlineError::Eof) => { - println!("CTRL-D"); + Err(e) => { + eprintln!("Error: {:?}", e); break; } - Err(err) => { - println!("Error: {:?}", err); - break; + } + } + Ok(None) + } + + async fn handle_event(event: PipelineEvent) -> bool { + match event { + PipelineEvent::Finish => { + return false; + } + PipelineEvent::Error(e) => { + queue!( + stderr(), + style::PrintStyledContent("Error: ".red().bold()), + style::PrintStyledContent(e.red()) + ) + .unwrap(); + return false; + } + PipelineEvent::Warning(e) => { + queue!( + stderr(), + style::PrintStyledContent("Warning: ".yellow().bold()), + style::PrintStyledContent(e.yellow()) + ) + .unwrap(); + } + PipelineEvent::StateChange(state) => { + queue!( + stderr(), + cursor::MoveToColumn(0), + terminal::Clear(ClearType::CurrentLine), + style::PrintStyledContent("Job state: ".bold()), + style::Print(&state) + ) + .unwrap(); + + if state == "Running" { + println!(); + } + } + PipelineEvent::Output(data) => { + print!("| "); + println!("{:?}", data.value); + for (_, v) in data.value.as_object().unwrap() { + print!(" {} |", v.to_string()); } + println!(); + } + } + + true + } + + pub async fn run_query(&mut self, query: String) -> anyhow::Result<()> { + let mut runner = QueryRunner::run_query(self.client.clone(), &query).await?; + let mut reader = EventStream::new(); + + eprintln!("Running query as pipeline {}", runner.pipeline.id); + + let web_url = self + .endpoint + .join(&format!("pipelines/{}", runner.pipeline.id)) + .unwrap(); + eprintln!("Web UI: {}", web_url); + + loop { + let mut reader_events = reader.next().fuse(); + select! { + event = runner.rx.recv() => { + if let Some(event) = event { + if !Self::handle_event(event).await { + break; + } + } else { + break; + } + } + event = reader_events => { + match event { + Some(Ok(Key(key))) => { + if key.code == KeyCode::Char('c') && key.modifiers == KeyModifiers::CONTROL { + // TODO: stop job + break; + } + } + _ => { + // ignore + } + } + } + } + } + + Ok(()) + } + + pub async fn run(&mut self) -> anyhow::Result<()> { + loop { + let Some(query) = self.edit().await? else { + eprintln!("Exiting..."); + break; + }; + + if let Err(e) = self.run_query(query).await { + eprintln!("{}", format!("Error while running query:\n{}", e).red()); } } diff --git a/arroyo/src/query/model.rs b/arroyo/src/query/model.rs index 5df3d99a4..dfd76c783 100644 --- a/arroyo/src/query/model.rs +++ b/arroyo/src/query/model.rs @@ -1,191 +1,62 @@ -use anyhow::{anyhow, bail, Context}; -use arroyo_openapi::types::{Job, OutputData, Pipeline, PipelinePost, ValidateQueryPost}; +use anyhow::bail; +use arroyo_openapi::types::{QueryValidationResult, ValidateQueryPost}; use arroyo_openapi::Client; -use eventsource_client::{Client as ESClient, SSE}; -use serde_json::Value; use sqlparser::ast::{ObjectName, Statement}; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; use std::collections::HashMap; -use std::sync::mpsc::Receiver; -use std::time::Duration; -use tokio::sync::mpsc::Sender; -use tokio_stream::StreamExt; +use thiserror::Error; use tracing::warn; -struct RunningPipeline { - pipeline: Pipeline, - job: Job, - stream: Receiver, -} - pub struct QueryModel { ddl: HashMap, - current_line: String, client: Client, - running: Option, +} + +#[derive(Error, Debug)] +pub enum QueryError { + #[error("error communicating with server")] + ClientError(#[from] arroyo_openapi::Error), + + #[error("Query error")] + InvalidQuery(Vec), } impl QueryModel { pub fn new(client: Client) -> Self { Self { ddl: HashMap::new(), - current_line: String::new(), client, - running: None, } } - pub fn push(&mut self, s: &str) -> bool { - self.current_line.push_str(s); - self.current_line.push('\n'); - - return self.current_line.trim_end().ends_with(';'); - } - fn build_query(queries: &Vec, ddl: &HashMap) -> String { let mut statements: Vec<_> = ddl.iter().map(|(_, t)| t.to_string()).collect(); statements.extend(queries.iter().cloned()); - statements.join("\n") + statements.join(";\n") } - async fn validate_query(&self, query: &str) -> anyhow::Result> { - let result = self + async fn validate_query( + &self, + query: &str, + ) -> Result { + Ok(self .client .validate_query() .body(ValidateQueryPost::builder().query(query.to_string())) .send() .await? - .into_inner(); - - Ok(result.errors.unwrap_or_default()) - } - - async fn start_pipeline(&self, query: &str) -> anyhow::Result { - let validation = self - .client - .validate_query() - .body(ValidateQueryPost::builder().query(query)) - .send() - .await? - .into_inner(); - - if let Some(errors) = validation.errors { - eprintln!("Invalid query:"); - for error in errors { - eprintln!(" * {}", error); - } - bail!("failed to run query"); - } - - let pipeline = self - .client - .post_pipeline() - .body( - PipelinePost::builder() - .name("cli-pipeline") - .query(query) - .parallelism(1) - .preview(true), - ) - .send() - .await? - .into_inner(); - - Ok(pipeline.id) - } - - async fn get_job(&self, pipeline_id: &str) -> anyhow::Result { - self.client - .get_pipeline_jobs() - .id(pipeline_id) - .send() - .await? - .into_inner() - .data - .into_iter() - .next() - .ok_or_else(|| anyhow!("No job found for pipeline")) - } - - async fn run_query(&self, query: &str) -> anyhow::Result> { - println!("Starting pipeline..."); - let pipeline_id = self.start_pipeline(query).await?; - - let job = self.get_job(&pipeline_id).await?; - - let mut outputs = eventsource_client::ClientBuilder::for_url(&format!( - "{}/v1/pipelines/{}/jobs/{}/output", - self.client.baseurl(), - pipeline_id, - job.id - )) - .unwrap() - .build() - .stream(); - - let mut last_state = String::new(); - while last_state != "Running" { - let job = self - .get_job(&pipeline_id) - .await - .context("waiting for job startup")?; - - if job.state != last_state { - println!("Job entered {}", job.state); - last_state = job.state; - } - - tokio::time::sleep(Duration::from_millis(300)).await; - } - - loop { - match outputs.next().await { - Some(Ok(msg)) => { - match msg { - SSE::Event(e) => { - let Ok(data) = serde_json::from_str::(&e.data) else { - eprintln!("received invalid outputs from server"); - continue; - }; - - let Ok(fields) = serde_json::from_str::(&data.value) else { - eprintln!( - "received invalid data from output operator {}", - data.operator_id - ); - continue; - }; - - print!("| "); - for (k, value) in fields.as_object().unwrap() { - print!("{} | ", value); - } - print!("\n"); - } - SSE::Comment(_) => { - // ignore - } - } - } - Some(Err(msg)) => { - bail!("error while reading output: {}", msg); - } - None => { - println!("output completed"); - break; - } - } - } - - Ok(()) + .into_inner()) } - pub async fn process_buffer(&mut self) -> anyhow::Result>> { + fn parse_query( + &self, + query: &str, + ) -> anyhow::Result<(Vec, HashMap)> { let dialect = PostgreSqlDialect {}; - let ast = Parser::parse_sql(&dialect, self.current_line.clone())?; + let ast = Parser::parse_sql(&dialect, query.to_string())?; let mut queries = vec![]; @@ -233,22 +104,27 @@ impl QueryModel { } } + Ok((queries, new_ddl)) + } + + pub async fn process_buffer(&mut self, query: String) -> Result, QueryError> { + let (queries, new_ddl) = self + .parse_query(&query) + .map_err(|e| QueryError::InvalidQuery(vec![e.to_string()]))?; + let query = Self::build_query(&queries, &new_ddl); - let errors = self.validate_query(&query).await?; + let result = self.validate_query(&query).await?; - if errors.len() == 1 && &errors[0] == "The provided SQL does not contain a query" { + if result.missing_query + || result.errors.len() == 1 + && (&result.errors[0] == "The provided SQL does not contain a query" + || &result.errors[0] == "Query is empty") + { self.ddl = new_ddl; Ok(None) - } else if !errors.is_empty() { - bail!( - "{}", - errors - .iter() - .map(|e| format!("* {}", e)) - .collect::>() - .join("\n") - ) + } else if !result.errors.is_empty() { + Err(QueryError::InvalidQuery(result.errors)) } else { Ok(Some(query)) } diff --git a/arroyo/src/query/runner.rs b/arroyo/src/query/runner.rs new file mode 100644 index 000000000..08492c152 --- /dev/null +++ b/arroyo/src/query/runner.rs @@ -0,0 +1,149 @@ +use anyhow::{anyhow, bail, Context}; +use arroyo_openapi::types::{Job, OutputData, Pipeline, PipelinePost}; +use arroyo_openapi::Client; +use eventsource_client::{Client as ESClient, SSE}; +use std::future::Future; +use std::time::Duration; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio_stream::StreamExt; + +pub struct QueryRunner { + client: Client, + pub pipeline: Pipeline, + pub job: Job, + pub rx: Receiver, +} + +pub enum PipelineEvent { + StateChange(String), + Warning(String), + Error(String), + Output(OutputData), + Finish, +} + +impl QueryRunner { + async fn start_pipeline(client: &Client, query: &str) -> anyhow::Result { + let pipeline = client + .post_pipeline() + .body( + PipelinePost::builder() + .name("cli-pipeline") + .query(query) + .parallelism(1) + .preview(true), + ) + .send() + .await? + .into_inner(); + + Ok(pipeline) + } + + async fn get_job(client: &Client, pipeline_id: &str) -> anyhow::Result { + client + .get_pipeline_jobs() + .id(pipeline_id) + .send() + .await? + .into_inner() + .data + .into_iter() + .next() + .ok_or_else(|| anyhow!("No job found for pipeline")) + } + + pub async fn run_query(client: Client, query: &str) -> anyhow::Result { + let (tx, rx) = channel(128); + + let pipeline = Self::start_pipeline(&client, query).await?; + let job = Self::get_job(&client, &pipeline.id).await?; + + let pipeline_id = pipeline.id.clone(); + let job_id = job.id.clone(); + let inner_client = client.clone(); + + tokio::spawn(async move { + match Self::run_query_int(inner_client, pipeline_id, job_id, tx.clone()).await { + Ok(_) => { + tx.send(PipelineEvent::Finish).await.unwrap(); + } + Err(e) => { + tx.send(PipelineEvent::Error(e.to_string())).await.unwrap(); + } + } + }); + + Ok(Self { + client, + pipeline, + job, + rx, + }) + } + + async fn run_query_int( + client: Client, + pipeline_id: String, + job_id: String, + tx: Sender, + ) -> anyhow::Result<()> { + let mut outputs = eventsource_client::ClientBuilder::for_url(&format!( + "{}/v1/pipelines/{}/jobs/{}/output", + client.baseurl(), + pipeline_id, + job_id + )) + .unwrap() + .build() + .stream(); + + let mut last_state = String::new(); + while last_state != "Running" { + let job = Self::get_job(&client, &pipeline_id) + .await + .context("waiting for job startup")?; + + if job.state != last_state { + tx.send(PipelineEvent::StateChange(job.state.clone())) + .await + .unwrap(); + last_state = job.state; + } + + tokio::time::sleep(Duration::from_millis(300)).await; + } + + loop { + match outputs.next().await { + Some(Ok(msg)) => { + match msg { + SSE::Event(e) => { + let Ok(data) = serde_json::from_str::(&e.data) else { + tx.send(PipelineEvent::Warning( + "received invalid outputs from server".to_string(), + )) + .await + .unwrap(); + continue; + }; + + tx.send(PipelineEvent::Output(data)).await.unwrap(); + } + SSE::Comment(_) => { + // ignore + } + } + } + Some(Err(msg)) => { + bail!("error while reading output: {}", msg); + } + None => { + break; + } + } + } + + Ok(()) + } +}