diff --git a/CHANGELOG.md b/CHANGELOG.md index 818fd47b..42eb05f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - ReleaseDate +### Added +- message: support deserialization of AnyMessage ([#107]). + +[#107]: https://github.com/elfo-rs/elfo/pull/107 ## [0.2.0-alpha.7] - 2023-08-11 ### Fixed diff --git a/elfo-core/src/lib.rs b/elfo-core/src/lib.rs index a551bd1a..cc6aad9f 100644 --- a/elfo-core/src/lib.rs +++ b/elfo-core/src/lib.rs @@ -92,6 +92,7 @@ pub mod _priv { permissions::{AtomicPermissions, Permissions}, request_table::RequestId, }; + pub use erased_serde; pub use linkme; pub use metrics; #[cfg(feature = "network")] diff --git a/elfo-core/src/message.rs b/elfo-core/src/message.rs index cf48a381..e6ee81ee 100644 --- a/elfo-core/src/message.rs +++ b/elfo-core/src/message.rs @@ -4,7 +4,11 @@ use fxhash::{FxHashMap, FxHashSet}; use linkme::distributed_slice; use metrics::Label; use once_cell::sync::Lazy; -use serde::{Deserialize, Serialize}; +use serde::{ + de::{DeserializeSeed, SeqAccess, Visitor}, + ser::SerializeTuple, + Deserialize, Deserializer, Serialize, +}; use smallbox::{smallbox, SmallBox}; use elfo_utils::unlikely; @@ -146,19 +150,75 @@ impl Serialize for AnyMessage { where S: serde::ser::Serializer, { + let mut tuple = serializer.serialize_tuple(3)?; + tuple.serialize_element(self.protocol())?; + tuple.serialize_element(self.name())?; + // TODO: avoid allocation here - self._erase().serialize(serializer) + let erased_msg = self._erase(); + tuple.serialize_element(&*erased_msg)?; + + tuple.end() } } impl<'de> Deserialize<'de> for AnyMessage { - fn deserialize(_deserializer: D) -> Result + fn deserialize(deserializer: D) -> Result where D: serde::de::Deserializer<'de>, { - Err(serde::de::Error::custom( - "AnyMessage cannot be deserialized", - )) + deserializer.deserialize_tuple(3, AnyMessageDeserializeVisitor) + } +} + +struct AnyMessageDeserializeVisitor; + +impl<'de> Visitor<'de> for AnyMessageDeserializeVisitor { + type Value = AnyMessage; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "tuple of 3 elements") + } + + #[inline] + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let protocol = serde::de::SeqAccess::next_element::<&str>(&mut seq)?.ok_or( + serde::de::Error::invalid_length(0usize, &"tuple of 3 elements"), + )?; + + let name = serde::de::SeqAccess::next_element::<&str>(&mut seq)?.ok_or( + serde::de::Error::invalid_length(1usize, &"tuple of 3 elements"), + )?; + + serde::de::SeqAccess::next_element_seed(&mut seq, MessageTag { protocol, name })?.ok_or( + serde::de::Error::invalid_length(2usize, &"tuple of 3 elements"), + ) + } +} + +struct MessageTag<'a> { + protocol: &'a str, + name: &'a str, +} + +impl<'de, 'tag> DeserializeSeed<'de> for MessageTag<'tag> { + type Value = AnyMessage; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let deserialize_any = lookup_vtable(self.protocol, self.name) + .ok_or(serde::de::Error::custom( + "unknown protocol/name combination", + ))? + .deserialize_any; + + let mut deserializer = >::erase(deserializer); + deserialize_any(&mut deserializer).map_err(serde::de::Error::custom) } } @@ -275,6 +335,8 @@ pub struct MessageVTable { pub clone: fn(&AnyMessage) -> AnyMessage, pub debug: fn(&AnyMessage, &mut fmt::Formatter<'_>) -> fmt::Result, pub erase: fn(&AnyMessage) -> dumping::ErasedMessage, + pub deserialize_any: + fn(&mut dyn erased_serde::Deserializer<'_>) -> Result, #[cfg(feature = "network")] pub write_msgpack: fn(&AnyMessage, &mut Vec, usize) -> Result<(), rmps::encode::Error>, #[cfg(feature = "network")] @@ -327,3 +389,32 @@ pub(crate) fn check_uniqueness() -> Result<(), Vec<(String, String)>> { .into_iter() .collect::>()) } + +#[cfg(test)] +mod tests { + use crate::{message, message::AnyMessage, Message}; + + #[test] + fn any_message_deserialize() { + #[message] + #[derive(PartialEq)] + struct MyCoolMessage { + field_a: u32, + field_b: String, + field_c: f64, + } + + let msg = MyCoolMessage { + field_a: 123, + field_b: String::from("Hello world"), + field_c: 0.5, + }; + let any_msg = msg.clone().upcast(); + let serialized = serde_json::to_string(&any_msg).unwrap(); + + let deserialized_any_msg: AnyMessage = serde_json::from_str(&serialized).unwrap(); + let deserialized_msg: MyCoolMessage = deserialized_any_msg.downcast().unwrap(); + + assert_eq!(msg, deserialized_msg); + } +} diff --git a/elfo-macros-impl/src/message.rs b/elfo-macros-impl/src/message.rs index d8b42ab6..bf2b92c4 100644 --- a/elfo-macros-impl/src/message.rs +++ b/elfo-macros-impl/src/message.rs @@ -252,6 +252,10 @@ pub fn message_impl( smallbox!(Clone::clone(cast_ref(message))) } + fn deserialize_any(deserializer: &mut dyn #internal::erased_serde::Deserializer<'_>) -> Result<#internal::AnyMessage, #internal::erased_serde::Error> { + #internal::erased_serde::deserialize::<#name>(deserializer).map(#crate_::Message::upcast) + } + #network_fns #[linkme::distributed_slice(MESSAGE_LIST)] @@ -267,6 +271,7 @@ pub fn message_impl( clone, debug, erase, + deserialize_any, #network_fns_ref };