Skip to content

Commit

Permalink
Merge branch 'master' into object_store/upgrade_to_0_7_1
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackson Newhouse authored Oct 23, 2023
2 parents 47ed88a + d0e6e63 commit 1ac36f2
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 127 deletions.
6 changes: 4 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions arroyo-connectors/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ futures = "0.3.28"
tokio-tungstenite = { version = "0.20.1", features = ["native-tls"] }
axum = {version = "0.6.12"}
reqwest = "0.11.20"
rand = "0.8.5"
base64 = "0.13.1"
96 changes: 89 additions & 7 deletions arroyo-connectors/src/websocket.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use std::convert::Infallible;
use std::str::FromStr;
use std::time::Duration;

use anyhow::anyhow;
use arroyo_rpc::api_types::connections::{ConnectionSchema, ConnectionType, TestSourceMessage};
use arroyo_rpc::OperatorConfig;
use arroyo_types::string_to_map;
use axum::response::sse::Event;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Sender;
use tokio_tungstenite::tungstenite::handshake::client::generate_key;
use tokio_tungstenite::tungstenite::http::Uri;
use tokio_tungstenite::{connect_async, tungstenite};
use tungstenite::http::Request;
use typify::import_types;

use arroyo_rpc::api_types::connections::{ConnectionSchema, ConnectionType, TestSourceMessage};
use serde::{Deserialize, Serialize};

use crate::{pull_opt, Connection, EmptyConfig};

use super::Connector;
Expand Down Expand Up @@ -72,7 +76,55 @@ impl Connector for WebsocketConnector {
}
};

let ws_stream = match connect_async(&table.endpoint).await {
let headers =
match string_to_map(table.headers.as_ref().map(|t| t.0.as_str()).unwrap_or(""))
.ok_or_else(|| anyhow!("Headers are invalid; should be comma-separated pairs"))
{
Ok(headers) => headers,
Err(e) => {
send(true, true, format!("Failed to parse headers: {:?}", e)).await;
return;
}
};

let uri = match Uri::from_str(&table.endpoint.to_string()) {
Ok(uri) => uri,
Err(e) => {
send(true, true, format!("Failed to parse endpoint: {:?}", e)).await;
return;
}
};

let host = match uri.host() {
Some(host) => host,
None => {
send(true, true, "Endpoint must have a host".to_string()).await;
return;
}
};

let mut request_builder = Request::builder().uri(&table.endpoint);

for (k, v) in headers {
request_builder = request_builder.header(k, v);
}

let request = match request_builder
.header("Host", host)
.header("Sec-WebSocket-Key", generate_key())
.header("Sec-WebSocket-Version", "13")
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.body(())
{
Ok(request) => request,
Err(e) => {
send(true, true, format!("Failed to build request: {:?}", e)).await;
return;
}
};

let ws_stream = match connect_async(request).await {
Ok((ws_stream, _)) => ws_stream,
Err(e) => {
send(
Expand All @@ -94,7 +146,7 @@ impl Connector for WebsocketConnector {

let (mut tx, mut rx) = ws_stream.split();

if let Some(msg) = table.subscription_message {
for msg in table.subscription_messages {
match tx
.send(tungstenite::Message::Text(msg.clone().into()))
.await
Expand Down Expand Up @@ -159,6 +211,15 @@ impl Connector for WebsocketConnector {
) -> anyhow::Result<crate::Connection> {
let description = format!("WebsocketSource<{}>", table.endpoint);

if let Some(headers) = &table.headers {
string_to_map(headers).ok_or_else(|| {
anyhow!(
"Invalid format for headers; should be a \
comma-separated list of colon-separated key value pairs"
)
})?;
}

let schema = schema
.map(|s| s.to_owned())
.ok_or_else(|| anyhow!("no schema defined for WebSocket connection"))?;
Expand Down Expand Up @@ -195,15 +256,36 @@ impl Connector for WebsocketConnector {
schema: Option<&ConnectionSchema>,
) -> anyhow::Result<crate::Connection> {
let endpoint = pull_opt("endpoint", opts)?;
let subscription_message = opts.remove("subscription_message");
let headers = opts.remove("headers");
let mut subscription_messages = vec![];

// add the single subscription message if it exists
if let Some(message) = opts.remove("subscription_message") {
subscription_messages.push(SubscriptionMessage(message));

if opts.contains_key("subscription_messages.0") {
return Err(anyhow!(
"Cannot specify both 'subscription_message' and 'subscription_messages.0'"
));
}
}

// add the indexed subscription messages if they exist
let mut message_index = 0;
while let Some(message) = opts.remove(&format!("subscription_messages.{}", message_index)) {
subscription_messages.push(SubscriptionMessage(message));
message_index += 1;
}

self.from_config(
None,
name,
EmptyConfig {},
WebsocketTable {
endpoint,
subscription_message: subscription_message.map(SubscriptionMessage),
headers: headers.map(Headers),
subscription_message: None,
subscription_messages,
},
schema,
)
Expand Down
Loading

0 comments on commit 1ac36f2

Please sign in to comment.