diff --git a/Cargo.lock b/Cargo.lock index 55685015e..1f149ba57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -487,9 +487,11 @@ dependencies = [ "arroyo-storage", "arroyo-types", "axum", + "base64 0.13.1", "chrono", "eventsource-client", "futures", + "rand", "rdkafka", "regress", "reqwest", @@ -8170,9 +8172,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" +checksum = "88ad59a7560b41a70d191093a945f0b87bc1deeda46fb237479708a1d6b6cdfc" dependencies = [ "getrandom", ] diff --git a/arroyo-connectors/Cargo.toml b/arroyo-connectors/Cargo.toml index c3a20bca6..e80ac126c 100644 --- a/arroyo-connectors/Cargo.toml +++ b/arroyo-connectors/Cargo.toml @@ -34,3 +34,5 @@ futures = "0.3.28" tokio-tungstenite = { version = "0.20.1", features = ["native-tls"] } axum = {version = "0.6.12"} reqwest = "0.11.20" +rand = "0.8.5" +base64 = "0.13.1" diff --git a/arroyo-connectors/src/websocket.rs b/arroyo-connectors/src/websocket.rs index b9e262fd9..343c4e705 100644 --- a/arroyo-connectors/src/websocket.rs +++ b/arroyo-connectors/src/websocket.rs @@ -1,17 +1,21 @@ use std::convert::Infallible; +use std::str::FromStr; use std::time::Duration; use anyhow::anyhow; +use arroyo_rpc::api_types::connections::{ConnectionSchema, ConnectionType, TestSourceMessage}; use arroyo_rpc::OperatorConfig; +use arroyo_types::string_to_map; use axum::response::sse::Event; use futures::{SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; use tokio::sync::mpsc::Sender; +use tokio_tungstenite::tungstenite::handshake::client::generate_key; +use tokio_tungstenite::tungstenite::http::Uri; use tokio_tungstenite::{connect_async, tungstenite}; +use tungstenite::http::Request; use typify::import_types; -use arroyo_rpc::api_types::connections::{ConnectionSchema, ConnectionType, TestSourceMessage}; -use serde::{Deserialize, Serialize}; - use crate::{pull_opt, Connection, EmptyConfig}; use super::Connector; @@ -72,7 +76,55 @@ impl Connector for WebsocketConnector { } }; - let ws_stream = match connect_async(&table.endpoint).await { + let headers = + match string_to_map(table.headers.as_ref().map(|t| t.0.as_str()).unwrap_or("")) + .ok_or_else(|| anyhow!("Headers are invalid; should be comma-separated pairs")) + { + Ok(headers) => headers, + Err(e) => { + send(true, true, format!("Failed to parse headers: {:?}", e)).await; + return; + } + }; + + let uri = match Uri::from_str(&table.endpoint.to_string()) { + Ok(uri) => uri, + Err(e) => { + send(true, true, format!("Failed to parse endpoint: {:?}", e)).await; + return; + } + }; + + let host = match uri.host() { + Some(host) => host, + None => { + send(true, true, "Endpoint must have a host".to_string()).await; + return; + } + }; + + let mut request_builder = Request::builder().uri(&table.endpoint); + + for (k, v) in headers { + request_builder = request_builder.header(k, v); + } + + let request = match request_builder + .header("Host", host) + .header("Sec-WebSocket-Key", generate_key()) + .header("Sec-WebSocket-Version", "13") + .header("Connection", "Upgrade") + .header("Upgrade", "websocket") + .body(()) + { + Ok(request) => request, + Err(e) => { + send(true, true, format!("Failed to build request: {:?}", e)).await; + return; + } + }; + + let ws_stream = match connect_async(request).await { Ok((ws_stream, _)) => ws_stream, Err(e) => { send( @@ -94,7 +146,7 @@ impl Connector for WebsocketConnector { let (mut tx, mut rx) = ws_stream.split(); - if let Some(msg) = table.subscription_message { + for msg in table.subscription_messages { match tx .send(tungstenite::Message::Text(msg.clone().into())) .await @@ -159,6 +211,15 @@ impl Connector for WebsocketConnector { ) -> anyhow::Result { let description = format!("WebsocketSource<{}>", table.endpoint); + if let Some(headers) = &table.headers { + string_to_map(headers).ok_or_else(|| { + anyhow!( + "Invalid format for headers; should be a \ + comma-separated list of colon-separated key value pairs" + ) + })?; + } + let schema = schema .map(|s| s.to_owned()) .ok_or_else(|| anyhow!("no schema defined for WebSocket connection"))?; @@ -195,7 +256,26 @@ impl Connector for WebsocketConnector { schema: Option<&ConnectionSchema>, ) -> anyhow::Result { let endpoint = pull_opt("endpoint", opts)?; - let subscription_message = opts.remove("subscription_message"); + let headers = opts.remove("headers"); + let mut subscription_messages = vec![]; + + // add the single subscription message if it exists + if let Some(message) = opts.remove("subscription_message") { + subscription_messages.push(SubscriptionMessage(message)); + + if opts.contains_key("subscription_messages.0") { + return Err(anyhow!( + "Cannot specify both 'subscription_message' and 'subscription_messages.0'" + )); + } + } + + // add the indexed subscription messages if they exist + let mut message_index = 0; + while let Some(message) = opts.remove(&format!("subscription_messages.{}", message_index)) { + subscription_messages.push(SubscriptionMessage(message)); + message_index += 1; + } self.from_config( None, @@ -203,7 +283,9 @@ impl Connector for WebsocketConnector { EmptyConfig {}, WebsocketTable { endpoint, - subscription_message: subscription_message.map(SubscriptionMessage), + headers: headers.map(Headers), + subscription_message: None, + subscription_messages, }, schema, ) diff --git a/arroyo-console/src/routes/connections/JsonForm.tsx b/arroyo-console/src/routes/connections/JsonForm.tsx index 804805d9e..e005e1b09 100644 --- a/arroyo-console/src/routes/connections/JsonForm.tsx +++ b/arroyo-console/src/routes/connections/JsonForm.tsx @@ -3,10 +3,12 @@ import { AlertIcon, Box, Button, + Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, + IconButton, Input, Select, Stack, @@ -17,7 +19,8 @@ import { useFormik } from 'formik'; import Ajv from 'ajv'; import addFormats from 'ajv-formats'; -import { useEffect, useMemo } from 'react'; +import React, { useEffect, useMemo } from 'react'; +import { AddIcon, DeleteIcon } from '@chakra-ui/icons'; function StringWidget({ path, @@ -158,6 +161,102 @@ function SelectWidget({ ); } +export function ArrayWidget({ + schema, + onChange, + path, + values, + errors, +}: { + schema: JSONSchema7; + onChange: (e: React.ChangeEvent) => void; + path: string; + values: any; + errors: any; +}) { + const add = () => { + // @ts-ignore + onChange({ target: { name: path, value: [...(values || []), undefined] } }); + }; + + const deleteItem = (index: number) => { + values.splice(index, 1); + // @ts-ignore + onChange({ target: { name: path, value: values } }); + }; + + const itemsSchema = schema.items as JSONSchema7; + + const example = + itemsSchema.examples && Array.isArray(itemsSchema.examples) + ? (itemsSchema.examples[0] as string) + : undefined; + + const arrayItem = (v: string, i: number) => { + switch (itemsSchema.type) { + case 'string': + return ( + + ); + default: + console.warn('Unsupported array item type', itemsSchema.type); + return <>; + } + }; + + return ( + +
+ + {schema.title} + + + + {errors[path] ? ( + {errors[path]} + ) : ( + schema.description && ( + + {schema.description} + + ) + )} + {values?.map((value: any, index: number) => ( + + {arrayItem(value, index)} + deleteItem(index)} + icon={} + /> + + ))} + } /> + + +
+
+ ); +} + export function FormInner({ schema, onChange, @@ -180,137 +279,156 @@ export function FormInner({ return ( - {Object.keys(schema.properties || {}).map(key => { - const property = schema.properties![key]; - if (typeof property == 'object') { - switch (property.type) { - case 'string': - if (property.enum) { + {Object.keys(schema.properties || {}) + .filter(key => { + const property = schema.properties![key]; + // @ts-ignore + return !property.deprecated ?? true; + }) + .map(key => { + const property = schema.properties![key]; + if (typeof property == 'object') { + switch (property.type) { + case 'string': + if (property.enum) { + return ( + ({ + value: value!.toString(), + label: value!.toString(), + }))} + value={values[key]} + onChange={onChange} + /> + ); + } else { + return ( + + ); + } + case 'number': + case 'integer': { return ( - ({ - value: value!.toString(), - label: value!.toString(), - }))} + required={schema.required?.includes(key)} + type={property.type} + placeholder={ + // @ts-ignore + property.examples ? (property.examples[0] as number) : undefined + } + min={property.minimum} + max={property.maximum} value={values[key]} + errors={errors} onChange={onChange} /> ); - } else { + 342; + } + case 'array': { return ( - ); } - case 'number': - case 'integer': { - return ( - - ); - 342; - } - case 'object': { - if (property.oneOf) { - const typeKey = '__meta.' + key + '.type'; - const value = ((values.__meta || {})[key] || {}).type; - return ( -
- - {property.title || key} - - - ({ - // @ts-ignore - value: oneOf.title!, - // @ts-ignore - label: oneOf.title!, - }))} - value={value} - onChange={onChange} - /> - - {value != undefined && ( - - x.title == value) || property.oneOf[0]; + return ( +
+ + {property.title || key} + + + ({ // @ts-ignore - schema={property.oneOf.find(x => x.title == value) || property.oneOf[0]} - errors={errors} - onChange={onChange} - values={values[key] || {}} - /> - - )} - -
- ); - } else if ((values[key].properties?.length || 0) > 0) { - return ( -
- - {property.title || key} - - - - -
- ); + value: oneOf.title!, + // @ts-ignore + label: oneOf.title!, + }))} + value={value} + onChange={onChange} + /> + {value != undefined && ( + + + + )} +
+
+ ); + } else if ((values[key].properties?.length || 0) > 0) { + return ( +
+ + {property.title || key} + + + + +
+ ); + } + } + default: { + console.warn('Unsupported field type', property.type); } - } - default: { - console.warn('Unsupported field type', property.type); } } - } - })} + })}
); } diff --git a/arroyo-worker/src/connectors/websocket.rs b/arroyo-worker/src/connectors/websocket.rs index 40acda754..02e30a72c 100644 --- a/arroyo-worker/src/connectors/websocket.rs +++ b/arroyo-worker/src/connectors/websocket.rs @@ -1,3 +1,4 @@ +use std::str::FromStr; use std::{ marker::PhantomData, time::{Duration, Instant, SystemTime}, @@ -9,14 +10,17 @@ use arroyo_rpc::{ ControlMessage, OperatorConfig, }; use arroyo_state::tables::global_keyed_map::GlobalKeyedState; -use arroyo_types::{Data, Message, Record, UserError, Watermark}; +use arroyo_types::{string_to_map, Data, Message, Record, UserError, Watermark}; use bincode::{Decode, Encode}; use futures::{SinkExt, StreamExt}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use tokio::select; +use tokio_tungstenite::tungstenite::handshake::client::generate_key; +use tokio_tungstenite::tungstenite::http::Uri; use tokio_tungstenite::{connect_async, tungstenite}; use tracing::{debug, info}; +use tungstenite::http::Request; use typify::import_types; use crate::formats::DataDeserializer; @@ -37,7 +41,8 @@ where T: SchemaData, { url: String, - subscription_message: Option, + headers: Vec<(String, String)>, + subscription_messages: Vec, deserializer: DataDeserializer, state: WebsocketSourceState, _t: PhantomData, @@ -55,9 +60,25 @@ where let table: WebsocketTable = serde_json::from_value(config.table).expect("Invalid table config for WebsocketSource"); + // Include subscription_message for backwards compatibility + let mut subscription_messages = vec![]; + if let Some(message) = table.subscription_message { + subscription_messages.push(message.to_string()); + }; + subscription_messages.extend( + table + .subscription_messages + .into_iter() + .map(|m| m.to_string()), + ); + Self { url: table.endpoint, - subscription_message: table.subscription_message.map(|s| s.into()), + headers: string_to_map(table.headers.as_ref().map(|t| t.0.as_str()).unwrap_or("")) + .expect("Invalid header map") + .into_iter() + .collect(), + subscription_messages, deserializer: DataDeserializer::new( config.format.expect("WebsocketSource requires a format"), config.framing, @@ -143,7 +164,47 @@ where } async fn run(&mut self, ctx: &mut Context<(), T>) -> SourceFinishType { - let ws_stream = match connect_async(&self.url).await { + let uri = match Uri::from_str(&self.url.to_string()) { + Ok(uri) => uri, + Err(e) => { + ctx.report_error("Failed to parse endpoint".to_string(), format!("{:?}", e)) + .await; + panic!("Failed to parse endpoint: {:?}", e); + } + }; + + let host = match uri.host() { + Some(host) => host, + None => { + ctx.report_error("Endpoint must have a host".to_string(), "".to_string()) + .await; + panic!("Endpoint must have a host"); + } + }; + + let mut request_builder = Request::builder().uri(&self.url); + + for (k, v) in &self.headers { + request_builder = request_builder.header(k, v); + } + + let request = match request_builder + .header("Host", host) + .header("Sec-WebSocket-Key", generate_key()) + .header("Sec-WebSocket-Version", "13") + .header("Connection", "Upgrade") + .header("Upgrade", "websocket") + .body(()) + { + Ok(request) => request, + Err(e) => { + ctx.report_error("Failed to build request".to_string(), format!("{:?}", e)) + .await; + panic!("Failed to build request: {:?}", e); + } + }; + + let ws_stream = match connect_async(request).await { Ok((ws_stream, _)) => ws_stream, Err(e) => { ctx.report_error( @@ -160,7 +221,7 @@ where let (mut tx, mut rx) = ws_stream.split(); - if let Some(msg) = &self.subscription_message { + for msg in &self.subscription_messages { if let Err(e) = tx.send(tungstenite::Message::Text(msg.clone())).await { ctx.report_error( "Failed to send subscription message to websocket server".to_string(), diff --git a/connector-schemas/websocket/table.json b/connector-schemas/websocket/table.json index b3d7fdb1b..b8781ad8d 100644 --- a/connector-schemas/websocket/table.json +++ b/connector-schemas/websocket/table.json @@ -11,14 +11,35 @@ ], "format": "uri" }, + "headers": { + "title": "Headers", + "type": "string", + "description": "Comma separated list of headers to send with the request", + "pattern": "([a-zA-Z0-9-]+: ?.+,)*([a-zA-Z0-9-]+: ?.+)", + "examples": ["Authentication: digest 1234,Content-Type: application/json"] + }, "subscription_message": { "title": "Subscription Message", "type": "string", + "description": "[Deprecated] An optional message to send after the socket is opened.", "maxLength": 2048, - "description": "An optional message to send after the socket is opened", "examples": [ "{\"type\":\"subscribe\",\"channels\":[\"updates\"]}" - ] + ], + "deprecated": true + }, + "subscription_messages": { + "title": "Subscription Messages", + "type": "array", + "description": "An optional array of messages to send after the socket is opened. The messages will be sent in order.", + "items": { + "title": "Subscription Message", + "type": "string", + "maxLength": 2048, + "examples": [ + "{\"type\":\"subscribe\",\"channels\":[\"updates\"]}" + ] + } } }, "required": [