diff --git a/crates/arroyo-connectors/src/kafka/mod.rs b/crates/arroyo-connectors/src/kafka/mod.rs index 14bb5455e..6c63f10fc 100644 --- a/crates/arroyo-connectors/src/kafka/mod.rs +++ b/crates/arroyo-connectors/src/kafka/mod.rs @@ -336,6 +336,10 @@ impl Connector for KafkaConnector { name: "timestamp", data_type: DataType::Int64, }, + MetadataDef { + name: "key", + data_type: DataType::Binary, + }, ] } diff --git a/crates/arroyo-connectors/src/kafka/source/mod.rs b/crates/arroyo-connectors/src/kafka/source/mod.rs index 020c0da7a..083b68b39 100644 --- a/crates/arroyo-connectors/src/kafka/source/mod.rs +++ b/crates/arroyo-connectors/src/kafka/source/mod.rs @@ -204,10 +204,11 @@ impl KafkaSourceFunc { let mut connector_metadata = HashMap::new(); for f in &self.metadata_fields { connector_metadata.insert(f.field_name.as_str(), match f.key.as_str() { - "offset_id" => FieldValueType::Int64(msg.offset()), - "partition" => FieldValueType::Int32(msg.partition()), - "topic" => FieldValueType::String(topic), - "timestamp" => FieldValueType::Int64(timestamp), + "key" => FieldValueType::Bytes(msg.key()), + "offset_id" => FieldValueType::Int64(Some(msg.offset())), + "partition" => FieldValueType::Int32(Some(msg.partition())), + "topic" => FieldValueType::String(Some(topic)), + "timestamp" => FieldValueType::Int64(Some(timestamp)), k => unreachable!("Invalid metadata key '{}'", k), }); } diff --git a/crates/arroyo-connectors/src/mqtt/source/mod.rs b/crates/arroyo-connectors/src/mqtt/source/mod.rs index 4f1508fd2..69b96a74d 100644 --- a/crates/arroyo-connectors/src/mqtt/source/mod.rs +++ b/crates/arroyo-connectors/src/mqtt/source/mod.rs @@ -154,7 +154,7 @@ impl MqttSourceFunc { let mut connector_metadata = HashMap::new(); for mf in &self.metadata_fields { connector_metadata.insert(mf.field_name.as_str(), match mf.key.as_str() { - "topic" => FieldValueType::String(&topic), + "topic" => FieldValueType::String(Some(&topic)), k => unreachable!("invalid metadata key '{}' for mqtt", k) }); } diff --git a/crates/arroyo-connectors/src/redis/lookup.rs b/crates/arroyo-connectors/src/redis/lookup.rs index 79400bc4d..db63382d2 100644 --- a/crates/arroyo-connectors/src/redis/lookup.rs +++ b/crates/arroyo-connectors/src/redis/lookup.rs @@ -59,12 +59,15 @@ impl LookupConnector for RedisLookup { let mut additional = HashMap::new(); for (idx, (v, k)) in vs.iter().zip(keys).enumerate() { - additional.insert(LOOKUP_KEY_INDEX_FIELD, FieldValueType::UInt64(idx as u64)); + additional.insert( + LOOKUP_KEY_INDEX_FIELD, + FieldValueType::UInt64(Some(idx as u64)), + ); for m in &self.metadata_fields { additional.insert( m.field_name.as_str(), match m.key.as_str() { - "key" => FieldValueType::String(k.unwrap()), + "key" => FieldValueType::String(Some(k.unwrap())), k => unreachable!("Invalid metadata key '{}'", k), }, ); diff --git a/crates/arroyo-formats/src/de.rs b/crates/arroyo-formats/src/de.rs index 7d089b9c6..258b71a35 100644 --- a/crates/arroyo-formats/src/de.rs +++ b/crates/arroyo-formats/src/de.rs @@ -24,13 +24,13 @@ use std::sync::Arc; use std::time::{Instant, SystemTime}; use tokio::sync::Mutex; -#[derive(Debug, Clone)] +#[derive(Debug, Copy, Clone)] pub enum FieldValueType<'a> { - Int64(i64), - UInt64(u64), - Int32(i32), - String(&'a str), - // Extend with more types as needed + Int64(Option), + UInt64(Option), + Int32(Option), + String(Option<&'a str>), + Bytes(Option<&'a [u8]>), } struct ContextBuffer { @@ -480,6 +480,7 @@ impl ArrowDeserializer { FieldValueType::Int64(_) => Box::new(Int64Builder::new()), FieldValueType::UInt64(_) => Box::new(UInt64Builder::new()), FieldValueType::String(_) => Box::new(StringBuilder::new()), + FieldValueType::Bytes(_) => Box::new(BinaryBuilder::new()), }; builders.insert(key.to_string(), builder); } @@ -489,7 +490,7 @@ impl ArrowDeserializer { let builders = self.additional_fields_builder.as_mut().unwrap(); for (k, v) in additional_fields { - add_additional_fields(builders, k, v, count); + add_additional_fields(builders, k, *v, count); } } } @@ -674,52 +675,50 @@ impl ArrowDeserializer { } } +macro_rules! append_repeated_value { + ($builder:expr, $builder_ty:ty, $value:expr, $count:expr) => {{ + let b = $builder + .downcast_mut::<$builder_ty>() + .expect("additional field has incorrect type"); + + if let Some(v) = $value { + for _ in 0..$count { + b.append_value(v); + } + } else { + for _ in 0..$count { + b.append_null(); + } + } + }}; +} + fn add_additional_fields( builders: &mut HashMap>, key: &str, - value: &FieldValueType<'_>, + value: FieldValueType<'_>, count: usize, ) { let builder = builders .get_mut(key) .unwrap_or_else(|| panic!("unexpected additional field '{}'", key)) .as_any_mut(); - match value { - FieldValueType::Int32(i) => { - let b = builder - .downcast_mut::() - .expect("additional field has incorrect type"); - for _ in 0..count { - b.append_value(*i); - } + match value { + FieldValueType::Int32(v) => { + append_repeated_value!(builder, Int32Builder, v, count); } - FieldValueType::Int64(i) => { - let b = builder - .downcast_mut::() - .expect("additional field has incorrect type"); - - for _ in 0..count { - b.append_value(*i); - } + FieldValueType::Int64(v) => { + append_repeated_value!(builder, Int64Builder, v, count); } - FieldValueType::UInt64(i) => { - let b = builder - .downcast_mut::() - .expect("additional field has incorrect type"); - - for _ in 0..count { - b.append_value(*i); - } + FieldValueType::UInt64(v) => { + append_repeated_value!(builder, UInt64Builder, v, count); } - FieldValueType::String(s) => { - let b = builder - .downcast_mut::() - .expect("additional field has incorrect type"); - - for _ in 0..count { - b.append_value(*s); - } + FieldValueType::String(v) => { + append_repeated_value!(builder, StringBuilder, v, count); + } + FieldValueType::Bytes(v) => { + append_repeated_value!(builder, BinaryBuilder, v, count); } } } @@ -987,8 +986,8 @@ mod tests { let time = SystemTime::now(); let mut additional_fields = std::collections::HashMap::new(); - additional_fields.insert("y", FieldValueType::Int32(5)); - additional_fields.insert("z", FieldValueType::String("hello")); + additional_fields.insert("y", FieldValueType::Int32(Some(5))); + additional_fields.insert("z", FieldValueType::String(Some("hello"))); let result = deserializer .deserialize_slice(