Skip to content

Commit

Permalink
feat: support deserialization of AnyMessage
Browse files Browse the repository at this point in the history
  • Loading branch information
laplab committed Aug 24, 2023
1 parent 64af576 commit 016868a
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 6 deletions.
1 change: 1 addition & 0 deletions elfo-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub mod _priv {
permissions::{AtomicPermissions, Permissions},
request_table::RequestId,
};
pub use erased_serde;
pub use linkme;
pub use metrics;
#[cfg(feature = "network")]
Expand Down
103 changes: 97 additions & 6 deletions elfo-core/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ use fxhash::{FxHashMap, FxHashSet};
use linkme::distributed_slice;
use metrics::Label;
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde::{
de::{DeserializeSeed, SeqAccess, Visitor},
ser::SerializeTuple,
Deserialize, Deserializer, Serialize,
};
use smallbox::{smallbox, SmallBox};

use elfo_utils::unlikely;
Expand Down Expand Up @@ -146,19 +150,75 @@ impl Serialize for AnyMessage {
where
S: serde::ser::Serializer,
{
let mut tuple = serializer.serialize_tuple(3)?;
tuple.serialize_element(self.protocol())?;
tuple.serialize_element(self.name())?;

// TODO: avoid allocation here
self._erase().serialize(serializer)
let erased_msg = self._erase();
tuple.serialize_element(&*erased_msg)?;

tuple.end()
}
}

impl<'de> Deserialize<'de> for AnyMessage {
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
Err(serde::de::Error::custom(
"AnyMessage cannot be deserialized",
))
deserializer.deserialize_tuple(3, AnyMessageDeserializeVisitor)
}
}

struct AnyMessageDeserializeVisitor;

impl<'de> Visitor<'de> for AnyMessageDeserializeVisitor {
type Value = AnyMessage;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "tuple of 3 elements")
}

#[inline]
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let protocol = serde::de::SeqAccess::next_element::<&str>(&mut seq)?.ok_or(
serde::de::Error::invalid_length(0usize, &"tuple of 3 elements"),
)?;

let name = serde::de::SeqAccess::next_element::<&str>(&mut seq)?.ok_or(
serde::de::Error::invalid_length(1usize, &"tuple of 3 elements"),
)?;

serde::de::SeqAccess::next_element_seed(&mut seq, MessageTag { protocol, name })?.ok_or(
serde::de::Error::invalid_length(2usize, &"tuple of 3 elements"),
)
}
}

struct MessageTag<'a> {
protocol: &'a str,
name: &'a str,
}

impl<'de, 'tag> DeserializeSeed<'de> for MessageTag<'tag> {
type Value = AnyMessage;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
let deserialize = lookup_vtable(self.protocol, self.name)
.ok_or(serde::de::Error::custom(
"unknown protocol/name combination",
))?
.deserialize;

let mut deserializer = <dyn erased_serde::Deserializer<'_>>::erase(deserializer);
deserialize(&mut deserializer).map_err(serde::de::Error::custom)
}
}

Expand Down Expand Up @@ -275,6 +335,8 @@ pub struct MessageVTable {
pub clone: fn(&AnyMessage) -> AnyMessage,
pub debug: fn(&AnyMessage, &mut fmt::Formatter<'_>) -> fmt::Result,
pub erase: fn(&AnyMessage) -> dumping::ErasedMessage,
pub deserialize:
fn(&mut dyn erased_serde::Deserializer<'_>) -> Result<AnyMessage, erased_serde::Error>,
#[cfg(feature = "network")]
pub write_msgpack: fn(&AnyMessage, &mut Vec<u8>, usize) -> Result<(), rmps::encode::Error>,
#[cfg(feature = "network")]
Expand Down Expand Up @@ -327,3 +389,32 @@ pub(crate) fn check_uniqueness() -> Result<(), Vec<(String, String)>> {
.into_iter()
.collect::<Vec<_>>())
}

#[cfg(test)]
mod tests {
use crate::{message, message::AnyMessage, Message};

#[test]
fn any_message_deserialize() {
#[message]
#[derive(PartialEq)]
struct MyCoolMessage {
field_a: u32,
field_b: String,
field_c: f64,
}

let msg = MyCoolMessage {
field_a: 123,
field_b: String::from("Hello world"),
field_c: 0.5,
};
let any_msg = msg.clone().upcast();
let serialized = serde_json::to_string(&any_msg).unwrap();

let deserialized_any_msg: AnyMessage = serde_json::from_str(&serialized).unwrap();
let deserialized_msg: MyCoolMessage = deserialized_any_msg.downcast().unwrap();

assert_eq!(msg, deserialized_msg);
}
}
5 changes: 5 additions & 0 deletions elfo-macros-impl/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ pub fn message_impl(
smallbox!(Clone::clone(cast_ref(message)))
}

fn deserialize(deserializer: &mut dyn #internal::erased_serde::Deserializer<'_>) -> Result<#internal::AnyMessage, #internal::erased_serde::Error> {
#internal::erased_serde::deserialize::<#name>(deserializer).map(#crate_::Message::upcast)
}

#network_fns

#[linkme::distributed_slice(MESSAGE_LIST)]
Expand All @@ -267,6 +271,7 @@ pub fn message_impl(
clone,
debug,
erase,
deserialize,
#network_fns_ref
};

Expand Down

0 comments on commit 016868a

Please sign in to comment.