Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support deserialization of AnyMessage #107

Merged
merged 2 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
<!-- next-header -->

## [Unreleased] - ReleaseDate
### Added
- message: support deserialization of AnyMessage ([#107]).

[#107]: https://github.com/elfo-rs/elfo/pull/107

## [0.2.0-alpha.7] - 2023-08-11
### Fixed
Expand Down
1 change: 1 addition & 0 deletions elfo-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub mod _priv {
permissions::{AtomicPermissions, Permissions},
request_table::RequestId,
};
pub use erased_serde;
pub use linkme;
pub use metrics;
#[cfg(feature = "network")]
Expand Down
103 changes: 97 additions & 6 deletions elfo-core/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ use fxhash::{FxHashMap, FxHashSet};
use linkme::distributed_slice;
use metrics::Label;
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde::{
de::{DeserializeSeed, SeqAccess, Visitor},
ser::SerializeTuple,
Deserialize, Deserializer, Serialize,
};
use smallbox::{smallbox, SmallBox};

use elfo_utils::unlikely;
Expand Down Expand Up @@ -146,19 +150,75 @@ impl Serialize for AnyMessage {
where
S: serde::ser::Serializer,
{
let mut tuple = serializer.serialize_tuple(3)?;
tuple.serialize_element(self.protocol())?;
tuple.serialize_element(self.name())?;

// TODO: avoid allocation here
self._erase().serialize(serializer)
let erased_msg = self._erase();
tuple.serialize_element(&*erased_msg)?;

tuple.end()
}
}

impl<'de> Deserialize<'de> for AnyMessage {
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
Err(serde::de::Error::custom(
"AnyMessage cannot be deserialized",
))
deserializer.deserialize_tuple(3, AnyMessageDeserializeVisitor)
}
}

struct AnyMessageDeserializeVisitor;

impl<'de> Visitor<'de> for AnyMessageDeserializeVisitor {
type Value = AnyMessage;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "tuple of 3 elements")
}

#[inline]
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let protocol = serde::de::SeqAccess::next_element::<&str>(&mut seq)?.ok_or(
serde::de::Error::invalid_length(0usize, &"tuple of 3 elements"),
)?;

let name = serde::de::SeqAccess::next_element::<&str>(&mut seq)?.ok_or(
serde::de::Error::invalid_length(1usize, &"tuple of 3 elements"),
)?;

serde::de::SeqAccess::next_element_seed(&mut seq, MessageTag { protocol, name })?.ok_or(
serde::de::Error::invalid_length(2usize, &"tuple of 3 elements"),
)
}
}

struct MessageTag<'a> {
protocol: &'a str,
name: &'a str,
}

impl<'de, 'tag> DeserializeSeed<'de> for MessageTag<'tag> {
type Value = AnyMessage;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
let deserialize_any = lookup_vtable(self.protocol, self.name)
.ok_or(serde::de::Error::custom(
"unknown protocol/name combination",
))?
.deserialize_any;

let mut deserializer = <dyn erased_serde::Deserializer<'_>>::erase(deserializer);
deserialize_any(&mut deserializer).map_err(serde::de::Error::custom)
}
}

Expand Down Expand Up @@ -275,6 +335,8 @@ pub struct MessageVTable {
pub clone: fn(&AnyMessage) -> AnyMessage,
pub debug: fn(&AnyMessage, &mut fmt::Formatter<'_>) -> fmt::Result,
pub erase: fn(&AnyMessage) -> dumping::ErasedMessage,
pub deserialize_any:
fn(&mut dyn erased_serde::Deserializer<'_>) -> Result<AnyMessage, erased_serde::Error>,
#[cfg(feature = "network")]
pub write_msgpack: fn(&AnyMessage, &mut Vec<u8>, usize) -> Result<(), rmps::encode::Error>,
#[cfg(feature = "network")]
Expand Down Expand Up @@ -327,3 +389,32 @@ pub(crate) fn check_uniqueness() -> Result<(), Vec<(String, String)>> {
.into_iter()
.collect::<Vec<_>>())
}

#[cfg(test)]
mod tests {
use crate::{message, message::AnyMessage, Message};

#[test]
fn any_message_deserialize() {
#[message]
#[derive(PartialEq)]
struct MyCoolMessage {
field_a: u32,
field_b: String,
field_c: f64,
}

let msg = MyCoolMessage {
field_a: 123,
field_b: String::from("Hello world"),
field_c: 0.5,
};
let any_msg = msg.clone().upcast();
let serialized = serde_json::to_string(&any_msg).unwrap();

let deserialized_any_msg: AnyMessage = serde_json::from_str(&serialized).unwrap();
let deserialized_msg: MyCoolMessage = deserialized_any_msg.downcast().unwrap();

assert_eq!(msg, deserialized_msg);
}
}
5 changes: 5 additions & 0 deletions elfo-macros-impl/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ pub fn message_impl(
smallbox!(Clone::clone(cast_ref(message)))
}

fn deserialize_any(deserializer: &mut dyn #internal::erased_serde::Deserializer<'_>) -> Result<#internal::AnyMessage, #internal::erased_serde::Error> {
#internal::erased_serde::deserialize::<#name>(deserializer).map(#crate_::Message::upcast)
}

#network_fns

#[linkme::distributed_slice(MESSAGE_LIST)]
Expand All @@ -267,6 +271,7 @@ pub fn message_impl(
clone,
debug,
erase,
deserialize_any,
#network_fns_ref
};

Expand Down