Skip to content

Commit

Permalink
Add headers and multiple subscription messages to Websocket source
Browse files Browse the repository at this point in the history
Add headers and subscription_messages fields to the Websocket schema and
include them in the request. The console uses a new ArrayWidget
component to render this type of schema.
  • Loading branch information
jbeisen committed Oct 20, 2023
1 parent 08e43b8 commit a2ef9cc
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 22 deletions.
6 changes: 4 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions arroyo-connectors/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ futures = "0.3.28"
tokio-tungstenite = { version = "0.20.1", features = ["native-tls"] }
axum = {version = "0.6.12"}
reqwest = "0.11.20"
rand = "0.8.5"
base64 = "0.13.1"
77 changes: 70 additions & 7 deletions arroyo-connectors/src/websocket.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
use std::convert::Infallible;
use std::env;
use std::str::FromStr;
use std::time::Duration;

use anyhow::anyhow;
use arroyo_rpc::api_types::connections::{ConnectionSchema, ConnectionType, TestSourceMessage};
use arroyo_rpc::OperatorConfig;
use arroyo_types::string_to_map;
use axum::response::sse::Event;
use futures::{SinkExt, StreamExt};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Sender;
use tokio_tungstenite::tungstenite::handshake::client::generate_key;
use tokio_tungstenite::tungstenite::http::Uri;
use tokio_tungstenite::{connect_async, tungstenite};
use tungstenite::http::Request;
use typify::import_types;

use arroyo_rpc::api_types::connections::{ConnectionSchema, ConnectionType, TestSourceMessage};
use serde::{Deserialize, Serialize};

use crate::{pull_opt, Connection, EmptyConfig};

use super::Connector;
Expand Down Expand Up @@ -72,7 +78,34 @@ impl Connector for WebsocketConnector {
}
};

let ws_stream = match connect_async(&table.endpoint).await {
let headers =
match string_to_map(table.headers.as_ref().map(|t| t.0.as_str()).unwrap_or(""))
.ok_or_else(|| anyhow!("Headers are invalid; should be comma-separated pairs"))
{
Ok(headers) => headers,
Err(e) => {
send(true, true, format!("Failed to parse headers: {:?}", e)).await;
return;
}
};

let uri = Uri::from_str(&table.endpoint.to_string()).unwrap();
let mut request_builder = Request::builder().uri(&table.endpoint);

for (k, v) in headers {
request_builder = request_builder.header(k, v);
}

let request = request_builder
.header("Host", uri.host().unwrap())
.header("Sec-WebSocket-Key", generate_key())
.header("Sec-WebSocket-Version", "13")
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.body(())
.unwrap();

let ws_stream = match connect_async(request).await {
Ok((ws_stream, _)) => ws_stream,
Err(e) => {
send(
Expand All @@ -94,7 +127,7 @@ impl Connector for WebsocketConnector {

let (mut tx, mut rx) = ws_stream.split();

if let Some(msg) = table.subscription_message {
for msg in table.subscription_messages {
match tx
.send(tungstenite::Message::Text(msg.clone().into()))
.await
Expand Down Expand Up @@ -159,6 +192,15 @@ impl Connector for WebsocketConnector {
) -> anyhow::Result<crate::Connection> {
let description = format!("WebsocketSource<{}>", table.endpoint);

if let Some(headers) = &table.headers {
string_to_map(headers).ok_or_else(|| {
anyhow!(
"Invalid format for headers; should be a \
comma-separated list of colon-separated key value pairs"
)
})?;
}

let schema = schema
.map(|s| s.to_owned())
.ok_or_else(|| anyhow!("no schema defined for WebSocket connection"))?;
Expand Down Expand Up @@ -195,15 +237,36 @@ impl Connector for WebsocketConnector {
schema: Option<&ConnectionSchema>,
) -> anyhow::Result<crate::Connection> {
let endpoint = pull_opt("endpoint", opts)?;
let subscription_message = opts.remove("subscription_message");
let headers = opts.remove("headers");
let mut subscription_messages = vec![];

// add the single subscription message if it exists
if let Some(message) = opts.remove("subscription_message") {
subscription_messages.push(SubscriptionMessage(message));

if opts.contains_key("subscription_messages.0") {
return Err(anyhow!(
"Cannot specify both 'subscription_message' and 'subscription_messages.0'"
));
}
}

// add the indexed subscription messages if they exist
let mut message_index = 0;
while let Some(message) = opts.remove(&format!("subscription_messages.{}", message_index)) {
subscription_messages.push(SubscriptionMessage(message));
message_index += 1;
}

self.from_config(
None,
name,
EmptyConfig {},
WebsocketTable {
endpoint,
subscription_message: subscription_message.map(SubscriptionMessage),
headers: headers.map(Headers),
// subscription_message: subscription_message.map(SubscriptionMessage),
subscription_messages,
},
schema,
)
Expand Down
113 changes: 112 additions & 1 deletion arroyo-console/src/routes/connections/JsonForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ import {
AlertIcon,
Box,
Button,
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
IconButton,
Input,
Select,
Stack,
Expand All @@ -17,7 +19,8 @@ import { useFormik } from 'formik';

import Ajv from 'ajv';
import addFormats from 'ajv-formats';
import { useEffect, useMemo } from 'react';
import React, { useEffect, useMemo } from 'react';
import { AddIcon, DeleteIcon } from '@chakra-ui/icons';

function StringWidget({
path,
Expand Down Expand Up @@ -158,6 +161,102 @@ function SelectWidget({
);
}

export function ArrayWidget({
schema,
onChange,
path,
values,
errors,
}: {
schema: JSONSchema7;
onChange: (e: React.ChangeEvent<any>) => void;
path: string;
values: any;
errors: any;
}) {
const add = () => {
// @ts-ignore
onChange({ target: { name: path, value: [...(values || []), undefined] } });
};

const deleteItem = (index: number) => {
values.splice(index, 1);
// @ts-ignore
onChange({ target: { name: path, value: values } });
};

const itemsSchema = schema.items as JSONSchema7;

const example =
itemsSchema.examples && Array.isArray(itemsSchema.examples)
? (itemsSchema.examples[0] as string)
: undefined;

const arrayItem = (v: string, i: number) => {
switch (itemsSchema.type) {
case 'string':
return (
<StringWidget
path={`${path}.${i}`}
title={itemsSchema.title + ` ${i + 1}`}
value={v}
errors={errors}
onChange={onChange}
maxLength={itemsSchema.maxLength}
description={itemsSchema.description}
placeholder={example}
/>
);
default:
console.warn('Unsupported array item type', itemsSchema.type);
return <></>;
}
};

return (
<Box>
<fieldset key={schema.title} style={{ border: '1px solid #888', borderRadius: '8px' }}>
<legend
style={{
marginLeft: '8px',
paddingLeft: '16px',
paddingRight: '16px',
}}
>
{schema.title}
</legend>
<FormControl isInvalid={errors[path]}>
<Stack p={4} gap={2}>
{errors[path] ? (
<FormErrorMessage>{errors[path]}</FormErrorMessage>
) : (
schema.description && (
<FormHelperText mt={0} pb={2}>
{schema.description}
</FormHelperText>
)
)}
{values?.map((value: any, index: number) => (
<Flex alignItems={'flex-end'} gap={2}>
{arrayItem(value, index)}
<IconButton
width={8}
height={8}
minWidth={0}
aria-label="Delete item"
onClick={() => deleteItem(index)}
icon={<DeleteIcon width={3} />}
/>
</Flex>
))}
<IconButton aria-label="Add item" onClick={add} icon={<AddIcon />} />
</Stack>
</FormControl>
</fieldset>
</Box>
);
}

export function FormInner({
schema,
onChange,
Expand Down Expand Up @@ -241,6 +340,18 @@ export function FormInner({
);
342;
}
case 'array': {
return (
<ArrayWidget
path={(path ? `${path}.` : '') + key}
key={key}
schema={property}
values={values[key]}
errors={errors}
onChange={onChange}
/>
);
}
case 'object': {
if (property.oneOf) {
const typeKey = '__meta.' + key + '.type';
Expand Down
39 changes: 34 additions & 5 deletions arroyo-worker/src/connectors/websocket.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::str::FromStr;
use std::{
marker::PhantomData,
time::{Duration, Instant, SystemTime},
Expand All @@ -9,14 +10,17 @@ use arroyo_rpc::{
ControlMessage, OperatorConfig,
};
use arroyo_state::tables::global_keyed_map::GlobalKeyedState;
use arroyo_types::{Data, Message, Record, UserError, Watermark};
use arroyo_types::{string_to_map, Data, Message, Record, UserError, Watermark};
use bincode::{Decode, Encode};
use futures::{SinkExt, StreamExt};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tokio::select;
use tokio_tungstenite::tungstenite::handshake::client::generate_key;
use tokio_tungstenite::tungstenite::http::Uri;
use tokio_tungstenite::{connect_async, tungstenite};
use tracing::{debug, info};
use tungstenite::http::Request;
use typify::import_types;

use crate::formats::DataDeserializer;
Expand All @@ -37,7 +41,8 @@ where
T: SchemaData,
{
url: String,
subscription_message: Option<String>,
headers: Vec<(String, String)>,
subscription_messages: Vec<String>,
deserializer: DataDeserializer<T>,
state: WebsocketSourceState,
_t: PhantomData<K>,
Expand All @@ -57,7 +62,15 @@ where

Self {
url: table.endpoint,
subscription_message: table.subscription_message.map(|s| s.into()),
headers: string_to_map(table.headers.as_ref().map(|t| t.0.as_str()).unwrap_or(""))
.expect("Invalid header map")
.into_iter()
.collect(),
subscription_messages: table
.subscription_messages
.into_iter()
.map(|m| m.to_string())
.collect(),
deserializer: DataDeserializer::new(
config.format.expect("WebsocketSource requires a format"),
config.framing,
Expand Down Expand Up @@ -143,7 +156,23 @@ where
}

async fn run(&mut self, ctx: &mut Context<(), T>) -> SourceFinishType {
let ws_stream = match connect_async(&self.url).await {
let uri = Uri::from_str(&self.url.to_string()).unwrap();
let mut request_builder = Request::builder().uri(&self.url);

for (k, v) in &self.headers {
request_builder = request_builder.header(k, v);
}

let request = request_builder
.header("Host", uri.host().unwrap())
.header("Sec-WebSocket-Key", generate_key())
.header("Sec-WebSocket-Version", "13")
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.body(())
.unwrap();

let ws_stream = match connect_async(request).await {
Ok((ws_stream, _)) => ws_stream,
Err(e) => {
ctx.report_error(
Expand All @@ -160,7 +189,7 @@ where

let (mut tx, mut rx) = ws_stream.split();

if let Some(msg) = &self.subscription_message {
for msg in &self.subscription_messages {
if let Err(e) = tx.send(tungstenite::Message::Text(msg.clone())).await {
ctx.report_error(
"Failed to send subscription message to websocket server".to_string(),
Expand Down
Loading

0 comments on commit a2ef9cc

Please sign in to comment.