diff --git a/mpc-core/src/protocols/rep3/network.rs b/mpc-core/src/protocols/rep3/network.rs index a21c53746..59b92edf2 100644 --- a/mpc-core/src/protocols/rep3/network.rs +++ b/mpc-core/src/protocols/rep3/network.rs @@ -9,7 +9,7 @@ use ark_ff::PrimeField; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use bytes::{Bytes, BytesMut}; use eyre::{bail, eyre, Report}; -use mpc_net::{channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler}; +use mpc_net::{channel::BytesChannel, config::NetworkConfig, MpcNetworkHandler, TlsStream}; use tokio::runtime::Runtime; use super::{ @@ -236,8 +236,8 @@ pub trait Rep3Network: Send { #[derive(Debug)] pub struct Rep3MpcNet { pub(crate) id: PartyID, - pub(crate) chan_next: ChannelHandle, - pub(crate) chan_prev: ChannelHandle, + pub(crate) chan_next: BytesChannel, + pub(crate) chan_prev: BytesChannel, // TODO we should be able to get rid of this mutex once we dont remove streams from the pool anymore pub(crate) net_handler: Arc>, // order is important, runtime MUST be dropped after network handler @@ -263,9 +263,6 @@ impl Rep3MpcNet { .get_byte_channel(&id.prev_id().into()) .ok_or(eyre!("no prev channel found"))?; - let chan_next = net_handler.spawn(chan_next); - let chan_prev = net_handler.spawn(chan_prev); - Ok(Self { id, net_handler: Arc::new(Mutex::new(net_handler)), @@ -278,35 +275,29 @@ impl Rep3MpcNet { /// Sends bytes over the network to the target party. pub fn send_bytes(&mut self, target: PartyID, data: Bytes) -> std::io::Result<()> { if target == self.id.next_id() { - std::mem::drop(self.chan_next.blocking_send(data)); - Ok(()) + self.runtime.block_on(self.chan_next.send(data)) } else if target == self.id.prev_id() { - std::mem::drop(self.chan_prev.blocking_send(data)); - Ok(()) + self.runtime.block_on(self.chan_prev.send(data)) } else { - return Err(std::io::Error::new( + Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, "Cannot send to self", - )); + )) } } /// Receives bytes over the network from the party with the given id. pub fn recv_bytes(&mut self, from: PartyID) -> std::io::Result { - let data = if from == self.id.prev_id() { - self.chan_prev.blocking_recv().blocking_recv() + if from == self.id.prev_id() { + self.runtime.block_on(self.chan_prev.recv()) } else if from == self.id.next_id() { - self.chan_next.blocking_recv().blocking_recv() + self.runtime.block_on(self.chan_next.recv()) } else { - return Err(std::io::Error::new( + Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, "Cannot recv from self", - )); - }; - let data = data.map_err(|_| { - std::io::Error::new(std::io::ErrorKind::BrokenPipe, "receive channel end died") - })??; - Ok(data) + )) + } } } @@ -364,8 +355,6 @@ impl Rep3Network for Rep3MpcNet { let chan_prev = net_handler .get_byte_channel(&id.prev_id().into()) .expect("no prev channel found"); - let chan_next = net_handler.spawn(chan_next); - let chan_prev = net_handler.spawn(chan_prev); Ok(Self { id, diff --git a/mpc-core/src/protocols/shamir/network.rs b/mpc-core/src/protocols/shamir/network.rs index 3b77ace96..e0f1e1b19 100644 --- a/mpc-core/src/protocols/shamir/network.rs +++ b/mpc-core/src/protocols/shamir/network.rs @@ -4,8 +4,8 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use bytes::{Bytes, BytesMut}; -use eyre::{bail, eyre, Report}; -use mpc_net::{channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler}; +use eyre::{bail, ContextCompat, Report}; +use mpc_net::{channel::BytesChannel, config::NetworkConfig, MpcNetworkHandler, TlsStream}; use std::{ collections::HashMap, sync::{Arc, Mutex}, @@ -79,7 +79,7 @@ pub trait ShamirNetwork: Send { pub struct ShamirMpcNet { pub(crate) id: usize, // 0 <= id < num_parties pub(crate) num_parties: usize, - pub(crate) channels: HashMap>, + pub(crate) channels: HashMap>, // TODO we should be able to get rid of this mutex once we dont remove streams from the pool anymore pub(crate) net_handler: Arc>, // order is important, runtime MUST be dropped after network handler @@ -103,15 +103,11 @@ impl ShamirMpcNet { .enable_all() .build()?; let mut net_handler = runtime.block_on(MpcNetworkHandler::establish(config))?; - let mut channels = HashMap::with_capacity(num_parties - 1); - - for other_id in 0..num_parties { - if other_id != id { - let chan = net_handler - .get_byte_channel(&other_id) - .ok_or_else(|| eyre!("no channel found for party id={}", other_id))?; - channels.insert(other_id, net_handler.spawn(chan)); - } + let channels = net_handler + .get_byte_channels() + .context("all channels found")?; + if channels.len() != num_parties - 1 { + bail!("Unexpected channels found") } Ok(Self { @@ -126,8 +122,7 @@ impl ShamirMpcNet { /// Sends bytes over the network to the target party. pub fn send_bytes(&mut self, target: usize, data: Bytes) -> std::io::Result<()> { if let Some(chan) = self.channels.get_mut(&target) { - std::mem::drop(chan.blocking_send(data)); - Ok(()) + self.runtime.block_on(chan.send(data)) } else { Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, @@ -138,19 +133,14 @@ impl ShamirMpcNet { /// Receives bytes over the network from the party with the given id. pub fn recv_bytes(&mut self, from: usize) -> std::io::Result { - let data = if let Some(chan) = self.channels.get_mut(&from) { - chan.blocking_recv().blocking_recv() + if let Some(chan) = self.channels.get_mut(&from) { + self.runtime.block_on(chan.recv()) } else { - return Err(std::io::Error::new( + Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("No channel found for party id={}", from), - )); - }; - - let data = data.map_err(|_| { - std::io::Error::new(std::io::ErrorKind::BrokenPipe, "receive channel end died") - })??; - Ok(data) + )) + } } pub(crate) fn _id(&self) -> usize { @@ -262,17 +252,9 @@ impl ShamirNetwork for ShamirMpcNet { let id = self.id; let num_parties = self.num_parties; let mut net_handler = self.net_handler.lock().unwrap(); - - let mut channels = HashMap::with_capacity(num_parties - 1); - for other_id in 0..num_parties { - if other_id != id { - let chan = net_handler - .get_byte_channel(&other_id) - .expect("to find channel"); - channels.insert(other_id, net_handler.spawn(chan)); - } - } - + // TODO return error + let channels = net_handler.get_byte_channels().expect("all channels found"); + assert_eq!(channels.len(), self.num_parties - 1); Ok(Self { id, num_parties, diff --git a/mpc-net/examples/three_party.rs b/mpc-net/examples/three_party.rs index 9c57fcac7..c4d0490e6 100644 --- a/mpc-net/examples/three_party.rs +++ b/mpc-net/examples/three_party.rs @@ -5,12 +5,10 @@ use color_eyre::{ eyre::{eyre, Context, ContextCompat}, Result, }; -use futures::{SinkExt, StreamExt}; use mpc_net::{ config::{NetworkConfig, NetworkConfigFile}, MpcNetworkHandler, }; -use tokio::io::AsyncWriteExt; #[derive(Parser)] struct Args { @@ -43,17 +41,9 @@ async fn main() -> Result<()> { } // recv from all channels for (&_, channel) in channels.iter_mut() { - let buf = channel.next().await; - if let Some(Ok(b)) = buf { - println!("received {}, should be {}", b[0], my_id as u8); - assert!(b.iter().all(|&x| x == my_id as u8)) - } - } - - // make sure all write are done by shutting down all streams - for (_, channel) in channels.into_iter() { - let (write, _) = channel.split(); - write.into_inner().shutdown().await?; + let buf = channel.recv().await?; + println!("received {}, should be {}", buf[0], my_id as u8); + assert!(buf.iter().all(|&x| x == my_id as u8)) } network.print_connection_stats(&mut std::io::stdout())?; diff --git a/mpc-net/examples/three_party_bincode_channels.rs b/mpc-net/examples/three_party_bincode_channels.rs index b9bf8e23c..d95aa5915 100644 --- a/mpc-net/examples/three_party_bincode_channels.rs +++ b/mpc-net/examples/three_party_bincode_channels.rs @@ -5,13 +5,11 @@ use color_eyre::{ eyre::{eyre, Context, ContextCompat}, Result, }; -use futures::{SinkExt, StreamExt}; use mpc_net::{ config::{NetworkConfig, NetworkConfigFile}, MpcNetworkHandler, }; use serde::{Deserialize, Serialize}; -use tokio::io::AsyncWriteExt; #[derive(Parser)] struct Args { @@ -46,11 +44,9 @@ async fn main() -> Result<()> { } // recv from all channels for (&_, channel) in channels.iter_mut() { - let buf = channel.next().await; - if let Some(Ok(Message::Ping(b))) = buf { + let buf = channel.recv().await?; + if let Message::Ping(b) = buf { assert!(b.iter().all(|&x| x == my_id as u8)) - } else { - panic!("could not receive message"); } } // send to all channels @@ -60,18 +56,11 @@ async fn main() -> Result<()> { } // recv from all channels for (&_, channel) in channels.iter_mut() { - let buf = channel.next().await; - if let Some(Ok(Message::Pong(b))) = buf { + let buf = channel.recv().await?; + if let Message::Pong(b) = buf { assert!(b.iter().all(|&x| x == my_id as u8)) - } else { - panic!("could not receive message"); } } - // make sure all write are done by shutting down all streams - for (_, channel) in channels.into_iter() { - let (write, _) = channel.split(); - write.into_inner().shutdown().await?; - } network.print_connection_stats(&mut std::io::stdout())?; Ok(()) diff --git a/mpc-net/examples/three_party_custom_channels.rs b/mpc-net/examples/three_party_custom_channels.rs index 8a8b4e963..969f44657 100644 --- a/mpc-net/examples/three_party_custom_channels.rs +++ b/mpc-net/examples/three_party_custom_channels.rs @@ -6,12 +6,10 @@ use color_eyre::{ eyre::{eyre, Context, ContextCompat}, Result, }; -use futures::{SinkExt, StreamExt}; use mpc_net::{ config::{NetworkConfig, NetworkConfigFile}, MpcNetworkHandler, }; -use tokio::io::AsyncWriteExt; use tokio_util::codec::{Decoder, Encoder}; #[derive(Parser)] @@ -46,11 +44,9 @@ async fn main() -> Result<()> { } // recv from all channels for (&_, channel) in channels.iter_mut() { - let buf = channel.next().await; - if let Some(Ok(Message::Ping(b))) = buf { + let buf = channel.recv().await?; + if let Message::Ping(b) = buf { assert!(b.iter().all(|&x| x == my_id as u8)) - } else { - panic!("could not receive message"); } } // send to all channels @@ -60,18 +56,11 @@ async fn main() -> Result<()> { } // recv from all channels for (&_, channel) in channels.iter_mut() { - let buf = channel.next().await; - if let Some(Ok(Message::Pong(b))) = buf { + let buf = channel.recv().await?; + if let Message::Pong(b) = buf { assert!(b.iter().all(|&x| x == my_id as u8)) - } else { - panic!("could not receive message"); } } - // make sure all write are done by shutting down all streams - for (_, channel) in channels.into_iter() { - let (write, _) = channel.split(); - write.into_inner().shutdown().await?; - } network.print_connection_stats(&mut std::io::stdout())?; Ok(()) diff --git a/mpc-net/examples/three_party_managed.rs b/mpc-net/examples/three_party_managed.rs deleted file mode 100644 index ab4792e5a..000000000 --- a/mpc-net/examples/three_party_managed.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::{collections::HashMap, path::PathBuf}; - -use clap::Parser; -use color_eyre::{ - eyre::{eyre, Context, ContextCompat}, - Result, -}; -use mpc_net::{ - config::{NetworkConfig, NetworkConfigFile}, - MpcNetworkHandler, -}; - -#[derive(Parser)] -struct Args { - /// The config file path - #[clap(short, long, value_name = "FILE")] - config_file: PathBuf, -} - -#[tokio::main] -async fn main() -> Result<()> { - let args = Args::parse(); - rustls::crypto::aws_lc_rs::default_provider() - .install_default() - .map_err(|_| eyre!("Could not install default rustls crypto provider"))?; - - let config: NetworkConfigFile = - toml::from_str(&std::fs::read_to_string(args.config_file).context("opening config file")?) - .context("parsing config file")?; - let config = NetworkConfig::try_from(config).context("converting network config")?; - let my_id = config.my_id; - - let mut network = MpcNetworkHandler::establish(config).await?; - - let channels = network.get_byte_channels().context("get channels")?; - let mut managed_channels = channels - .into_iter() - .map(|(i, c)| (i, network.spawn(c))) - .collect::>(); - - // send to all channels - for (&i, channel) in managed_channels.iter_mut() { - let buf = vec![i as u8; 1024]; - let _ = channel.send(buf.into()).await.await?; - } - // recv from all channels - for (&_, channel) in managed_channels.iter_mut() { - let buf = channel.recv().await.await; - if let Ok(Ok(b)) = buf { - println!("received {}, should be {}", b[0], my_id as u8); - assert!(b.iter().all(|&x| x == my_id as u8)) - } - } - // drop handles so we can shutdown - drop(managed_channels); - // wait until all send and recv taks are done - network.shutdown().await?; - network.print_connection_stats(&mut std::io::stdout())?; - - Ok(()) -} diff --git a/mpc-net/src/channel.rs b/mpc-net/src/channel.rs index 2f5166eaf..d1a1e74cd 100644 --- a/mpc-net/src/channel.rs +++ b/mpc-net/src/channel.rs @@ -1,22 +1,27 @@ //! A channel abstraction for sending and receiving messages. -use futures::{Sink, SinkExt, Stream, StreamExt}; -use std::{io, marker::Unpin, pin::Pin}; -use tokio::{ - io::{AsyncRead, AsyncWrite, AsyncWriteExt}, - runtime::Handle, - sync::{mpsc, oneshot}, - task::{JoinError, JoinHandle}, +use futures::{SinkExt, StreamExt}; +use std::{ + io::{self}, + marker::Unpin, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::{Context, Poll}, }; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec}; use crate::codecs::BincodeCodec; /// A read end of the channel, just a type alias for [`FramedRead`]. -pub type ReadChannel = FramedRead; +pub type ReadChannel = FramedRead, D>; /// A write end of the channel, just a type alias for [`FramedWrite`]. -pub type WriteChannel = FramedWrite; +pub type WriteChannel = FramedWrite, E>; /// A channel that uses a [`Encoder`] and [`Decoder`] to send and receive messages. +// TODO we can remove generics over R and W? #[derive(Debug)] pub struct Channel { read_conn: ReadChannel, @@ -29,304 +34,153 @@ pub type BytesChannel = Channel; /// A channel that uses a [`BincodeCodec`] to send and receive messages. pub type BincodeChannel = Channel>; -impl Channel { +impl Channel +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ /// Create a new [`Channel`], backed by a read and write half. Read and write buffers /// are automatically handled by [`LengthDelimitedCodec`]. - pub fn new(read_half: R, write_half: W, codec: C) -> Self + /// The number of sent and received bytes is tracked. + pub fn new( + read_half: R, + write_half: W, + codec: C, + bytes_read: Arc, + bytes_written: Arc, + ) -> Self where C: Clone + Decoder + Encoder, - R: AsyncRead, - W: AsyncWrite, { Channel { - write_conn: FramedWrite::new(write_half, codec.clone()), - read_conn: FramedRead::new(read_half, codec), + write_conn: FramedWrite::new( + TrackingAsyncWriter::new(write_half, bytes_written), + codec.clone(), + ), + read_conn: FramedRead::new(TrackingAsyncReader::new(read_half, bytes_read), codec), } } - /// Split Connection into a ([`WriteChannel`],[`ReadChannel`]) pair. - pub fn split(self) -> (WriteChannel, ReadChannel) { - (self.write_conn, self.read_conn) - } - - /// Join ([`WriteChannel`],[`ReadChannel`]) pair back into a [`Channel`]. - pub fn join(write_conn: WriteChannel, read_conn: ReadChannel) -> Self { - Self { - write_conn, - read_conn, - } - } - - /// Returns mutable reference to the ([`WriteChannel`],[`ReadChannel`]) pair. - pub fn inner_ref(&mut self) -> (&mut WriteChannel, &mut ReadChannel) { - (&mut self.write_conn, &mut self.read_conn) - } - - /// Closes the channel, flushing the write buffer and checking that there is no unread data. - pub async fn close(self) -> Result<(), io::Error> + /// Send data via the channel + pub async fn send(&mut self, data: MSend) -> std::io::Result<()> where - C: Encoder + Decoder, - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, + C: Clone + Encoder, { - let Channel { - mut read_conn, - mut write_conn, - .. - } = self; - write_conn.flush().await?; - write_conn.close().await?; - if let Some(x) = read_conn.next().await { - match x { - Ok(_) => { - return Err(io::Error::new( - io::ErrorKind::Other, - "Unexpected data on read channel when closing connections", - )); - } - Err(e) => { - return Err(e); - } - } - } - - Ok(()) + self.write_conn.send(data).await } -} -impl> Sink - for Channel -where - Self: Unpin, -{ - type Error = >::Error; - fn poll_ready( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.write_conn.poll_ready_unpin(cx) - } - - fn start_send(mut self: std::pin::Pin<&mut Self>, item: MSend) -> Result<(), Self::Error> { - self.write_conn.start_send_unpin(item) + /// Receive data via the channel + pub async fn recv(&mut self) -> std::io::Result + where + C: Decoder, + { + self.read_conn.next().await.ok_or(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "connection closed", + ))? } - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.write_conn.poll_flush_unpin(cx) + /// Split Connection into a ([`WriteChannel`],[`ReadChannel`]) pair. + pub fn split(self) -> (WriteChannel, ReadChannel) { + (self.write_conn, self.read_conn) } +} - fn poll_close( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.write_conn.poll_close_unpin(cx) - } +/// A wrapper around [`AsyncRead`] types that keeps track of the number of read bytes +#[derive(Debug)] +pub struct TrackingAsyncReader { + inner: R, + bytes_read: Arc, } -impl> Stream - for Channel -where - Self: Unpin, -{ - type Item = Result::Error>; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.read_conn.poll_next_unpin(cx) +impl TrackingAsyncReader { + /// Create a new [`TrackingAsyncReader`]. + pub fn new(inner: R, bytes_read: Arc) -> Self { + Self { inner, bytes_read } } } -struct WriteJob { - data: MSend, - ret: oneshot::Sender>, -} +impl AsyncRead for TrackingAsyncReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let inner = Pin::new(&mut self.inner); + let initial_len = buf.filled().len(); + let res = inner.poll_read(cx, buf); + + // if the read was ok, update bytes_read + if let Poll::Ready(Ok(())) = &res { + self.bytes_read + .fetch_add(buf.filled().len() - initial_len, Ordering::SeqCst); + } -struct ReadJob { - ret: oneshot::Sender>, + res + } } -/// A handle to a channel that allows sending and receiving messages. +/// A wrapper around [`AsyncWrite`] types that keeps track of the number of written bytes #[derive(Debug)] -pub struct ChannelHandle { - write_job_queue: mpsc::Sender>, - read_job_queue: mpsc::Sender>, +pub struct TrackingAsyncWriter { + inner: W, + bytes_written: Arc, } -impl ChannelHandle -where - MRecv: Send + std::fmt::Debug + 'static, - MSend: Send + std::fmt::Debug + 'static, -{ - /// Instructs the channel to send a message. Returns a [oneshot::Receiver] that will return the result of the send operation. - pub async fn send(&mut self, data: MSend) -> oneshot::Receiver> { - let (ret, recv) = oneshot::channel(); - let job = WriteJob { data, ret }; - match self.write_job_queue.send(job).await { - Ok(_) => {} - Err(job) => job - .0 - .ret - .send(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "ChannelHandle: send Channel is gone", - ))) - .unwrap(), +impl TrackingAsyncWriter { + /// Create a new [`TrackingAsyncWriter`]. + pub fn new(inner: R, bytes_written: Arc) -> Self { + Self { + inner, + bytes_written, } - recv } +} - /// Instructs the channel to receive a message. Returns a [oneshot::Receiver] that will return the result of the receive operation. - pub async fn recv(&mut self) -> oneshot::Receiver> { - let (ret, recv) = oneshot::channel(); - let job = ReadJob { ret }; - match self.read_job_queue.send(job).await { - Ok(_) => {} - Err(job) => job - .0 - .ret - .send(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "ChannelHandle: recv Channel is gone", - ))) - .unwrap(), +impl AsyncWrite for TrackingAsyncWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let inner = Pin::new(&mut self.inner); + let res = inner.poll_write(cx, buf); + + // if the write was ok, update bytes_written + if let Poll::Ready(Ok(bytes_written)) = &res { + self.bytes_written + .fetch_add(*bytes_written, Ordering::SeqCst); } - recv - } - /// A blocking version of [ChannelHandle::send]. This will block until the send operation is complete. - pub fn blocking_send(&mut self, data: MSend) -> oneshot::Receiver> { - let (ret, recv) = oneshot::channel(); - let job = WriteJob { data, ret }; - match self.write_job_queue.blocking_send(job) { - Ok(_) => {} - Err(job) => job - .0 - .ret - .send(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "ChannelHandle: send Channel is gone", - ))) - .unwrap(), - } - recv + res } - /// A blocking version of [ChannelHandle::recv]. This will block until the receive operation is complete. - pub fn blocking_recv(&mut self) -> oneshot::Receiver> { - let (ret, recv) = oneshot::channel(); - let job = ReadJob { ret }; - match self.read_job_queue.blocking_send(job) { - Ok(_) => {} - Err(job) => job - .0 - .ret - .send(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "ChannelHandle: recv Channel is gone", - ))) - .unwrap(), - } - recv + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) } -} - -/// Handles spawing and shutdown of channels. On drop, joins all [`JoinHandle`]s. The [`Handle`] musst be valid for the entire lifetime of this type. -#[derive(Debug)] -pub(crate) struct ChannelTasks { - tasks: Vec>, - handle: Handle, -} -impl ChannelTasks { - /// Create a new [`ChannelTasks`] instance. - pub fn new(handle: Handle) -> Self { - Self { - tasks: Vec::new(), - handle, - } + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) } - /// Create a new [`ChannelHandle`] from a [`Channel`]. This spawns a new tokio task that handles the read and write jobs so they can happen concurrently. - pub(crate) fn spawn( - &mut self, - chan: Channel, - ) -> ChannelHandle - where - C: 'static, - R: AsyncRead + Unpin + 'static, - W: AsyncWrite + Unpin + std::marker::Send + 'static, - FramedRead: Stream> + Send, - FramedWrite: Sink + Send, - MRecv: Send + std::fmt::Debug + 'static, - MSend: Send + std::fmt::Debug + 'static, - { - let (write_send, mut write_recv) = mpsc::channel::>(1024); - let (read_send, mut read_recv) = mpsc::channel::>(1024); - - let (mut write, mut read) = chan.split(); - - self.tasks.push(self.handle.spawn(async move { - while let Some(frame) = read.next().await { - let job = read_recv.recv().await; - match job { - Some(job) => { - if job.ret.send(frame).is_err() { - tracing::warn!("Warning: Read Job finished but receiver is gone!"); - } - } - None => { - if frame.is_ok() { - tracing::warn!("Warning: received Ok frame but receiver is gone!"); - } - break; - } - } - } - })); - self.tasks.push(self.handle.spawn(async move { - while let Some(write_job) = write_recv.recv().await { - let write_result = write.send(write_job.data).await; - // we don't really care if the receiver for a write job is gone, as this is a common case - // therefore we only emit a trace message - match write_job.ret.send(write_result) { - Ok(_) => {} - Err(_) => { - tracing::trace!("Debug: Write Job finished but receiver is gone!"); - } - } - } - // make sure all data is sent - if write.into_inner().shutdown().await.is_err() { - tracing::warn!("Warning: shutdown of stream failed!"); - } - })); - - ChannelHandle { - write_job_queue: write_send, - read_job_queue: read_send, + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + let inner = Pin::new(&mut self.inner); + let res = inner.poll_write_vectored(cx, bufs); + + // if the write was ok, update bytes_written + if let Poll::Ready(Ok(bytes_written)) = &res { + self.bytes_written + .fetch_add(*bytes_written, Ordering::SeqCst); } - } - /// Join all [`JoinHandle`]s and remove them. - pub(crate) async fn shutdown(&mut self) -> Result<(), JoinError> { - futures::future::try_join_all(std::mem::take(&mut self.tasks)) - .await - .map(|_| ()) + res } -} -impl Drop for ChannelTasks { - fn drop(&mut self) { - tokio::task::block_in_place(move || { - self.handle - .block_on(futures::future::try_join_all(std::mem::take( - &mut self.tasks, - ))) - .expect("can join all tasks"); - }); + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() } } diff --git a/mpc-net/src/lib.rs b/mpc-net/src/lib.rs index 0a0519e24..e7ca8a940 100644 --- a/mpc-net/src/lib.rs +++ b/mpc-net/src/lib.rs @@ -4,26 +4,21 @@ use std::{ collections::{BTreeMap, HashMap}, io, net::ToSocketAddrs, - pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, Arc, }, - task::{Context, Poll}, time::Duration, }; -use channel::{BincodeChannel, BytesChannel, Channel, ChannelHandle, ChannelTasks}; +use channel::{BincodeChannel, BytesChannel, Channel}; use codecs::BincodeCodec; -use color_eyre::eyre::{self, bail, Context as Ctx, ContextCompat, Report}; +use color_eyre::eyre::{self, bail, Context, ContextCompat, Report}; use config::NetworkConfig; -use futures::{Sink, Stream}; use serde::{de::DeserializeOwned, Serialize}; use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}, + io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, - runtime::Handle, - task::JoinError, }; use tokio_rustls::{ rustls::{ @@ -32,7 +27,7 @@ use tokio_rustls::{ }, TlsAcceptor, TlsConnector, }; -use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec}; +use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec}; pub mod channel; pub mod codecs; @@ -41,8 +36,8 @@ pub mod config; // TODO get this from network config const STREAMS_PER_CONN: usize = 8; -/// Type alias for a [rustls::TcpStream] over a [TcpStream]. -type TlsStream = tokio_rustls::TlsStream; +/// Type alias for a [tokio_rustls::TlsStream] over a [TcpStream]. +pub type TlsStream = tokio_rustls::TlsStream; /// A duplex TLS stream that uses one stream for sending and one for receiving. /// Splitting a single stream would add unwanted syncronization primitives. @@ -75,7 +70,6 @@ struct Connection { pub struct MpcNetworkHandler { // this is a btreemap because we rely on iteration order connections: BTreeMap, - tasks: ChannelTasks, my_id: usize, } @@ -247,33 +241,10 @@ impl MpcNetworkHandler { Ok(MpcNetworkHandler { connections, - tasks: ChannelTasks::new(Handle::current()), my_id: config.my_id, }) } - /// Create a new [`ChannelHandle`] from a [`Channel`]. This spawns a new tokio task that handles the read and write jobs so they can happen concurrently. - pub fn spawn( - &mut self, - chan: Channel, - ) -> ChannelHandle - where - C: 'static, - R: AsyncRead + Unpin + 'static, - W: AsyncWrite + Unpin + std::marker::Send + 'static, - FramedRead: Stream> + Send, - FramedWrite: Sink + Send, - MRecv: Send + std::fmt::Debug + 'static, - MSend: Send + std::fmt::Debug + 'static, - { - self.tasks.spawn(chan) - } - - /// Shutdown the network, waiting until all read and write tasks are completed. This happens automatically, when the network handler is dropped. - pub async fn shutdown(&mut self) -> Result<(), JoinError> { - self.tasks.shutdown().await - } - /// Returns the number of sent and received bytes. pub fn get_send_receive(&self, i: usize) -> std::io::Result<(usize, usize)> { let conn = self @@ -301,10 +272,7 @@ impl MpcNetworkHandler { } /// Get a [Channel] to party with `id`. This pops a stream from the pool. - pub fn get_byte_channel( - &mut self, - id: &usize, - ) -> Option> { + pub fn get_byte_channel(&mut self, id: &usize) -> Option> { let mut codec = LengthDelimitedCodec::new(); codec.set_max_frame_length(1_000_000_000); self.get_custom_channel(id, codec) @@ -314,7 +282,7 @@ impl MpcNetworkHandler { pub fn get_serde_bincode_channel( &mut self, id: &usize, - ) -> Option>> { + ) -> Option>> { let bincodec = BincodeCodec::::new(); self.get_custom_channel(id, bincodec) } @@ -331,13 +299,17 @@ impl MpcNetworkHandler { &mut self, id: &usize, codec: C, - ) -> Option> { + ) -> Option> { debug_assert!(*id != self.my_id); if let Some(conn) = self.connections.get_mut(id) { if let Some(stream) = conn.streams.pop() { - let recv = TrackingAsyncReader::new(stream.recv, conn.recv.clone()); - let send = TrackingAsyncWriter::new(stream.send, conn.sent.clone()); - return Some(Channel::new(recv, send, codec)); + return Some(Channel::new( + stream.recv, + stream.send, + codec, + conn.recv.clone(), + conn.sent.clone(), + )); } } None @@ -354,7 +326,7 @@ impl MpcNetworkHandler { >( &mut self, codec: C, - ) -> Option>> { + ) -> Option>> { let mut channels = HashMap::new(); let party_ids: Vec<_> = self.connections.keys().cloned().collect(); for id in party_ids { @@ -367,7 +339,7 @@ impl MpcNetworkHandler { /// Get a [Channel] to each party. This pops a stream from each pool. pub fn get_byte_channels( &mut self, - ) -> Option>> { + ) -> Option>> { let mut codec = LengthDelimitedCodec::new(); codec.set_max_frame_length(1_000_000_000); self.get_custom_channels(codec) @@ -376,103 +348,8 @@ impl MpcNetworkHandler { /// Get a [Channel] to each party. This pops a stream from each pool. pub fn get_serde_bincode_channels( &mut self, - ) -> Option>> { + ) -> Option>> { let bincodec = BincodeCodec::::new(); self.get_custom_channels(bincodec) } } - -/// A wrapper around [`AsyncRead`] types that keeps track of the number of read bytes -struct TrackingAsyncReader { - inner: R, - bytes_read: Arc, -} - -impl TrackingAsyncReader { - fn new(inner: R, bytes_read: Arc) -> Self { - Self { inner, bytes_read } - } -} - -impl AsyncRead for TrackingAsyncReader { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - let inner = Pin::new(&mut self.inner); - let initial_len = buf.filled().len(); - let res = inner.poll_read(cx, buf); - - // if the read was ok, update bytes_read - if let Poll::Ready(Ok(())) = &res { - self.bytes_read - .fetch_add(buf.filled().len() - initial_len, Ordering::SeqCst); - } - - res - } -} - -/// A wrapper around [`AsyncWrite`] types that keeps track of the number of written bytes -struct TrackingAsyncWriter { - inner: W, - bytes_written: Arc, -} - -impl TrackingAsyncWriter { - fn new(inner: R, bytes_written: Arc) -> Self { - Self { - inner, - bytes_written, - } - } -} - -impl AsyncWrite for TrackingAsyncWriter { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let inner = Pin::new(&mut self.inner); - let res = inner.poll_write(cx, buf); - - // if the write was ok, update bytes_written - if let Poll::Ready(Ok(bytes_written)) = &res { - self.bytes_written - .fetch_add(*bytes_written, Ordering::SeqCst); - } - - res - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_shutdown(cx) - } - - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - let inner = Pin::new(&mut self.inner); - let res = inner.poll_write_vectored(cx, bufs); - - // if the write was ok, update bytes_written - if let Poll::Ready(Ok(bytes_written)) = &res { - self.bytes_written - .fetch_add(*bytes_written, Ordering::SeqCst); - } - - res - } - - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() - } -}