diff --git a/crates/arroyo-connectors/src/mqtt/mod.rs b/crates/arroyo-connectors/src/mqtt/mod.rs index e0e54d3fd..ff604d51c 100644 --- a/crates/arroyo-connectors/src/mqtt/mod.rs +++ b/crates/arroyo-connectors/src/mqtt/mod.rs @@ -158,7 +158,7 @@ impl Connector for MqttConnector { config: MqttConfig, table: MqttTable, schema: Option<&ConnectionSchema>, - _metadata_fields: Option>, + metadata_fields: Option>, ) -> anyhow::Result { let (typ, desc) = match table.type_ { TableType::Source { .. } => ( @@ -178,6 +178,13 @@ impl Connector for MqttConnector { .map(|t| t.to_owned()) .ok_or_else(|| anyhow!("'format' must be set for Mqtt connection"))?; + let metadata_fields = metadata_fields.map(|fields| { + fields + .into_iter() + .map(|(k, (v, _))| (k, v)) + .collect::>() + }); + let config = OperatorConfig { connection: serde_json::to_value(config).unwrap(), table: serde_json::to_value(table).unwrap(), @@ -185,7 +192,7 @@ impl Connector for MqttConnector { format: Some(format), bad_data: schema.bad_data.clone(), framing: schema.framing.clone(), - additional_fields: None, + additional_fields: metadata_fields, }; Ok(Connection { @@ -246,7 +253,7 @@ impl Connector for MqttConnector { options: &mut HashMap, schema: Option<&ConnectionSchema>, profile: Option<&ConnectionProfile>, - _metadata_fields: Option>, + metadata_fields: Option>, ) -> anyhow::Result { let connection = profile .map(|p| { @@ -258,7 +265,25 @@ impl Connector for MqttConnector { let table = Self::table_from_options(options)?; - Self::from_config(self, None, name, connection, table, schema, None) + if let Some(fields) = &metadata_fields { + for (k, (v, t)) in fields { + if v != "topic" { + return Err(anyhow!( + "Invalid metadata field name for mqtt connector: {}", + k + )); + } + if *t != DataType::Utf8 { + return Err(anyhow!( + "Invalid datatype: {} for metadata field: {} for mqtt connector", + k, + v + )); + } + } + } + + Self::from_config(self, None, name, connection, table, schema, metadata_fields) } fn make_operator( @@ -286,6 +311,7 @@ impl Connector for MqttConnector { ) .unwrap(), subscribed: Arc::new(AtomicBool::new(false)), + metadata_fields: config.additional_fields, })), TableType::Sink { retain } => OperatorNode::from_operator(Box::new(MqttSinkFunc { config: profile, diff --git a/crates/arroyo-connectors/src/mqtt/source/mod.rs b/crates/arroyo-connectors/src/mqtt/source/mod.rs index 485ee5560..3ef88bee4 100644 --- a/crates/arroyo-connectors/src/mqtt/source/mod.rs +++ b/crates/arroyo-connectors/src/mqtt/source/mod.rs @@ -1,3 +1,4 @@ +use arroyo_formats::de::FieldValueType; use async_trait::async_trait; use std::collections::HashMap; use std::num::NonZeroU32; @@ -33,6 +34,7 @@ pub struct MqttSourceFunc { pub bad_data: Option, pub messages_per_second: NonZeroU32, pub subscribed: Arc, + pub metadata_fields: Option>, } #[async_trait] @@ -65,6 +67,7 @@ impl SourceOperator for MqttSourceFunc { } } +#[allow(clippy::too_many_arguments)] impl MqttSourceFunc { pub fn new( config: MqttConfig, @@ -74,6 +77,7 @@ impl MqttSourceFunc { framing: Option, bad_data: Option, messages_per_second: u32, + metadata_fields: Option>, ) -> Self { Self { config, @@ -84,6 +88,7 @@ impl MqttSourceFunc { bad_data, messages_per_second: NonZeroU32::new(messages_per_second).unwrap(), subscribed: Arc::new(AtomicBool::new(false)), + metadata_fields, } } @@ -143,7 +148,20 @@ impl MqttSourceFunc { event = eventloop.poll() => { match event { Ok(MqttEvent::Incoming(Incoming::Publish(p))) => { - ctx.deserialize_slice(&p.payload, SystemTime::now(), None).await?; + let topic = String::from_utf8_lossy(&p.topic).to_string(); + let connector_metadata: Option>> = + self.metadata_fields.as_ref().map(|fields| { + fields.iter() + .filter_map(|(k, v)| { + if v == "topic" { + Some((k, FieldValueType::String(&topic))) + } else { + None + } + }) + .collect() + }); + ctx.deserialize_slice(&p.payload, SystemTime::now(), connector_metadata).await?; rate_limiter.until_ready().await; } Ok(MqttEvent::Outgoing(Outgoing::Subscribe(_))) => { diff --git a/crates/arroyo-connectors/src/mqtt/source/test.rs b/crates/arroyo-connectors/src/mqtt/source/test.rs index c6f26a432..5ab61c09d 100644 --- a/crates/arroyo-connectors/src/mqtt/source/test.rs +++ b/crates/arroyo-connectors/src/mqtt/source/test.rs @@ -125,6 +125,7 @@ impl MqttTopicTester { None, None, 10, + None, ); let (to_control_tx, control_rx) = channel(128);