Skip to content

Commit

Permalink
refactor: removed mpc-net managed channel tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
fabian1409 committed Nov 7, 2024
1 parent b9584f3 commit e400d11
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 570 deletions.
37 changes: 13 additions & 24 deletions mpc-core/src/protocols/rep3/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use ark_ff::PrimeField;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use bytes::{Bytes, BytesMut};
use eyre::{bail, eyre, Report};
use mpc_net::{channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler};
use mpc_net::{channel::BytesChannel, config::NetworkConfig, MpcNetworkHandler, TlsStream};
use tokio::runtime::Runtime;

use super::{
Expand Down Expand Up @@ -236,8 +236,8 @@ pub trait Rep3Network: Send {
#[derive(Debug)]
pub struct Rep3MpcNet {
pub(crate) id: PartyID,
pub(crate) chan_next: ChannelHandle<Bytes, BytesMut>,
pub(crate) chan_prev: ChannelHandle<Bytes, BytesMut>,
pub(crate) chan_next: BytesChannel<TlsStream, TlsStream>,
pub(crate) chan_prev: BytesChannel<TlsStream, TlsStream>,
// TODO we should be able to get rid of this mutex once we dont remove streams from the pool anymore
pub(crate) net_handler: Arc<Mutex<MpcNetworkHandler>>,
// order is important, runtime MUST be dropped after network handler
Expand All @@ -263,9 +263,6 @@ impl Rep3MpcNet {
.get_byte_channel(&id.prev_id().into())
.ok_or(eyre!("no prev channel found"))?;

let chan_next = net_handler.spawn(chan_next);
let chan_prev = net_handler.spawn(chan_prev);

Ok(Self {
id,
net_handler: Arc::new(Mutex::new(net_handler)),
Expand All @@ -278,35 +275,29 @@ impl Rep3MpcNet {
/// Sends bytes over the network to the target party.
pub fn send_bytes(&mut self, target: PartyID, data: Bytes) -> std::io::Result<()> {
if target == self.id.next_id() {
std::mem::drop(self.chan_next.blocking_send(data));
Ok(())
self.runtime.block_on(self.chan_next.send(data))
} else if target == self.id.prev_id() {
std::mem::drop(self.chan_prev.blocking_send(data));
Ok(())
self.runtime.block_on(self.chan_prev.send(data))
} else {
return Err(std::io::Error::new(
Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Cannot send to self",
));
))
}
}

/// Receives bytes over the network from the party with the given id.
pub fn recv_bytes(&mut self, from: PartyID) -> std::io::Result<BytesMut> {
let data = if from == self.id.prev_id() {
self.chan_prev.blocking_recv().blocking_recv()
if from == self.id.prev_id() {
self.runtime.block_on(self.chan_prev.recv())
} else if from == self.id.next_id() {
self.chan_next.blocking_recv().blocking_recv()
self.runtime.block_on(self.chan_next.recv())
} else {
return Err(std::io::Error::new(
Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Cannot recv from self",
));
};
let data = data.map_err(|_| {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "receive channel end died")
})??;
Ok(data)
))
}
}
}

Expand Down Expand Up @@ -364,8 +355,6 @@ impl Rep3Network for Rep3MpcNet {
let chan_prev = net_handler
.get_byte_channel(&id.prev_id().into())
.expect("no prev channel found");
let chan_next = net_handler.spawn(chan_next);
let chan_prev = net_handler.spawn(chan_prev);

Ok(Self {
id,
Expand Down
52 changes: 17 additions & 35 deletions mpc-core/src/protocols/shamir/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use bytes::{Bytes, BytesMut};
use eyre::{bail, eyre, Report};
use mpc_net::{channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler};
use eyre::{bail, ContextCompat, Report};
use mpc_net::{channel::BytesChannel, config::NetworkConfig, MpcNetworkHandler, TlsStream};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
Expand Down Expand Up @@ -79,7 +79,7 @@ pub trait ShamirNetwork: Send {
pub struct ShamirMpcNet {
pub(crate) id: usize, // 0 <= id < num_parties
pub(crate) num_parties: usize,
pub(crate) channels: HashMap<usize, ChannelHandle<Bytes, BytesMut>>,
pub(crate) channels: HashMap<usize, BytesChannel<TlsStream, TlsStream>>,
// TODO we should be able to get rid of this mutex once we dont remove streams from the pool anymore
pub(crate) net_handler: Arc<Mutex<MpcNetworkHandler>>,
// order is important, runtime MUST be dropped after network handler
Expand All @@ -103,15 +103,11 @@ impl ShamirMpcNet {
.enable_all()
.build()?;
let mut net_handler = runtime.block_on(MpcNetworkHandler::establish(config))?;
let mut channels = HashMap::with_capacity(num_parties - 1);

for other_id in 0..num_parties {
if other_id != id {
let chan = net_handler
.get_byte_channel(&other_id)
.ok_or_else(|| eyre!("no channel found for party id={}", other_id))?;
channels.insert(other_id, net_handler.spawn(chan));
}
let channels = net_handler
.get_byte_channels()
.context("all channels found")?;
if channels.len() != num_parties - 1 {
bail!("Unexpected channels found")
}

Ok(Self {
Expand All @@ -126,8 +122,7 @@ impl ShamirMpcNet {
/// Sends bytes over the network to the target party.
pub fn send_bytes(&mut self, target: usize, data: Bytes) -> std::io::Result<()> {
if let Some(chan) = self.channels.get_mut(&target) {
std::mem::drop(chan.blocking_send(data));
Ok(())
self.runtime.block_on(chan.send(data))
} else {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
Expand All @@ -138,19 +133,14 @@ impl ShamirMpcNet {

/// Receives bytes over the network from the party with the given id.
pub fn recv_bytes(&mut self, from: usize) -> std::io::Result<BytesMut> {
let data = if let Some(chan) = self.channels.get_mut(&from) {
chan.blocking_recv().blocking_recv()
if let Some(chan) = self.channels.get_mut(&from) {
self.runtime.block_on(chan.recv())
} else {
return Err(std::io::Error::new(
Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("No channel found for party id={}", from),
));
};

let data = data.map_err(|_| {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "receive channel end died")
})??;
Ok(data)
))
}
}

pub(crate) fn _id(&self) -> usize {
Expand Down Expand Up @@ -262,17 +252,9 @@ impl ShamirNetwork for ShamirMpcNet {
let id = self.id;
let num_parties = self.num_parties;
let mut net_handler = self.net_handler.lock().unwrap();

let mut channels = HashMap::with_capacity(num_parties - 1);
for other_id in 0..num_parties {
if other_id != id {
let chan = net_handler
.get_byte_channel(&other_id)
.expect("to find channel");
channels.insert(other_id, net_handler.spawn(chan));
}
}

// TODO return error
let channels = net_handler.get_byte_channels().expect("all channels found");
assert_eq!(channels.len(), self.num_parties - 1);
Ok(Self {
id,
num_parties,
Expand Down
16 changes: 3 additions & 13 deletions mpc-net/examples/three_party.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ use color_eyre::{
eyre::{eyre, Context, ContextCompat},
Result,
};
use futures::{SinkExt, StreamExt};
use mpc_net::{
config::{NetworkConfig, NetworkConfigFile},
MpcNetworkHandler,
};
use tokio::io::AsyncWriteExt;

#[derive(Parser)]
struct Args {
Expand Down Expand Up @@ -43,17 +41,9 @@ async fn main() -> Result<()> {
}
// recv from all channels
for (&_, channel) in channels.iter_mut() {
let buf = channel.next().await;
if let Some(Ok(b)) = buf {
println!("received {}, should be {}", b[0], my_id as u8);
assert!(b.iter().all(|&x| x == my_id as u8))
}
}

// make sure all write are done by shutting down all streams
for (_, channel) in channels.into_iter() {
let (write, _) = channel.split();
write.into_inner().shutdown().await?;
let buf = channel.recv().await?;
println!("received {}, should be {}", buf[0], my_id as u8);
assert!(buf.iter().all(|&x| x == my_id as u8))
}

network.print_connection_stats(&mut std::io::stdout())?;
Expand Down
19 changes: 4 additions & 15 deletions mpc-net/examples/three_party_bincode_channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@ use color_eyre::{
eyre::{eyre, Context, ContextCompat},
Result,
};
use futures::{SinkExt, StreamExt};
use mpc_net::{
config::{NetworkConfig, NetworkConfigFile},
MpcNetworkHandler,
};
use serde::{Deserialize, Serialize};
use tokio::io::AsyncWriteExt;

#[derive(Parser)]
struct Args {
Expand Down Expand Up @@ -46,11 +44,9 @@ async fn main() -> Result<()> {
}
// recv from all channels
for (&_, channel) in channels.iter_mut() {
let buf = channel.next().await;
if let Some(Ok(Message::Ping(b))) = buf {
let buf = channel.recv().await?;
if let Message::Ping(b) = buf {
assert!(b.iter().all(|&x| x == my_id as u8))
} else {
panic!("could not receive message");
}
}
// send to all channels
Expand All @@ -60,18 +56,11 @@ async fn main() -> Result<()> {
}
// recv from all channels
for (&_, channel) in channels.iter_mut() {
let buf = channel.next().await;
if let Some(Ok(Message::Pong(b))) = buf {
let buf = channel.recv().await?;
if let Message::Pong(b) = buf {
assert!(b.iter().all(|&x| x == my_id as u8))
} else {
panic!("could not receive message");
}
}
// make sure all write are done by shutting down all streams
for (_, channel) in channels.into_iter() {
let (write, _) = channel.split();
write.into_inner().shutdown().await?;
}
network.print_connection_stats(&mut std::io::stdout())?;

Ok(())
Expand Down
19 changes: 4 additions & 15 deletions mpc-net/examples/three_party_custom_channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@ use color_eyre::{
eyre::{eyre, Context, ContextCompat},
Result,
};
use futures::{SinkExt, StreamExt};
use mpc_net::{
config::{NetworkConfig, NetworkConfigFile},
MpcNetworkHandler,
};
use tokio::io::AsyncWriteExt;
use tokio_util::codec::{Decoder, Encoder};

#[derive(Parser)]
Expand Down Expand Up @@ -46,11 +44,9 @@ async fn main() -> Result<()> {
}
// recv from all channels
for (&_, channel) in channels.iter_mut() {
let buf = channel.next().await;
if let Some(Ok(Message::Ping(b))) = buf {
let buf = channel.recv().await?;
if let Message::Ping(b) = buf {
assert!(b.iter().all(|&x| x == my_id as u8))
} else {
panic!("could not receive message");
}
}
// send to all channels
Expand All @@ -60,18 +56,11 @@ async fn main() -> Result<()> {
}
// recv from all channels
for (&_, channel) in channels.iter_mut() {
let buf = channel.next().await;
if let Some(Ok(Message::Pong(b))) = buf {
let buf = channel.recv().await?;
if let Message::Pong(b) = buf {
assert!(b.iter().all(|&x| x == my_id as u8))
} else {
panic!("could not receive message");
}
}
// make sure all write are done by shutting down all streams
for (_, channel) in channels.into_iter() {
let (write, _) = channel.split();
write.into_inner().shutdown().await?;
}
network.print_connection_stats(&mut std::io::stdout())?;

Ok(())
Expand Down
61 changes: 0 additions & 61 deletions mpc-net/examples/three_party_managed.rs

This file was deleted.

Loading

0 comments on commit e400d11

Please sign in to comment.