diff --git a/Cargo.lock b/Cargo.lock index 30760e7fc..391e93b42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -660,6 +660,7 @@ version = "0.7.0" dependencies = [ "anyhow", "arroyo-types", + "async-trait", "bincode 2.0.0-rc.3", "log", "nanoid", @@ -1085,9 +1086,9 @@ checksum = "ecc7ab41815b3c653ccd2978ec3255c81349336702dfdf62ee6f7069b12a3aae" [[package]] name = "async-trait" -version = "0.1.73" +version = "0.1.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" +checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", diff --git a/arroyo-api/src/connection_tables.rs b/arroyo-api/src/connection_tables.rs index 545aba61c..b0b6c47e6 100644 --- a/arroyo-api/src/connection_tables.rs +++ b/arroyo-api/src/connection_tables.rs @@ -16,16 +16,17 @@ use tokio::sync::mpsc::channel; use tokio_stream::wrappers::ReceiverStream; use tracing::warn; -use arroyo_connectors::{Connector, connector_for_type, ErasedConnector}; use arroyo_connectors::kafka::{KafkaConfig, KafkaConnector, KafkaTable}; +use arroyo_connectors::{connector_for_type, Connector, ErasedConnector}; use arroyo_rpc::api_types::connections::{ - ConnectionProfile, ConnectionSchema, - ConnectionTable, ConnectionTablePost, SchemaDefinition, + ConnectionProfile, ConnectionSchema, ConnectionTable, ConnectionTablePost, SchemaDefinition, }; use arroyo_rpc::api_types::{ConnectionTableCollection, PaginationQueryParams}; use arroyo_rpc::formats::{AvroFormat, Format, JsonFormat}; use arroyo_rpc::public_ids::{generate_id, IdTypes}; -use arroyo_rpc::schema_resolver::{ConfluentSchemaResolver, ConfluentSchemaResponse, ConfluentSchemaType}; +use arroyo_rpc::schema_resolver::{ + ConfluentSchemaResolver, ConfluentSchemaResponse, ConfluentSchemaType, +}; use arroyo_sql::avro; use arroyo_sql::json_schema::convert_json_schema; use arroyo_sql::types::{StructField, TypeDef}; @@ -431,48 +432,75 @@ pub(crate) async fn expand_schema( }; match format { - Format::Json(_) => expand_json_schema(name, connector, schema, table_config, profile_config).await, - Format::Avro(_) => expand_avro_schema(name, connector, schema, table_config, profile_config).await, + Format::Json(_) => { + expand_json_schema(name, connector, schema, table_config, profile_config).await + } + Format::Avro(_) => { + expand_avro_schema(name, connector, schema, table_config, profile_config).await + } Format::Parquet(_) => Ok(schema), Format::RawString(_) => Ok(schema), } } -async fn expand_avro_schema(name: &str, connector: &str, mut schema: ConnectionSchema, table_config: &Value, profile_config: &Value) -> Result { - if let Some(Format::Avro(AvroFormat { confluent_schema_registry: true, .. })) = &schema.format { +async fn expand_avro_schema( + name: &str, + connector: &str, + mut schema: ConnectionSchema, + table_config: &Value, + profile_config: &Value, +) -> Result { + if let Some(Format::Avro(AvroFormat { + confluent_schema_registry: true, + .. + })) = &schema.format + { let schema_response = get_schema(connector, table_config, profile_config).await?; if schema_response.schema_type != ConfluentSchemaType::Avro { - return Err(bad_request(format!("Format configured is avro, but confluent schema repository returned a {:?} schema", - schema_response.schema_type))); + return Err(bad_request(format!( + "Format configured is avro, but confluent schema repository returned a {:?} schema", + schema_response.schema_type + ))); } schema.definition = Some(SchemaDefinition::AvroSchema(schema_response.schema)); } - let Some(SchemaDefinition::AvroSchema(definition)) = schema.definition.as_ref() else { return Err(bad_request("avro format requires an avro schema be set")); }; let fields: Result<_, String> = avro::convert_avro_schema(&name, &definition) .map_err(|e| bad_request(format!("Invalid avro schema: {}", e)))? - .into_iter().map(|f| f.try_into()).collect(); - + .into_iter() + .map(|f| f.try_into()) + .collect(); - schema.fields = fields - .map_err(|e| bad_request(format!("Failed to convert schema: {}", e)))?; + schema.fields = fields.map_err(|e| bad_request(format!("Failed to convert schema: {}", e)))?; Ok(schema) } -async fn expand_json_schema(name: &str, connector: &str, mut schema: ConnectionSchema, table_config: &Value, profile_config: &Value) -> Result { - if let Some(Format::Json(JsonFormat { confluent_schema_registry: true, .. })) = &schema.format { +async fn expand_json_schema( + name: &str, + connector: &str, + mut schema: ConnectionSchema, + table_config: &Value, + profile_config: &Value, +) -> Result { + if let Some(Format::Json(JsonFormat { + confluent_schema_registry: true, + .. + })) = &schema.format + { let schema_response = get_schema(connector, table_config, profile_config).await?; if schema_response.schema_type != ConfluentSchemaType::Json { - return Err(bad_request(format!("Format configured is json, but confluent schema repository returned a {:?} schema", - schema_response.schema_type))); + return Err(bad_request(format!( + "Format configured is json, but confluent schema repository returned a {:?} schema", + schema_response.schema_type + ))); } schema.definition = Some(SchemaDefinition::JsonSchema(schema_response.schema)); @@ -487,11 +515,7 @@ async fn expand_json_schema(name: &str, connector: &str, mut schema: ConnectionS None, TypeDef::DataType(DataType::Utf8, false), )], - _ => { - return Err(bad_request( - "Invalid schema type for json format", - )) - } + _ => return Err(bad_request("Invalid schema type for json format")), }; let fields: Result<_, String> = fields.into_iter().map(|f| f.try_into()).collect(); @@ -501,29 +525,44 @@ async fn expand_json_schema(name: &str, connector: &str, mut schema: ConnectionS } Ok(schema) - } -async fn get_schema(connector: &str, table_config: &Value, profile_config: &Value) -> Result { +async fn get_schema( + connector: &str, + table_config: &Value, + profile_config: &Value, +) -> Result { if connector != "kafka" { - return Err(bad_request("confluent schema registry can only be used for Kafka connections")); + return Err(bad_request( + "confluent schema registry can only be used for Kafka connections", + )); } // we unwrap here because this should already have been validated - let profile: KafkaConfig = serde_json::from_value(profile_config.clone()) - .expect("invalid kafka config"); - - let table: KafkaTable = serde_json::from_value(table_config.clone()) - .expect("invalid kafka table"); - - let schema_registry = profile.schema_registry.as_ref().ok_or_else(|| - bad_request("schema registry must be configured on the Kafka connection profile"))?; - - let resolver = ConfluentSchemaResolver::new(&schema_registry.endpoint, &table.topic) - .map_err(|e| bad_request(format!("failed to fetch schemas from schema repository: {}", e)))?; - - resolver.get_schema(None).await - .map_err(|e| bad_request(format!("failed to fetch schemas from schema repository: {}", e))) + let profile: KafkaConfig = + serde_json::from_value(profile_config.clone()).expect("invalid kafka config"); + + let table: KafkaTable = + serde_json::from_value(table_config.clone()).expect("invalid kafka table"); + + let schema_registry = profile.schema_registry.as_ref().ok_or_else(|| { + bad_request("schema registry must be configured on the Kafka connection profile") + })?; + + let resolver = + ConfluentSchemaResolver::new(&schema_registry.endpoint, &table.topic).map_err(|e| { + bad_request(format!( + "failed to fetch schemas from schema repository: {}", + e + )) + })?; + + resolver.get_schema(None).await.map_err(|e| { + bad_request(format!( + "failed to fetch schemas from schema repository: {}", + e + )) + }) } /// Test a Connection Schema @@ -556,4 +595,4 @@ pub(crate) async fn test_schema( Ok(()) } } -} \ No newline at end of file +} diff --git a/arroyo-api/src/lib.rs b/arroyo-api/src/lib.rs index c54ceb395..43d9f4312 100644 --- a/arroyo-api/src/lib.rs +++ b/arroyo-api/src/lib.rs @@ -8,8 +8,8 @@ use crate::connection_profiles::{ __path_create_connection_profile, __path_get_connection_profiles, }; use crate::connection_tables::{ - __path_create_connection_table, __path_delete_connection_table, - __path_get_connection_tables, __path_test_connection_table, __path_test_schema, + __path_create_connection_table, __path_delete_connection_table, __path_get_connection_tables, + __path_test_connection_table, __path_test_schema, }; use crate::connectors::__path_get_connectors; use crate::jobs::{ diff --git a/arroyo-api/src/rest.rs b/arroyo-api/src/rest.rs index bb713012a..8da8d0741 100644 --- a/arroyo-api/src/rest.rs +++ b/arroyo-api/src/rest.rs @@ -19,8 +19,8 @@ use utoipa_swagger_ui::SwaggerUi; use crate::connection_profiles::{create_connection_profile, get_connection_profiles}; use crate::connection_tables::{ - create_connection_table, delete_connection_table, get_connection_tables, - test_connection_table, test_schema, + create_connection_table, delete_connection_table, get_connection_tables, test_connection_table, + test_schema, }; use crate::connectors::get_connectors; use crate::jobs::{ diff --git a/arroyo-rpc/Cargo.toml b/arroyo-rpc/Cargo.toml index 1e3eedd15..475cc714d 100644 --- a/arroyo-rpc/Cargo.toml +++ b/arroyo-rpc/Cargo.toml @@ -17,9 +17,10 @@ serde_json = "1.0" nanoid = "0.4" utoipa = "3" anyhow = "1.0.75" -reqwest = "0.11.22" +reqwest = { version = "0.11.22", features = ["default", "serde_json", "json"] } log = "0.4.20" tracing = "0.1.40" +async-trait = "0.1.74" [build-dependencies] tonic-build = { workspace = true } diff --git a/arroyo-rpc/src/formats.rs b/arroyo-rpc/src/formats.rs index 94b0ac451..a74b18f19 100644 --- a/arroyo-rpc/src/formats.rs +++ b/arroyo-rpc/src/formats.rs @@ -1,10 +1,10 @@ +use arroyo_types::UserError; use serde::{Deserialize, Serialize}; +use serde_json::Value; use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::str::FromStr; -use serde_json::Value; use utoipa::ToSchema; -use arroyo_types::UserError; #[derive( Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Default, Hash, PartialOrd, ToSchema, @@ -102,13 +102,15 @@ pub struct ConfluentSchemaRegistryConfig { #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, ToSchema)] #[serde(rename_all = "camelCase")] pub struct AvroFormat { - pub confluent_schema_registry: bool + pub confluent_schema_registry: bool, + pub embedded_schema: bool, } impl AvroFormat { pub fn from_opts(opts: &mut HashMap) -> Result { Ok(Self { - confluent_schema_registry: false + confluent_schema_registry: false, + embedded_schema: false, }) } } diff --git a/arroyo-rpc/src/lib.rs b/arroyo-rpc/src/lib.rs index 821e61928..5b744abc5 100644 --- a/arroyo-rpc/src/lib.rs +++ b/arroyo-rpc/src/lib.rs @@ -1,8 +1,7 @@ pub mod api_types; pub mod formats; -pub mod schema_resolver; pub mod public_ids; - +pub mod schema_resolver; use std::{fs, time::SystemTime}; diff --git a/arroyo-rpc/src/schema_resolver.rs b/arroyo-rpc/src/schema_resolver.rs index 1c4e0414e..e352d5b0a 100644 --- a/arroyo-rpc/src/schema_resolver.rs +++ b/arroyo-rpc/src/schema_resolver.rs @@ -1,24 +1,30 @@ -use std::time::Duration; use anyhow::{anyhow, bail}; -use arroyo_types::UserError; +use async_trait::async_trait; use reqwest::{Client, StatusCode, Url}; -use serde_json::Value; -use tracing::warn; use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tracing::warn; -pub trait SchemaResolver { - fn resolve_schema(&self, id: [u8; 4]) -> Result, UserError>; +#[async_trait] +pub trait SchemaResolver: Send { + async fn resolve_schema(&self, id: [u8; 4]) -> Result, String>; } -pub struct FailingSchemaResolver { +pub struct FailingSchemaResolver {} + +impl FailingSchemaResolver { + pub fn new() -> Self { + FailingSchemaResolver {} + } } +#[async_trait] impl SchemaResolver for FailingSchemaResolver { - fn resolve_schema(&self, id: [u8; 4]) -> Result, UserError> { - Err(UserError { - name: "Could not deserialize".to_string(), - details: format!("Schema with id {:?} not available, and no schema registry configured", id), - }) + async fn resolve_schema(&self, id: [u8; 4]) -> Result, String> { + Err(format!( + "Schema with id {:?} not available, and no schema registry configured", + id + )) } } @@ -28,7 +34,7 @@ pub enum ConfluentSchemaType { #[default] Avro, Json, - Protobuf + Protobuf, } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -52,11 +58,12 @@ impl ConfluentSchemaResolver { pub fn new(endpoint: &str, topic: &str) -> anyhow::Result { let client = Client::builder() .timeout(Duration::from_secs(5)) - .build().unwrap(); + .build() + .unwrap(); - - let endpoint: Url = - format!("{}/subjects/{}-value/versions/", endpoint, topic).as_str().try_into() + let endpoint: Url = format!("{}/subjects/{}-value/versions/", endpoint, topic) + .as_str() + .try_into() .map_err(|e| anyhow!("{} is not a valid url", endpoint))?; Ok(Self { @@ -66,21 +73,30 @@ impl ConfluentSchemaResolver { }) } + pub async fn get_schema( + &self, + version: Option, + ) -> anyhow::Result { + let url = self + .endpoint + .join( + &version + .map(|v| format!("{}", v)) + .unwrap_or_else(|| "latest".to_string()), + ) + .unwrap(); - pub async fn get_schema(&self, version: Option) -> anyhow::Result { - let url = self.endpoint.join( - &version.map(|v| format!("{}", v)).unwrap_or_else(|| "latest".to_string())).unwrap(); - - let resp = reqwest::get(url.clone()).await.map_err(|e| { + let resp = self.client.get(url.clone()).send().await.map_err(|e| { warn!("Got error response from schema registry: {:?}", e); match e.status() { - Some(StatusCode::NOT_FOUND) => anyhow!( - "Could not find value schema for topic '{}'", - self.topic), + Some(StatusCode::NOT_FOUND) => { + anyhow!("Could not find value schema for topic '{}'", self.topic) + } Some(code) => anyhow!("Schema registry returned error: {}", code), None => { - warn!("Unknown error connecting to schema registry {}: {:?}", + warn!( + "Unknown error connecting to schema registry {}: {:?}", self.endpoint, e ); anyhow!( @@ -103,10 +119,22 @@ impl ConfluentSchemaResolver { } resp.json().await.map_err(|e| { - warn!("Invalid json from schema registry: {:?} for request {:?}", e, url); - anyhow!( - "Schema registry response could not be deserialized: {}", e - ) + warn!( + "Invalid json from schema registry: {:?} for request {:?}", + e, url + ); + anyhow!("Schema registry response could not be deserialized: {}", e) }) } -} \ No newline at end of file +} + +#[async_trait] +impl SchemaResolver for ConfluentSchemaResolver { + async fn resolve_schema(&self, id: [u8; 4]) -> Result, String> { + let version = u32::from_be_bytes(id); + self.get_schema(Some(version)) + .await + .map(|s| Some(s.schema)) + .map_err(|e| e.to_string()) + } +} diff --git a/arroyo-sql/src/avro.rs b/arroyo-sql/src/avro.rs index 856109db3..1bb6f178b 100644 --- a/arroyo-sql/src/avro.rs +++ b/arroyo-sql/src/avro.rs @@ -1,22 +1,20 @@ -use std::sync::Arc; +use crate::types::{StructDef, StructField, TypeDef}; use anyhow::{anyhow, bail}; use apache_avro::Schema; -use arrow_schema::{DataType}; +use arrow_schema::DataType; use proc_macro2::Ident; use quote::quote; -use crate::types::{StructDef, StructField, TypeDef}; +use std::sync::Arc; pub const ROOT_NAME: &str = "ArroyoAvroRoot"; pub fn convert_avro_schema(name: &str, schema: &str) -> anyhow::Result> { - let schema = Schema::parse_str(schema) - .map_err(|e| anyhow!("avro schema is not valid: {:?}", e))?; + let schema = + Schema::parse_str(schema).map_err(|e| anyhow!("avro schema is not valid: {:?}", e))?; let (typedef, _) = to_typedef(name, &schema); match typedef { - TypeDef::StructDef(sd, _) => { - Ok(sd.fields) - } + TypeDef::StructDef(sd, _) => Ok(sd.fields), TypeDef::DataType(_, _) => { bail!("top-level schema must be a record") } @@ -27,8 +25,15 @@ pub fn get_defs(name: &str, schema: &str) -> anyhow::Result { let fields = convert_avro_schema(name, schema)?; let sd = StructDef::new(Some(ROOT_NAME.to_string()), true, fields, None); - let defs: Vec<_> = sd.all_structs_including_named().iter() - .map(|p| vec![syn::parse_str(&p.def(false)).unwrap(), p.generate_serializer_items()]) + let defs: Vec<_> = sd + .all_structs_including_named() + .iter() + .map(|p| { + vec![ + syn::parse_str(&p.def(false)).unwrap(), + p.generate_serializer_items(), + ] + }) .flatten() .collect(); @@ -39,35 +44,20 @@ pub fn get_defs(name: &str, schema: &str) -> anyhow::Result { #(#defs) * } - }.to_string()) + } + .to_string()) } fn to_typedef(source_name: &str, schema: &Schema) -> (TypeDef, Option) { match schema { - Schema::Null => { - (TypeDef::DataType(DataType::Null, false), None) - } - Schema::Boolean => { - (TypeDef::DataType(DataType::Boolean, false), None) - } - Schema::Int => { - (TypeDef::DataType(DataType::Int32, false), None) - } - Schema::Long => { - (TypeDef::DataType(DataType::Int64, false), None) - } - Schema::Float => { - (TypeDef::DataType(DataType::Float32, false), None) - } - Schema::Double => { - (TypeDef::DataType(DataType::Float64, false), None) - } - Schema::Bytes => { - (TypeDef::DataType(DataType::Binary, false), None) - } - Schema::String => { - (TypeDef::DataType(DataType::Utf8, false), None) - } + Schema::Null => (TypeDef::DataType(DataType::Null, false), None), + Schema::Boolean => (TypeDef::DataType(DataType::Boolean, false), None), + Schema::Int => (TypeDef::DataType(DataType::Int32, false), None), + Schema::Long => (TypeDef::DataType(DataType::Int64, false), None), + Schema::Float => (TypeDef::DataType(DataType::Float32, false), None), + Schema::Double => (TypeDef::DataType(DataType::Float64, false), None), + Schema::Bytes => (TypeDef::DataType(DataType::Binary, false), None), + Schema::String => (TypeDef::DataType(DataType::Utf8, false), None), // Schema::Array(t) => { // let dt = match to_typedef(source_name, t) { // (TypeDef::StructDef(sd, _), _) => { @@ -88,7 +78,10 @@ fn to_typedef(source_name: &str, schema: &Schema) -> (TypeDef, Option) { // currently just support unions that have [t, null] as variants, which is the // avro way to represent optional fields - let (nulls, not_nulls): (Vec<_>, Vec<_>) = union.variants().iter().partition(|v| matches!(v, Schema::Null)); + let (nulls, not_nulls): (Vec<_>, Vec<_>) = union + .variants() + .iter() + .partition(|v| matches!(v, Schema::Null)); if nulls.len() == 1 && not_nulls.len() == 1 { let (dt, original) = to_typedef(source_name, not_nulls[0]); @@ -101,36 +94,41 @@ fn to_typedef(source_name: &str, schema: &Schema) -> (TypeDef, Option) { } } Schema::Record(record) => { - let fields = record.fields.iter().map(|f| { - let (ft, original) = to_typedef(source_name, &f.schema); - StructField::with_rename(f.name.clone(), None, ft, None, original) - }).collect(); + let fields = record + .fields + .iter() + .map(|f| { + let (ft, original) = to_typedef(source_name, &f.schema); + StructField::with_rename(f.name.clone(), None, ft, None, original) + }) + .collect(); ( - TypeDef::StructDef(StructDef::for_name(Some(format!("{}::{}", source_name, record.name.name)), fields), false), - None + TypeDef::StructDef( + StructDef::for_name( + Some(format!("{}::{}", source_name, record.name.name)), + fields, + ), + false, + ), + None, ) } - _ => { - ( - TypeDef::DataType(DataType::Utf8, false), - Some("json".to_string()), - ) - } - // Schema::Enum(_) => {} - // Schema::Fixed(_) => {} - // Schema::Decimal(_) => {} - // Schema::Uuid => {} - // Schema::Date => {} - // Schema::TimeMillis => {} - // Schema::TimeMicros => {} - // Schema::TimestampMillis => {} - // Schema::TimestampMicros => {} - // Schema::LocalTimestampMillis => {} - // Schema::LocalTimestampMicros => {} - // Schema::Duration => {} - // Schema::Ref { .. } => {} + _ => ( + TypeDef::DataType(DataType::Utf8, false), + Some("json".to_string()), + ), // Schema::Enum(_) => {} + // Schema::Fixed(_) => {} + // Schema::Decimal(_) => {} + // Schema::Uuid => {} + // Schema::Date => {} + // Schema::TimeMillis => {} + // Schema::TimeMicros => {} + // Schema::TimestampMillis => {} + // Schema::TimestampMicros => {} + // Schema::LocalTimestampMillis => {} + // Schema::LocalTimestampMicros => {} + // Schema::Duration => {} + // Schema::Ref { .. } => {} } } - - diff --git a/arroyo-sql/src/lib.rs b/arroyo-sql/src/lib.rs index 4492b33bb..488cebf15 100644 --- a/arroyo-sql/src/lib.rs +++ b/arroyo-sql/src/lib.rs @@ -8,11 +8,11 @@ use arroyo_connectors::{Connection, Connector}; use arroyo_datastream::Program; use datafusion::physical_plan::functions::make_scalar_function; +pub mod avro; pub(crate) mod code_gen; pub mod expressions; pub mod external; pub mod json_schema; -pub mod avro; mod operators; mod optimizations; mod pipeline; diff --git a/arroyo-worker/src/connectors/polling_http.rs b/arroyo-worker/src/connectors/polling_http.rs index 591f883b3..0d3fcb1d2 100644 --- a/arroyo-worker/src/connectors/polling_http.rs +++ b/arroyo-worker/src/connectors/polling_http.rs @@ -259,24 +259,29 @@ where } } - let iter = self.deserializer.deserialize_slice(&buf); - - for record in iter { - match record { - Ok(value) => { - ctx.collect(Record { - timestamp: SystemTime::now(), - key: None, - value, - }).await; - } - Err(e) => { - ctx.report_user_error(e).await; + match self.deserializer.deserialize_slice(&buf).await { + Ok(iter) => { + for record in iter { + match record { + Ok(value) => { + ctx.collect(Record { + timestamp: SystemTime::now(), + key: None, + value, + }).await; + } + Err(e) => { + ctx.report_user_error(e).await; + } + } } + + self.state.last_message = Some(buf); + } + Err(e) => { + ctx.report_user_error(e).await; } } - - self.state.last_message = Some(buf); } Err(e) => { ctx.report_user_error(e).await; diff --git a/arroyo-worker/src/formats.rs b/arroyo-worker/src/formats.rs index 443d241e8..0e82c8d35 100644 --- a/arroyo-worker/src/formats.rs +++ b/arroyo-worker/src/formats.rs @@ -1,13 +1,16 @@ +use apache_avro::{Reader, Schema}; use std::sync::Arc; use std::{collections::HashMap, marker::PhantomData}; -use apache_avro::Schema; use arrow::datatypes::{Field, Fields}; use arroyo_rpc::formats::{AvroFormat, Format, Framing, FramingMethod, JsonFormat}; +use arroyo_rpc::schema_resolver::{FailingSchemaResolver, SchemaResolver}; use arroyo_types::UserError; use serde::de::DeserializeOwned; use serde_json::{json, Value}; use tokio::sync::Mutex; +use tracing::log::kv::Source; +use tracing::warn; use crate::SchemaData; @@ -56,12 +59,74 @@ fn deserialize_raw_string(msg: &[u8]) -> Result Ok(serde_json::from_value(json).unwrap()) } -fn deserialize_slice_avro( +async fn deserialize_slice_avro<'a, T: DeserializeOwned>( format: &AvroFormat, schema_registry: Arc>>, - msg: &[u8], -) { + resolver: Arc, + mut msg: &'a [u8], +) -> Result> + 'a, String> { + let id = if format.confluent_schema_registry { + let magic_byte = msg[0]; + if magic_byte != 0 { + return Err(format!("data was not encoded with schema registry wire format; magic byte has unexpected value: {}", magic_byte)); + } + + let id = [msg[1], msg[2], msg[3], msg[4]]; + msg = &msg[5..]; + id + } else { + [0, 0, 0, 0] + }; + + let mut registry = schema_registry.lock().await; + + let mut reader = if format.embedded_schema { + Reader::new(&msg[..]).map_err(|e| format!("invalid Avro schema in message: {:?}", e))? + } else { + let schema = if registry.contains_key(&id) { + registry.get(&id).unwrap() + } else { + let new_schema = resolver.resolve_schema(id).await?.ok_or_else(|| { + format!( + "could not resolve schema for message with id {}", + u32::from_le_bytes(id) + ) + })?; + + let new_schema = Schema::parse_str(&new_schema) + .map_err(|e| format!("invalid avro schema: {:?}", e))?; + + registry.insert(id, new_schema); + + registry.get(&id).unwrap() + }; + + Reader::with_schema(schema, &msg[..]) + .map_err(|e| format!("invalid avro schema: {:?}", e))? + }; + + let messages: Vec<_> = reader.collect(); + Ok(messages.into_iter().map(|record| { + apache_avro::from_value::(&record.map_err(|e| { + UserError::new( + "Deserialization failed", + format!( + "Failed to deserialize from avro: {:?}", + e + ), + ) + })?) + .map_err(|e| { + UserError::new( + "Deserialization failed", + format!("Failed to convert avro message into struct type: {:?}", e), + ) + }) + })) + // let record = reader.next() + // .ok_or_else(|| "avro record did not contain any messages")? + // .map_err(|e| e.to_string())?; } pub struct FramingIterator<'a> { @@ -120,38 +185,56 @@ pub struct DataDeserializer { format: Arc, framing: Option>, schema_registry: Arc>>, - schema_resolver: Arc>, + schema_resolver: Arc, _t: PhantomData, } - impl DataDeserializer { - pub fn new(format: Format, framing: Option,) -> Self { - if let Format::Avro(avro) = &format { - - }; + pub fn new(format: Format, framing: Option) -> Self { + Self::with_schema_resolver(format, framing, Arc::new(FailingSchemaResolver::new())) + } + pub fn with_schema_resolver( + format: Format, + framing: Option, + schema_resolver: Arc, + ) -> Self { Self { format: Arc::new(format), framing: framing.map(|f| Arc::new(f)), + schema_registry: Arc::new(Mutex::new(HashMap::new())), + schema_resolver, _t: PhantomData, } } - pub fn deserialize_slice<'a>( + pub async fn deserialize_slice<'a>( &self, msg: &'a [u8], - ) -> impl Iterator> + 'a { - let new_self = self.clone(); - FramingIterator::new(self.framing.clone(), msg) - .map(move |t| new_self.deserialize_single(t)) + ) -> Result> + 'a, UserError> { + match &*self.format { + Format::Avro(avro) => { + deserialize_slice_avro( + avro, + self.schema_registry.clone(), + self.schema_resolver.clone(), + msg, + ) + .await + } + _ => { + let new_self = self.clone(); + Ok(FramingIterator::new(self.framing.clone(), msg) + .map(move |t| new_self.deserialize_single(t))) + } + } } fn deserialize_single(&self, msg: &[u8]) -> Result { match &*self.format { Format::Json(json) => deserialize_slice_json(json, msg), - Format::Avro(avro) => deserialie_slice_avro(), - Format::Parquet(_) => todo!(), + Format::Avro(avro) => unreachable!("avro should be handled by here"), + Format::Parquet(_) => todo!("parquet is not supported as an input format"), Format::RawString(_) => deserialize_raw_string(msg), } .map_err(|e| {