diff --git a/mqtt-v5-broker/src/broker.rs b/mqtt-v5-broker/src/broker.rs index bbc5131..d561ede 100644 --- a/mqtt-v5-broker/src/broker.rs +++ b/mqtt-v5-broker/src/broker.rs @@ -1,4 +1,5 @@ -use crate::{client::ClientMessage, tree::SubscriptionTree}; +use crate::{client::ClientMessage, retained::RetainedMessageTree, tree::SubscriptionTree}; +use bytes::Bytes; use mqtt_v5::{ topic::TopicFilter, types::{ @@ -193,13 +194,20 @@ pub struct Broker { sender: Sender, receiver: Receiver, subscriptions: SubscriptionTree, + retained_messages: RetainedMessageTree, } impl Broker { pub fn new() -> Self { let (sender, receiver) = mpsc::channel(100); - Self { sessions: HashMap::new(), sender, receiver, subscriptions: SubscriptionTree::new() } + Self { + sessions: HashMap::new(), + sender, + receiver, + subscriptions: SubscriptionTree::new(), + retained_messages: RetainedMessageTree::new(), + } } pub fn sender(&self) -> Sender { @@ -346,6 +354,7 @@ impl Broker { async fn handle_subscribe(&mut self, client_id: String, packet: SubscribePacket) { let subscriptions = &mut self.subscriptions; + let retained_messages = &self.retained_messages; if let Some(session) = self.sessions.get_mut(&client_id) { // If a Server receives a SUBSCRIBE packet containing a Topic Filter that @@ -368,7 +377,7 @@ impl Broker { // and return the QoS that was granted. let granted_qos_values = packet .subscription_topics - .into_iter() + .iter() .map(|topic| { let session_subscription = SessionSubscription { client_id: client_id.clone(), @@ -394,6 +403,35 @@ impl Broker { }; session.send(ClientMessage::Packet(Packet::SubscribeAck(subscribe_ack))).await; + + // Send all retained messages which match the new subscriptions + let mut publish_packets = vec![]; + for topic in packet.subscription_topics { + for (topic, retained_message) in + retained_messages.retained_messages(&topic.topic_filter) + { + let publish = PublishPacket { + is_duplicate: false, + qos: QoS::AtMostOnce, // TODO(bschwind) + retain: true, + topic, + payload: retained_message.clone(), + + packet_id: None, // TODO(bschwind) + payload_format_indicator: None, + message_expiry_interval: None, + topic_alias: None, + response_topic: None, + correlation_data: None, + user_properties: vec![], + subscription_identifier: None, + content_type: None, + }; + publish_packets.push(Packet::Publish(publish)); + } + } + + session.send(ClientMessage::Packets(publish_packets)).await; } } @@ -531,6 +569,16 @@ impl Broker { async fn handle_publish(&mut self, client_id: String, packet: PublishPacket) { let mut is_dup = false; + if packet.retain { + if packet.payload.len() > 0 { + println!("Storing retained message for topic {:?}", packet.topic); + self.retained_messages.insert(&packet.topic, packet.payload.clone()); + } else { + println!("Deleting retained message for topic {:?}", packet.topic); + self.retained_messages.remove(&packet.topic); + } + } + // For QoS2, ensure this packet isn't delivered twice. So if we have an outgoing // publish receive with the same ID, just send the publish receive again but don't forward // the message. diff --git a/mqtt-v5-broker/src/main.rs b/mqtt-v5-broker/src/main.rs index bbce2d4..11a2823 100644 --- a/mqtt-v5-broker/src/main.rs +++ b/mqtt-v5-broker/src/main.rs @@ -22,6 +22,7 @@ use tokio_util::codec::Framed; mod broker; mod client; +mod retained; mod tree; async fn client_handler(stream: TcpStream, broker_tx: Sender) { diff --git a/mqtt-v5-broker/src/retained.rs b/mqtt-v5-broker/src/retained.rs new file mode 100644 index 0000000..056cce8 --- /dev/null +++ b/mqtt-v5-broker/src/retained.rs @@ -0,0 +1,304 @@ +use mqtt_v5::topic::{Topic, TopicFilter, TopicLevel}; +use std::collections::{hash_map::Entry, HashMap}; + +#[derive(Debug)] +pub struct RetainedMessageTreeNode { + retained_data: Option, + // TODO(bschwind) - use TopicLevel instead of String + concrete_topic_levels: HashMap>, +} + +#[derive(Debug)] +pub struct RetainedMessageTree { + root: RetainedMessageTreeNode, +} + +impl RetainedMessageTree { + pub fn new() -> Self { + Self { root: RetainedMessageTreeNode::new() } + } + + pub fn insert(&mut self, topic: &Topic, retained_data: T) { + self.root.insert(topic, retained_data); + } + + /// Get the retained messages which match a given topic filter. + pub fn retained_messages( + &self, + topic_filter: &TopicFilter, + ) -> impl Iterator { + self.root.retained_messages_recursive(topic_filter) + } + + pub fn remove(&mut self, topic: &Topic) -> Option { + self.root.remove(topic) + } + + #[allow(dead_code)] + fn is_empty(&self) -> bool { + self.root.is_empty() + } +} + +impl RetainedMessageTreeNode { + fn new() -> Self { + Self { retained_data: None, concrete_topic_levels: HashMap::new() } + } + + fn is_empty(&self) -> bool { + self.retained_data.is_none() && self.concrete_topic_levels.is_empty() + } + + fn insert(&mut self, topic: &Topic, retained_data: T) { + let mut current_tree = self; + + for level in topic.levels() { + match level { + TopicLevel::SingleLevelWildcard | TopicLevel::MultiLevelWildcard => { + unreachable!("Publish topics only contain concrete levels"); + }, + TopicLevel::Concrete(concrete_topic_level) => { + if !current_tree.concrete_topic_levels.contains_key(concrete_topic_level) { + current_tree.concrete_topic_levels.insert( + concrete_topic_level.to_string(), + RetainedMessageTreeNode::new(), + ); + } + + // TODO - Do this without another hash lookup + current_tree = + current_tree.concrete_topic_levels.get_mut(concrete_topic_level).unwrap(); + }, + } + } + + current_tree.retained_data = Some(retained_data); + } + + fn remove(&mut self, topic: &Topic) -> Option { + let mut current_tree = self; + let mut stack: Vec<(*mut RetainedMessageTreeNode, usize)> = vec![]; + + let levels: Vec = topic.levels().collect(); + let mut level_index = 0; + + for level in &levels { + match level { + TopicLevel::SingleLevelWildcard | TopicLevel::MultiLevelWildcard => { + unreachable!("Publish topics only contain concrete levels"); + }, + TopicLevel::Concrete(concrete_topic_level) => { + if current_tree.concrete_topic_levels.contains_key(*concrete_topic_level) { + stack.push((&mut *current_tree, level_index)); + level_index += 1; + + current_tree = current_tree + .concrete_topic_levels + .get_mut(*concrete_topic_level) + .unwrap(); + } else { + return None; + } + }, + } + } + + let return_val = current_tree.retained_data.take(); + + // Go up the stack, cleaning up empty nodes + while let Some((stack_val, level_index)) = stack.pop() { + let tree = unsafe { &mut *stack_val }; + + let level = &levels[level_index]; + + match level { + TopicLevel::SingleLevelWildcard | TopicLevel::MultiLevelWildcard => { + unreachable!("Publish topics only contain concrete levels"); + }, + TopicLevel::Concrete(concrete_topic_level) => { + if let Entry::Occupied(o) = + tree.concrete_topic_levels.entry((*concrete_topic_level).to_string()) + { + if o.get().is_empty() { + o.remove_entry(); + } + } + }, + } + } + + return_val + } + + pub fn retained_messages_recursive( + &self, + topic_filter: &TopicFilter, + ) -> impl Iterator { + let mut retained_messages = vec![]; + let mut path = vec![]; + let levels: Vec = topic_filter.levels().collect(); + + Self::retained_messages_inner(self, &mut path, &levels, 0, &mut retained_messages); + + retained_messages.into_iter() + } + + fn retained_messages_inner<'a>( + current_tree: &'a Self, + path: &mut Vec, + levels: &[TopicLevel], + current_level: usize, + retained_messages: &mut Vec<(Topic, &'a T)>, + ) { + let level = &levels[current_level]; + + match level { + TopicLevel::SingleLevelWildcard => { + for (level, sub_tree) in ¤t_tree.concrete_topic_levels { + path.push(level.to_string()); + + if current_level + 1 < levels.len() { + Self::retained_messages_inner( + sub_tree, + path, + levels, + current_level + 1, + retained_messages, + ); + } else if let Some(retained_data) = sub_tree.retained_data.as_ref() { + let topic = Topic::from_concrete_levels(path); + retained_messages.push((topic, retained_data)); + } + path.pop(); + } + }, + TopicLevel::MultiLevelWildcard => { + for (level, sub_tree) in ¤t_tree.concrete_topic_levels { + path.push(level.to_string()); + + if let Some(retained_data) = sub_tree.retained_data.as_ref() { + let topic = Topic::from_concrete_levels(path); + retained_messages.push((topic, retained_data)); + } + + Self::retained_messages_multilevel(sub_tree, path, retained_messages); + path.pop(); + } + }, + TopicLevel::Concrete(concrete_topic_level) => { + if current_tree.concrete_topic_levels.contains_key(*concrete_topic_level) { + let sub_tree = + current_tree.concrete_topic_levels.get(*concrete_topic_level).unwrap(); + + path.push(concrete_topic_level.to_string()); + + if current_level + 1 < levels.len() { + let sub_tree = + current_tree.concrete_topic_levels.get(*concrete_topic_level).unwrap(); + Self::retained_messages_inner( + sub_tree, + path, + levels, + current_level + 1, + retained_messages, + ); + } else if let Some(retained_data) = sub_tree.retained_data.as_ref() { + let topic = Topic::from_concrete_levels(path); + retained_messages.push((topic, retained_data)); + } + + path.pop(); + } + }, + } + } + + fn retained_messages_multilevel<'a>( + current_tree: &'a Self, + path: &mut Vec, + retained_messages: &mut Vec<(Topic, &'a T)>, + ) { + // Add all the retained messages and keep going. + for (level, sub_tree) in ¤t_tree.concrete_topic_levels { + path.push(level.to_string()); + if let Some(retained_data) = sub_tree.retained_data.as_ref() { + let topic = Topic::from_concrete_levels(path); + retained_messages.push((topic, retained_data)); + } + + Self::retained_messages_multilevel(sub_tree, path, retained_messages); + + path.pop(); + } + } +} + +#[cfg(test)] +mod tests { + use crate::retained::RetainedMessageTree; + + #[test] + fn test_insert() { + let mut sub_tree = RetainedMessageTree::new(); + sub_tree.insert(&"home/kitchen/temperature".parse().unwrap(), 1); + sub_tree.insert(&"home/bedroom/temperature".parse().unwrap(), 2); + sub_tree.insert(&"home/kitchen".parse().unwrap(), 7); + + sub_tree.insert(&"office/cafe".parse().unwrap(), 12); + sub_tree.insert(&"office/cafe/temperature".parse().unwrap(), 27); + + for msg in sub_tree.retained_messages(&"+/+/temperature".parse().unwrap()) { + dbg!(msg); + } + + assert_eq!(sub_tree.remove(&"home/kitchen/temperature".parse().unwrap()), Some(1)); + assert_eq!(sub_tree.remove(&"home/kitchen".parse().unwrap()), Some(7)); + assert_eq!(sub_tree.remove(&"home/kitchen".parse().unwrap()), None); + dbg!(sub_tree); + } + + #[test] + fn test_wildcards() { + let mut sub_tree = RetainedMessageTree::new(); + sub_tree.insert(&"home/bedroom/humidity/val".parse().unwrap(), 1); + sub_tree.insert(&"home/bedroom/temperature/val".parse().unwrap(), 2); + sub_tree.insert(&"home/kitchen/temperature/val".parse().unwrap(), 3); + sub_tree.insert(&"home/kitchen/humidity/val".parse().unwrap(), 4); + sub_tree.insert(&"home/kitchen/humidity/val/celsius".parse().unwrap(), 42); + + sub_tree.insert(&"office/cafe/humidity/val".parse().unwrap(), 5); + sub_tree.insert(&"office/cafe/temperature/val".parse().unwrap(), 6); + sub_tree.insert(&"office/meeting_room_1/temperature/val".parse().unwrap(), 7); + sub_tree.insert(&"office/meeting_room_1/humidity/val".parse().unwrap(), 8); + + let filter = "home/+/+/val"; + println!("{}", filter); + for msg in sub_tree.retained_messages(&filter.parse().unwrap()) { + dbg!(msg); + } + + let filter = "home/bedroom/#"; + println!("{}", filter); + for msg in sub_tree.retained_messages(&filter.parse().unwrap()) { + dbg!(msg); + } + + let filter = "#"; + println!("{}", filter); + for msg in sub_tree.retained_messages(&filter.parse().unwrap()) { + dbg!(msg); + } + + let filter = "+"; + println!("{}", filter); + for msg in sub_tree.retained_messages(&filter.parse().unwrap()) { + dbg!(msg); + } + + let filter = "+/+/#"; + println!("{}", filter); + for msg in sub_tree.retained_messages(&filter.parse().unwrap()) { + dbg!(msg); + } + } +} diff --git a/mqtt-v5-broker/src/tree.rs b/mqtt-v5-broker/src/tree.rs index 8698e95..5faa30c 100644 --- a/mqtt-v5-broker/src/tree.rs +++ b/mqtt-v5-broker/src/tree.rs @@ -224,7 +224,6 @@ impl SubscriptionTreeNode { let sub_tree = current_tree.concrete_topic_levels.get(*level).unwrap(); if current_level + 1 < levels.len() { - let sub_tree = current_tree.concrete_topic_levels.get(*level).unwrap(); tree_stack.push((sub_tree, current_level + 1)); } else { subscriptions diff --git a/mqtt-v5/src/topic.rs b/mqtt-v5/src/topic.rs index afa2804..be352c3 100644 --- a/mqtt-v5/src/topic.rs +++ b/mqtt-v5/src/topic.rs @@ -27,6 +27,12 @@ impl Topic { pub fn topic_name(&self) -> &str { &self.topic_name } + + pub fn from_concrete_levels(levels: &[String]) -> Self { + let topic_name = levels.join(&TOPIC_SEPARATOR.to_string()); + + Self { topic_name, level_count: levels.len() as u32 } + } } #[derive(Debug, PartialEq)]