From 27c3d603208386df93ecbe188ba5ce90a38fa434 Mon Sep 17 00:00:00 2001 From: Fabian Gruber Date: Mon, 16 Dec 2024 13:55:29 +0100 Subject: [PATCH] refactor: replaced constant channel pool with actor queue that generate new channels in the background --- .../groth16/run_full_multiplier2_shamir.sh | 19 + mpc-core/src/protocols/bridges/network.rs | 4 +- mpc-core/src/protocols/rep3/network.rs | 64 +-- mpc-core/src/protocols/shamir/network.rs | 60 +-- mpc-net/Cargo.toml | 1 + mpc-net/examples/three_party.rs | 6 +- .../examples/three_party_bincode_channels.rs | 8 +- .../examples/three_party_custom_channels.rs | 6 +- mpc-net/examples/three_party_managed.rs | 12 +- mpc-net/src/channel.rs | 201 +++---- mpc-net/src/config.rs | 11 + mpc-net/src/lib.rs | 493 +++++++----------- mpc-net/src/queue.rs | 80 +++ mpc-net/src/tracking_rw.rs | 116 +++++ 14 files changed, 603 insertions(+), 478 deletions(-) create mode 100755 co-circom/co-circom/examples/groth16/run_full_multiplier2_shamir.sh create mode 100644 mpc-net/src/queue.rs create mode 100644 mpc-net/src/tracking_rw.rs diff --git a/co-circom/co-circom/examples/groth16/run_full_multiplier2_shamir.sh b/co-circom/co-circom/examples/groth16/run_full_multiplier2_shamir.sh new file mode 100755 index 000000000..4d8af297c --- /dev/null +++ b/co-circom/co-circom/examples/groth16/run_full_multiplier2_shamir.sh @@ -0,0 +1,19 @@ +# split input into shares +cargo run --release --bin co-circom -- split-input --circuit test_vectors/multiplier2/circuit.circom --input test_vectors/multiplier2/input.json --protocol REP3 --curve BN254 --out-dir test_vectors/multiplier2 --config test_vectors/kyc/config.toml +# run witness extension in MPC +cargo run --release --bin co-circom -- generate-witness -O2 --input test_vectors/multiplier2/input.json.0.shared --circuit test_vectors/multiplier2/circuit.circom --protocol REP3 --curve BN254 --config ../configs/party1.toml --out test_vectors/multiplier2/witness.wtns.0.shared & +cargo run --release --bin co-circom -- generate-witness -O2 --input test_vectors/multiplier2/input.json.1.shared --circuit test_vectors/multiplier2/circuit.circom --protocol REP3 --curve BN254 --config ../configs/party2.toml --out test_vectors/multiplier2/witness.wtns.1.shared & +cargo run --release --bin co-circom -- generate-witness -O2 --input test_vectors/multiplier2/input.json.2.shared --circuit test_vectors/multiplier2/circuit.circom --protocol REP3 --curve BN254 --config ../configs/party3.toml --out test_vectors/multiplier2/witness.wtns.2.shared +wait $(jobs -p) +# run translation from REP3 to Shamir +cargo run --release --bin co-circom -- translate-witness --witness test_vectors/multiplier2/witness.wtns.0.shared --src-protocol REP3 --target-protocol SHAMIR --curve BN254 --config ../configs/party1.toml --out test_vectors/multiplier2/shamir_witness.wtns.0.shared & +cargo run --release --bin co-circom -- translate-witness --witness test_vectors/multiplier2/witness.wtns.1.shared --src-protocol REP3 --target-protocol SHAMIR --curve BN254 --config ../configs/party2.toml --out test_vectors/multiplier2/shamir_witness.wtns.1.shared & +cargo run --release --bin co-circom -- translate-witness --witness test_vectors/multiplier2/witness.wtns.2.shared --src-protocol REP3 --target-protocol SHAMIR --curve BN254 --config ../configs/party3.toml --out test_vectors/multiplier2/shamir_witness.wtns.2.shared +wait $(jobs -p) +# run proving in MPC +cargo run --release --bin co-circom -- generate-proof groth16 --witness test_vectors/multiplier2/shamir_witness.wtns.0.shared --zkey test_vectors/multiplier2/multiplier2.zkey --protocol SHAMIR --curve BN254 --config ../configs/party1.toml --out proof.0.json --public-input public_input.json & +cargo run --release --bin co-circom -- generate-proof groth16 --witness test_vectors/multiplier2/shamir_witness.wtns.1.shared --zkey test_vectors/multiplier2/multiplier2.zkey --protocol SHAMIR --curve BN254 --config ../configs/party2.toml --out proof.1.json & +cargo run --release --bin co-circom -- generate-proof groth16 --witness test_vectors/multiplier2/shamir_witness.wtns.2.shared --zkey test_vectors/multiplier2/multiplier2.zkey --protocol SHAMIR --curve BN254 --config ../configs/party3.toml --out proof.2.json +wait $(jobs -p) +# verify proof +cargo run --release --bin co-circom -- verify groth16 --proof proof.0.json --vk test_vectors/multiplier2/verification_key.json --public-input public_input.json --curve BN254 diff --git a/mpc-core/src/protocols/bridges/network.rs b/mpc-core/src/protocols/bridges/network.rs index c97f00af3..cfdc39981 100644 --- a/mpc-core/src/protocols/bridges/network.rs +++ b/mpc-core/src/protocols/bridges/network.rs @@ -19,9 +19,9 @@ impl RepToShamirNetwork for Rep3MpcNet { fn to_shamir_net(self) -> ShamirMpcNet { let Self { id, - net_handler, chan_next, chan_prev, + queue, runtime, } = self; @@ -32,8 +32,8 @@ impl RepToShamirNetwork for Rep3MpcNet { ShamirMpcNet { id: id.into(), num_parties: 3, - net_handler, channels, + queue, runtime, } } diff --git a/mpc-core/src/protocols/rep3/network.rs b/mpc-core/src/protocols/rep3/network.rs index 48632c26c..310c15577 100644 --- a/mpc-core/src/protocols/rep3/network.rs +++ b/mpc-core/src/protocols/rep3/network.rs @@ -2,13 +2,15 @@ //! //! This module contains implementation of the rep3 mpc network -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use crate::RngType; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use bytes::{Bytes, BytesMut}; -use eyre::{bail, eyre, Report}; -use mpc_net::{channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler}; +use eyre::{bail, ContextCompat}; +use mpc_net::{ + channel::ChannelHandle, config::NetworkConfig, queue::ChannelQueue, MpcNetworkHandler, +}; use tokio::runtime::Runtime; use super::{ @@ -224,46 +226,44 @@ pub trait Rep3Network: Send { Self: Sized; } -// TODO make generic over codec? /// This struct can be used to facilitate network communication for the REP3 MPC protocol. #[derive(Debug)] pub struct Rep3MpcNet { pub(crate) id: PartyID, pub(crate) chan_next: ChannelHandle, pub(crate) chan_prev: ChannelHandle, - // TODO we should be able to get rid of this mutex once we dont remove streams from the pool anymore - pub(crate) net_handler: Arc>, - // order is important, runtime MUST be dropped after network handler + pub(crate) queue: ChannelQueue, + // order is important, runtime MUST be dropped last pub(crate) runtime: Arc, } impl Rep3MpcNet { /// Takes a [NetworkConfig] struct and constructs the network interface. The network needs to contain exactly 3 parties with ids 0, 1, and 2. - pub fn new(config: NetworkConfig) -> Result { + pub fn new(config: NetworkConfig) -> eyre::Result { if config.parties.len() != 3 { bail!("REP3 protocol requires exactly 3 parties") } + let queue_size = config.conn_queue_size; let id = PartyID::try_from(config.my_id)?; let runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() .build()?; - let mut net_handler = runtime.block_on(MpcNetworkHandler::establish(config))?; - - let chan_next = net_handler - .get_byte_channel(&id.next_id().into()) - .ok_or(eyre!("no next channel found"))?; - let chan_prev = net_handler - .get_byte_channel(&id.prev_id().into()) - .ok_or(eyre!("no prev channel found"))?; + let net_handler = runtime.block_on(MpcNetworkHandler::init(config))?; + let queue = runtime.block_on(MpcNetworkHandler::queue(net_handler, queue_size))?; - let chan_next = net_handler.spawn(chan_next); - let chan_prev = net_handler.spawn(chan_prev); + let mut channels = queue.get_channels()?; + let chan_next = channels + .remove(&id.next_id().into()) + .context("while removing channel")?; + let chan_prev = channels + .remove(&id.prev_id().into()) + .context("while removing channel")?; Ok(Self { id, - net_handler: Arc::new(Mutex::new(net_handler)), chan_next, chan_prev, + queue, runtime: Arc::new(runtime), }) } @@ -349,22 +349,24 @@ impl Rep3Network for Rep3MpcNet { } fn fork(&mut self) -> std::io::Result { - let id = self.id; - let mut net_handler = self.net_handler.lock().unwrap(); - let chan_next = net_handler - .get_byte_channel(&id.next_id().into()) - .expect("no next channel found"); - let chan_prev = net_handler - .get_byte_channel(&id.prev_id().into()) - .expect("no prev channel found"); - let chan_next = net_handler.spawn(chan_next); - let chan_prev = net_handler.spawn(chan_prev); + let mut channels = self.queue.get_channels().map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::Other, + "could not get channels from queue, channel died", + ) + })?; + let chan_next = channels + .remove(&self.id.next_id().into()) + .expect("to find next channel"); + let chan_prev = channels + .remove(&self.id.prev_id().into()) + .expect("to find prev channel"); Ok(Self { - id, - net_handler: self.net_handler.clone(), + id: self.id, chan_next, chan_prev, + queue: self.queue.clone(), runtime: self.runtime.clone(), }) } diff --git a/mpc-core/src/protocols/shamir/network.rs b/mpc-core/src/protocols/shamir/network.rs index 3b77ace96..07bf8ffde 100644 --- a/mpc-core/src/protocols/shamir/network.rs +++ b/mpc-core/src/protocols/shamir/network.rs @@ -4,12 +4,11 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use bytes::{Bytes, BytesMut}; -use eyre::{bail, eyre, Report}; -use mpc_net::{channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler}; -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, +use eyre::bail; +use mpc_net::{ + channel::ChannelHandle, config::NetworkConfig, queue::ChannelQueue, MpcNetworkHandler, }; +use std::{collections::HashMap, sync::Arc}; use tokio::runtime::Runtime; /// This trait defines the network interface for the Shamir protocol. @@ -76,20 +75,21 @@ pub trait ShamirNetwork: Send { } /// This struct can be used to facilitate network communication for the Shamir MPC protocol. +#[derive(Debug)] pub struct ShamirMpcNet { pub(crate) id: usize, // 0 <= id < num_parties pub(crate) num_parties: usize, pub(crate) channels: HashMap>, - // TODO we should be able to get rid of this mutex once we dont remove streams from the pool anymore - pub(crate) net_handler: Arc>, - // order is important, runtime MUST be dropped after network handler + pub(crate) queue: ChannelQueue, + // order is important, runtime MUST be dropped last pub(crate) runtime: Arc, } impl ShamirMpcNet { /// Takes a [NetworkConfig] struct and constructs the network interface. The network needs to contain at least 3 parties and all ids need to be in the range of 0 <= id < num_parties. - pub fn new(config: NetworkConfig) -> Result { + pub fn new(config: NetworkConfig) -> eyre::Result { let num_parties = config.parties.len(); + let queue_size = config.conn_queue_size; if config.parties.len() <= 2 { bail!("Shamir protocol requires at least 3 parties") @@ -102,23 +102,16 @@ impl ShamirMpcNet { let runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() .build()?; - let mut net_handler = runtime.block_on(MpcNetworkHandler::establish(config))?; - let mut channels = HashMap::with_capacity(num_parties - 1); - - for other_id in 0..num_parties { - if other_id != id { - let chan = net_handler - .get_byte_channel(&other_id) - .ok_or_else(|| eyre!("no channel found for party id={}", other_id))?; - channels.insert(other_id, net_handler.spawn(chan)); - } - } + let net_handler = runtime.block_on(MpcNetworkHandler::init(config))?; + let queue = runtime.block_on(MpcNetworkHandler::queue(net_handler, queue_size))?; + + let channels = queue.get_channels()?; Ok(Self { id, num_parties, - net_handler: Arc::new(Mutex::new(net_handler)), channels, + queue, runtime: Arc::new(runtime), }) } @@ -259,25 +252,18 @@ impl ShamirNetwork for ShamirMpcNet { } fn fork(&mut self) -> std::io::Result { - let id = self.id; - let num_parties = self.num_parties; - let mut net_handler = self.net_handler.lock().unwrap(); - - let mut channels = HashMap::with_capacity(num_parties - 1); - for other_id in 0..num_parties { - if other_id != id { - let chan = net_handler - .get_byte_channel(&other_id) - .expect("to find channel"); - channels.insert(other_id, net_handler.spawn(chan)); - } - } + let channels = self.queue.get_channels().map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::Other, + "could not get channels from queue, channel died", + ) + })?; Ok(Self { - id, - num_parties, - net_handler: self.net_handler.clone(), + id: self.id, + num_parties: self.num_parties, channels, + queue: self.queue.clone(), runtime: self.runtime.clone(), }) } diff --git a/mpc-net/Cargo.toml b/mpc-net/Cargo.toml index 95989b380..6168b55ac 100644 --- a/mpc-net/Cargo.toml +++ b/mpc-net/Cargo.toml @@ -17,6 +17,7 @@ bincode = { workspace = true } bytes = { workspace = true } clap = { workspace = true } color-eyre = { workspace = true } +eyre = { workspace = true } futures = { workspace = true } rcgen = { workspace = true } rustls = { workspace = true } diff --git a/mpc-net/examples/three_party.rs b/mpc-net/examples/three_party.rs index dae77509c..2f75b323e 100644 --- a/mpc-net/examples/three_party.rs +++ b/mpc-net/examples/three_party.rs @@ -2,7 +2,7 @@ use std::path::PathBuf; use clap::Parser; use color_eyre::{ - eyre::{eyre, Context, ContextCompat}, + eyre::{eyre, Context}, Result, }; use futures::{SinkExt, StreamExt}; @@ -31,9 +31,9 @@ async fn main() -> Result<()> { let config = NetworkConfig::try_from(config).context("converting network config")?; let my_id = config.my_id; - let mut network = MpcNetworkHandler::establish(config).await?; + let network = MpcNetworkHandler::init(config).await?; - let mut channels = network.get_byte_channels().context("get channels")?; + let mut channels = network.get_byte_channels().await?; // send to all channels for (&i, channel) in channels.iter_mut() { diff --git a/mpc-net/examples/three_party_bincode_channels.rs b/mpc-net/examples/three_party_bincode_channels.rs index afb62aa25..5749d48fa 100644 --- a/mpc-net/examples/three_party_bincode_channels.rs +++ b/mpc-net/examples/three_party_bincode_channels.rs @@ -2,7 +2,7 @@ use std::path::PathBuf; use clap::Parser; use color_eyre::{ - eyre::{eyre, Context, ContextCompat}, + eyre::{eyre, Context}, Result, }; use futures::{SinkExt, StreamExt}; @@ -32,11 +32,9 @@ async fn main() -> Result<()> { let config = NetworkConfig::try_from(config).context("converting network config")?; let my_id = config.my_id; - let mut network = MpcNetworkHandler::establish(config).await?; + let network = MpcNetworkHandler::init(config).await?; - let mut channels = network - .get_serde_bincode_channels() - .context("get channels")?; + let mut channels = network.get_serde_bincode_channels().await?; // send to all channels for (&i, channel) in channels.iter_mut() { diff --git a/mpc-net/examples/three_party_custom_channels.rs b/mpc-net/examples/three_party_custom_channels.rs index 64d72151f..dbde16885 100644 --- a/mpc-net/examples/three_party_custom_channels.rs +++ b/mpc-net/examples/three_party_custom_channels.rs @@ -3,7 +3,7 @@ use std::path::PathBuf; use bytes::{Buf, BufMut}; use clap::Parser; use color_eyre::{ - eyre::{eyre, Context, ContextCompat}, + eyre::{eyre, Context}, Result, }; use futures::{SinkExt, StreamExt}; @@ -33,10 +33,10 @@ async fn main() -> Result<()> { let config = NetworkConfig::try_from(config).context("converting network config")?; let my_id = config.my_id; - let mut network = MpcNetworkHandler::establish(config).await?; + let network = MpcNetworkHandler::init(config).await?; let codec = MessageCodec; - let mut channels = network.get_custom_channels(codec).context("get channels")?; + let mut channels = network.get_custom_channels(codec).await?; // send to all channels for (&i, channel) in channels.iter_mut() { diff --git a/mpc-net/examples/three_party_managed.rs b/mpc-net/examples/three_party_managed.rs index a65b3f711..dd00cee2e 100644 --- a/mpc-net/examples/three_party_managed.rs +++ b/mpc-net/examples/three_party_managed.rs @@ -1,8 +1,8 @@ -use std::{collections::HashMap, path::PathBuf}; +use std::path::PathBuf; use clap::Parser; use color_eyre::{ - eyre::{eyre, Context, ContextCompat}, + eyre::{eyre, Context}, Result, }; use mpc_net::{ @@ -30,13 +30,9 @@ async fn main() -> Result<()> { let config = NetworkConfig::try_from(config).context("converting network config")?; let my_id = config.my_id; - let mut network = MpcNetworkHandler::establish(config).await?; + let network = MpcNetworkHandler::init(config).await?; - let channels = network.get_byte_channels().context("get channels")?; - let mut managed_channels = channels - .into_iter() - .map(|(i, c)| (i, network.spawn(c))) - .collect::>(); + let mut managed_channels = network.get_byte_channels_managed().await?; // send to all channels for (&i, channel) in managed_channels.iter_mut() { diff --git a/mpc-net/src/channel.rs b/mpc-net/src/channel.rs index 2f5166eaf..8ffcee8ba 100644 --- a/mpc-net/src/channel.rs +++ b/mpc-net/src/channel.rs @@ -5,9 +5,12 @@ use tokio::{ io::{AsyncRead, AsyncWrite, AsyncWriteExt}, runtime::Handle, sync::{mpsc, oneshot}, - task::{JoinError, JoinHandle}, + task::JoinHandle, +}; +use tokio_util::{ + codec::{Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec}, + sync::CancellationToken, }; -use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec}; use crate::codecs::BincodeCodec; @@ -152,8 +155,11 @@ struct ReadJob { /// A handle to a channel that allows sending and receiving messages. #[derive(Debug)] pub struct ChannelHandle { - write_job_queue: mpsc::Sender>, - read_job_queue: mpsc::Sender>, + write_job_queue: Option>>, + read_job_queue: Option>>, + tasks: Vec>, + handle: Handle, + token: CancellationToken, } impl ChannelHandle @@ -161,11 +167,89 @@ where MRecv: Send + std::fmt::Debug + 'static, MSend: Send + std::fmt::Debug + 'static, { + /// Create a new [`ChannelHandle`] from a [`Channel`]. This spawns a new tokio task that handles the read and write jobs so they can happen concurrently. + pub fn spawn(chan: Channel) -> ChannelHandle + where + C: 'static, + R: AsyncRead + Unpin + 'static, + W: AsyncWrite + Unpin + std::marker::Send + 'static, + FramedRead: Stream> + Send, + FramedWrite: Sink + Send, + { + let handle = Handle::current(); + let (write_send, mut write_recv) = mpsc::channel::>(1024); + let (read_send, mut read_recv) = mpsc::channel::>(1024); + + let (mut write, mut read) = chan.split(); + let token = CancellationToken::new(); + let token_ = token.clone(); + + let mut tasks = Vec::new(); + // terminates if received None (other side of network termianted first and closed channel) or in drop via cancelation token + tasks.push(handle.spawn(async move { + loop { + tokio::select! { + frame = read.next() => { + if let Some(frame) = frame { + let job = read_recv.recv().await; + match job { + Some(job) => { + if job.ret.send(frame).is_err() { + tracing::warn!("Warning: Read Job finished but receiver is gone!"); + } + } + None => { + if frame.is_ok() { + tracing::warn!("Warning: received Ok frame but receiver is gone!"); + } + break; + } + } + } else { + break; + } + } + _ = token_.cancelled() => break, + } + } + })); + // terminates once we drop the corresponding sender in drop + tasks.push(handle.spawn(async move { + while let Some(write_job) = write_recv.recv().await { + match write.send(write_job.data).await { + Ok(_) => { + // we don't really care if the receiver for a write job is gone, as this is a common case + // therefore we only emit a trace message + if write_job.ret.send(Ok(())).is_err() { + tracing::trace!("Debug: Write Job finished but receiver is gone!"); + } + } + Err(err) => { + tracing::error!("Write job failed: {err}"); + } + } + } + // make sure all data is sent + if write.into_inner().shutdown().await.is_err() { + tracing::warn!("Warning: shutdown of stream failed!"); + } + })); + + ChannelHandle { + write_job_queue: Some(write_send), + read_job_queue: Some(read_send), + tasks, + handle, + token, + } + } + /// Instructs the channel to send a message. Returns a [oneshot::Receiver] that will return the result of the send operation. pub async fn send(&mut self, data: MSend) -> oneshot::Receiver> { let (ret, recv) = oneshot::channel(); let job = WriteJob { data, ret }; - match self.write_job_queue.send(job).await { + // unwrap is fine because the value is only set to None in drop + match self.write_job_queue.as_mut().unwrap().send(job).await { Ok(_) => {} Err(job) => job .0 @@ -183,7 +267,8 @@ where pub async fn recv(&mut self) -> oneshot::Receiver> { let (ret, recv) = oneshot::channel(); let job = ReadJob { ret }; - match self.read_job_queue.send(job).await { + // unwrap is fine because the value is only set to None in drop + match self.read_job_queue.as_mut().unwrap().send(job).await { Ok(_) => {} Err(job) => job .0 @@ -201,7 +286,8 @@ where pub fn blocking_send(&mut self, data: MSend) -> oneshot::Receiver> { let (ret, recv) = oneshot::channel(); let job = WriteJob { data, ret }; - match self.write_job_queue.blocking_send(job) { + // unwrap is fine because the value is only set to None in drop + match self.write_job_queue.as_mut().unwrap().blocking_send(job) { Ok(_) => {} Err(job) => job .0 @@ -219,7 +305,8 @@ where pub fn blocking_recv(&mut self) -> oneshot::Receiver> { let (ret, recv) = oneshot::channel(); let job = ReadJob { ret }; - match self.read_job_queue.blocking_send(job) { + // unwrap is fine because the value is only set to None in drop + match self.read_job_queue.as_mut().unwrap().blocking_send(job) { Ok(_) => {} Err(job) => job .0 @@ -234,99 +321,17 @@ where } } -/// Handles spawing and shutdown of channels. On drop, joins all [`JoinHandle`]s. The [`Handle`] musst be valid for the entire lifetime of this type. -#[derive(Debug)] -pub(crate) struct ChannelTasks { - tasks: Vec>, - handle: Handle, -} - -impl ChannelTasks { - /// Create a new [`ChannelTasks`] instance. - pub fn new(handle: Handle) -> Self { - Self { - tasks: Vec::new(), - handle, - } - } - - /// Create a new [`ChannelHandle`] from a [`Channel`]. This spawns a new tokio task that handles the read and write jobs so they can happen concurrently. - pub(crate) fn spawn( - &mut self, - chan: Channel, - ) -> ChannelHandle - where - C: 'static, - R: AsyncRead + Unpin + 'static, - W: AsyncWrite + Unpin + std::marker::Send + 'static, - FramedRead: Stream> + Send, - FramedWrite: Sink + Send, - MRecv: Send + std::fmt::Debug + 'static, - MSend: Send + std::fmt::Debug + 'static, - { - let (write_send, mut write_recv) = mpsc::channel::>(1024); - let (read_send, mut read_recv) = mpsc::channel::>(1024); - - let (mut write, mut read) = chan.split(); - - self.tasks.push(self.handle.spawn(async move { - while let Some(frame) = read.next().await { - let job = read_recv.recv().await; - match job { - Some(job) => { - if job.ret.send(frame).is_err() { - tracing::warn!("Warning: Read Job finished but receiver is gone!"); - } - } - None => { - if frame.is_ok() { - tracing::warn!("Warning: received Ok frame but receiver is gone!"); - } - break; - } - } - } - })); - self.tasks.push(self.handle.spawn(async move { - while let Some(write_job) = write_recv.recv().await { - let write_result = write.send(write_job.data).await; - // we don't really care if the receiver for a write job is gone, as this is a common case - // therefore we only emit a trace message - match write_job.ret.send(write_result) { - Ok(_) => {} - Err(_) => { - tracing::trace!("Debug: Write Job finished but receiver is gone!"); - } - } - } - // make sure all data is sent - if write.into_inner().shutdown().await.is_err() { - tracing::warn!("Warning: shutdown of stream failed!"); - } - })); - - ChannelHandle { - write_job_queue: write_send, - read_job_queue: read_send, - } - } - - /// Join all [`JoinHandle`]s and remove them. - pub(crate) async fn shutdown(&mut self) -> Result<(), JoinError> { - futures::future::try_join_all(std::mem::take(&mut self.tasks)) - .await - .map(|_| ()) - } -} - -impl Drop for ChannelTasks { +impl Drop for ChannelHandle { fn drop(&mut self) { + // drop sender to let write task finish + std::mem::drop(self.write_job_queue.take()); + std::mem::drop(self.read_job_queue.take()); + // cancel read task, we are in drop of ChannelHandle, so no one can read anymore + self.token.cancel(); + // ignore results, the queue eagerly creates new taks after a channel was taken joining these can fail tokio::task::block_in_place(move || { self.handle - .block_on(futures::future::try_join_all(std::mem::take( - &mut self.tasks, - ))) - .expect("can join all tasks"); + .block_on(futures::future::join_all(std::mem::take(&mut self.tasks))); }); } } diff --git a/mpc-net/src/config.rs b/mpc-net/src/config.rs index 10f5b27cd..d53ffc59a 100644 --- a/mpc-net/src/config.rs +++ b/mpc-net/src/config.rs @@ -114,6 +114,10 @@ impl TryFrom for NetworkParty { } } +fn default_conn_queue_size() -> usize { + 3 +} + /// The network configuration file. #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, PartialOrd, Ord, Hash)] pub struct NetworkConfigFile { @@ -125,6 +129,9 @@ pub struct NetworkConfigFile { pub bind_addr: SocketAddr, /// The path to our private key file. pub key_path: PathBuf, + /// The inital size of the connection queue. + #[serde(default = "default_conn_queue_size")] + pub conn_queue_size: usize, } /// The network configuration. @@ -138,6 +145,8 @@ pub struct NetworkConfig { pub bind_addr: SocketAddr, /// The private key. pub key: PrivateKeyDer<'static>, + /// The inital size of the connection queue. + pub conn_queue_size: usize, } impl TryFrom for NetworkConfig { @@ -155,6 +164,7 @@ impl TryFrom for NetworkConfig { my_id: value.my_id, bind_addr: value.bind_addr, key, + conn_queue_size: value.conn_queue_size, }) } } @@ -166,6 +176,7 @@ impl Clone for NetworkConfig { my_id: self.my_id, bind_addr: self.bind_addr, key: self.key.clone_key(), + conn_queue_size: self.conn_queue_size, } } } diff --git a/mpc-net/src/lib.rs b/mpc-net/src/lib.rs index 72a7fce8b..51dce4f4c 100644 --- a/mpc-net/src/lib.rs +++ b/mpc-net/src/lib.rs @@ -3,27 +3,25 @@ use std::{ collections::{BTreeMap, HashMap}, io, - net::ToSocketAddrs, - pin::Pin, + net::{SocketAddr, ToSocketAddrs}, sync::{ atomic::{AtomicUsize, Ordering}, Arc, }, - task::{Context, Poll}, time::Duration, }; -use channel::{BincodeChannel, BytesChannel, Channel, ChannelHandle, ChannelTasks}; +use bytes::{Bytes, BytesMut}; +use channel::{BincodeChannel, BytesChannel, Channel, ChannelHandle}; use codecs::BincodeCodec; -use color_eyre::eyre::{bail, Context as Ctx, ContextCompat, Report}; +use color_eyre::eyre::{bail, Context, ContextCompat, Report}; use config::NetworkConfig; -use futures::{Sink, Stream}; +use queue::{ChannelQueue, CreateJob, QueueJob}; use serde::{de::DeserializeOwned, Serialize}; use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}, + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, net::{TcpListener, TcpStream}, - runtime::Handle, - task::JoinError, + sync::mpsc::{self}, }; use tokio_rustls::{ rustls::{ @@ -32,14 +30,14 @@ use tokio_rustls::{ }, TlsAcceptor, TlsConnector, }; -use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec}; +use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec}; +use tracking_rw::{TrackingAsyncReader, TrackingAsyncWriter}; pub mod channel; pub mod codecs; pub mod config; - -// TODO get this from network config -const STREAMS_PER_CONN: usize = 8; +pub mod queue; +mod tracking_rw; /// Type alias for a [tokio_rustls::TlsStream] over a [TcpStream]. type TlsStream = tokio_rustls::TlsStream; @@ -47,7 +45,7 @@ type TlsStream = tokio_rustls::TlsStream; /// A duplex TLS stream that uses one stream for sending and one for receiving. /// Splitting a single stream would add unwanted syncronization primitives. #[derive(Debug)] -struct DuplexTlsStream { +pub(crate) struct DuplexTlsStream { send: TlsStream, recv: TlsStream, } @@ -62,26 +60,38 @@ impl DuplexTlsStream { } } -/// A connection with a pool of streams and total sent/recv stats. -#[derive(Debug, Default)] -struct Connection { - streams: Vec, +#[derive(Debug)] +struct ConnectionInfo { + party_hostname: String, + party_addr: SocketAddr, sent: Arc, recv: Arc, } +impl ConnectionInfo { + pub fn new(party_hostname: String, party_addr: SocketAddr) -> Self { + Self { + party_hostname, + party_addr, + sent: Arc::default(), + recv: Arc::default(), + } + } +} + /// A network handler for MPC protocols. -#[derive(Debug)] pub struct MpcNetworkHandler { // this is a btreemap because we rely on iteration order - connections: BTreeMap, - tasks: ChannelTasks, + conn_infos: BTreeMap, my_id: usize, + listener: TcpListener, + acceptor: TlsAcceptor, + connector: TlsConnector, } impl MpcNetworkHandler { - /// Tries to establish a connection to other parties in the network based on the provided [NetworkConfig]. - pub async fn establish(config: NetworkConfig) -> Result { + /// Initialize the [NetworkHandler] based on the provided [NetworkConfig]. + pub async fn init(config: NetworkConfig) -> Result { config.check_config()?; let certs: HashMap = config .parties @@ -111,154 +121,148 @@ impl MpcNetworkHandler { tracing::trace!("Party {}: listening on {our_socket_addr}", config.my_id); - let mut connections: BTreeMap = BTreeMap::new(); - - let mut accpected_streams = BTreeMap::new(); + let mut conn_infos = BTreeMap::new(); - let num_parties = config.parties.len(); - - for party in config.parties { + for party in config.parties.iter() { if party.id == config.my_id { // skip self continue; } - if party.id < config.my_id { + let party_addr = party + .dns_name + .to_socket_addrs() + .with_context(|| format!("while resolving DNS name for {}", party.dns_name))? + .next() + .with_context(|| format!("could not resolve DNS name {}", party.dns_name))?; + let party_hostname = party.dns_name.hostname.clone(); + + for _ in 0..config.conn_queue_size { + let conn = ConnectionInfo::new(party_hostname.clone(), party_addr); + conn_infos.insert(party.id, conn); + } + } + + Ok(MpcNetworkHandler { + conn_infos, + my_id: config.my_id, + listener, + acceptor, + connector, + }) + } + + /// Tries to establish a connection to other parties in the network. + pub(crate) async fn establish(&self) -> Result, Report> { + let mut streams = HashMap::new(); + let mut unordered_strames = HashMap::new(); + for (id, conn_info) in self.conn_infos.iter() { + if *id < self.my_id { // connect to party, we are client - let party_addr = party - .dns_name - .to_socket_addrs() - .with_context(|| format!("while resolving DNS name for {}", party.dns_name))? - .next() - .with_context(|| format!("could not resolve DNS name {}", party.dns_name))?; - - let domain = ServerName::try_from(party.dns_name.hostname.clone()) + let domain = ServerName::try_from(conn_info.party_hostname.as_str()) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))? .to_owned(); - // create all streams for this connection - for stream_id in 0..STREAMS_PER_CONN { - let mut send = None; - let mut recv = None; - // create 2 streams per stream to get full duplex with tls streams - for split in [DuplexTlsStream::SPLIT0, DuplexTlsStream::SPLIT1] { - let stream = loop { - if let Ok(stream) = TcpStream::connect(party_addr).await { - break stream; - } - std::thread::sleep(Duration::from_millis(100)); - }; - // this removes buffering of tcp packets, very important for latency of small packets - stream.set_nodelay(true)?; - let mut stream = connector.connect(domain.clone(), stream).await?; - stream.write_u64(config.my_id as u64).await?; - stream.write_u64(stream_id as u64).await?; - stream.write_u8(split).await?; - if split == DuplexTlsStream::SPLIT0 { - send = Some(stream); - } else { - recv = Some(stream); + let mut send = None; + let mut recv = None; + // create 2 streams per stream to get full duplex with tls streams + for split in [DuplexTlsStream::SPLIT0, DuplexTlsStream::SPLIT1] { + let stream = loop { + if let Ok(stream) = TcpStream::connect(conn_info.party_addr).await { + break stream; } + std::thread::sleep(Duration::from_millis(100)); + }; + // this removes buffering of tcp packets, very important for latency of small packets + stream.set_nodelay(true)?; + let mut stream = self.connector.connect(domain.clone(), stream).await?; + stream.write_u64(self.my_id as u64).await?; + stream.write_u8(split).await?; + if split == DuplexTlsStream::SPLIT0 { + send = Some(stream); + } else { + recv = Some(stream); } + } - tracing::trace!( - "Party {}: connected stream {stream_id} to party {}", - party.id, - config.my_id - ); + tracing::trace!("Party {}: connected stream to party {}", id, self.my_id); - let send = send.expect("not none after connect was succesful"); - let recv = recv.expect("not none after connect was succesful"); + let send = send.expect("not none after connect was succesful"); + let recv = recv.expect("not none after connect was succesful"); - if let Some(conn) = connections.get_mut(&party.id) { - conn.streams - .push(DuplexTlsStream::new(send.into(), recv.into())); - } else { - let mut conn = Connection::default(); - conn.streams - .push(DuplexTlsStream::new(send.into(), recv.into())); - connections.insert(party.id, conn); - } - } + assert!(streams + .insert(*id, DuplexTlsStream::new(send.into(), recv.into())) + .is_none()); } else { // we are the server, accept connections - // accept all 2 splits for n streams and store them with (party_id, stream_id, split) so we know were they belong - for _ in 0..STREAMS_PER_CONN * 2 { - let (stream, _peer_addr) = listener.accept().await?; + // accept 2 splits and store them with (party_id, split) so we know were they belong + for _ in 0..2 { + let (stream, _peer_addr) = self.listener.accept().await?; // this removes buffering of tcp packets, very important for latency of small packets stream.set_nodelay(true)?; - let mut stream = acceptor.accept(stream).await?; + let mut stream = self.acceptor.accept(stream).await?; let party_id = stream.read_u64().await? as usize; - let stream_id = stream.read_u64().await? as usize; let split = stream.read_u8().await?; - assert!(accpected_streams - .insert((party_id, stream_id, split), stream) + assert!(unordered_strames + .insert((party_id, split), stream) .is_none()); } } } - // assign streams to the right party, stream and duplex half + // assign streams to the right party and duplex half // we accepted streams for all parties with id > my_id, so we can iter from my_id + 1..num_parties - for party_id in config.my_id + 1..num_parties { - for stream_id in 0..STREAMS_PER_CONN { - // send and recv is swapped here compared to above - let recv = accpected_streams - .remove(&(party_id, stream_id, DuplexTlsStream::SPLIT0)) - .context(format!("get recv for stream {stream_id} party {party_id}"))?; - let send = accpected_streams - .remove(&(party_id, stream_id, DuplexTlsStream::SPLIT1)) - .context(format!("get send for stream {stream_id} party {party_id}"))?; - if let Some(conn) = connections.get_mut(&party_id) { - conn.streams - .push(DuplexTlsStream::new(send.into(), recv.into())); - } else { - let mut conn = Connection::default(); - conn.streams - .push(DuplexTlsStream::new(send.into(), recv.into())); - connections.insert(party_id, conn); - } - } + for id in self.my_id + 1..self.conn_infos.len() + 1 { + // send and recv is swapped here compared to above + let recv = unordered_strames + .remove(&(id, DuplexTlsStream::SPLIT0)) + .context(format!("get recv for party {}", id)) + .unwrap(); + let send = unordered_strames + .remove(&(id, DuplexTlsStream::SPLIT1)) + .context(format!("get send for party {}", id)) + .unwrap(); + assert!(streams + .insert(id, DuplexTlsStream::new(send.into(), recv.into())) + .is_none()); } - if !accpected_streams.is_empty() { - bail!("not accepted connections should remain"); + if !unordered_strames.is_empty() { + bail!("no stream should remain"); } - tracing::trace!("Party {}: established network handler", config.my_id); - - Ok(MpcNetworkHandler { - connections, - tasks: ChannelTasks::new(Handle::current()), - my_id: config.my_id, - }) + Ok(streams) } - /// Create a new [`ChannelHandle`] from a [`Channel`]. This spawns a new tokio task that handles the read and write jobs so they can happen concurrently. - pub fn spawn( - &mut self, - chan: Channel, - ) -> ChannelHandle - where - C: 'static, - R: AsyncRead + Unpin + 'static, - W: AsyncWrite + Unpin + std::marker::Send + 'static, - FramedRead: Stream> + Send, - FramedWrite: Sink + Send, - MRecv: Send + std::fmt::Debug + 'static, - MSend: Send + std::fmt::Debug + 'static, - { - self.tasks.spawn(chan) - } + /// Create a [ChannelQueue] that holds `size` number of [ChannelHandle]s per party. + /// This queue can be used to quickly get existing connections and create new ones in the background. + pub async fn queue(net_handler: Self, size: usize) -> eyre::Result { + let mut init_queue = Vec::new(); + for _ in 0..size { + init_queue.push(net_handler.get_byte_channels_managed().await?); + } - /// Shutdown the network, waiting until all read and write tasks are completed. This happens automatically, when the network handler is dropped. - pub async fn shutdown(&mut self) -> Result<(), JoinError> { - self.tasks.shutdown().await + let (queue_sender, queue_receiver) = mpsc::channel::(size); + let (create_sender, create_receiver) = mpsc::channel::(size); + + tokio::spawn(queue::create_channel_actor( + net_handler, + create_receiver, + queue_sender.clone(), + )); + + tokio::spawn(queue::get_channel_actor( + init_queue, + create_sender, + queue_receiver, + )); + + Ok(ChannelQueue::new(queue_sender)) } /// Returns the number of sent and received bytes. pub fn get_send_receive(&self, i: usize) -> std::io::Result<(usize, usize)> { let conn = self - .connections + .conn_infos .get(&i) .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "no such connection"))?; Ok(( @@ -269,7 +273,7 @@ impl MpcNetworkHandler { /// Prints the connection statistics. pub fn print_connection_stats(&self, out: &mut impl std::io::Write) -> std::io::Result<()> { - for (i, conn) in &self.connections { + for (i, conn) in &self.conn_infos { writeln!( out, "Connection {} stats:\n\tSENT: {} bytes\n\tRECV: {} bytes", @@ -281,27 +285,8 @@ impl MpcNetworkHandler { Ok(()) } - /// Get a [Channel] to party with `id`. This pops a stream from the pool. - pub fn get_byte_channel( - &mut self, - id: &usize, - ) -> Option> { - let mut codec = LengthDelimitedCodec::new(); - codec.set_max_frame_length(1_000_000_000); - self.get_custom_channel(id, codec) - } - - /// Get a [Channel] to party with `id`. This pops a stream from the pool. - pub fn get_serde_bincode_channel( - &mut self, - id: &usize, - ) -> Option>> { - let bincodec = BincodeCodec::::new(); - self.get_custom_channel(id, bincodec) - } - - /// Get a [Channel] to party with `id` using the provided codec. This pops a stream from the pool. - pub fn get_custom_channel< + /// Get a [Channel] to each party. This establishes a new Connection with each party. + pub async fn get_custom_channels< MSend, MRecv, C: Encoder @@ -309,151 +294,77 @@ impl MpcNetworkHandler { + 'static + Clone, >( - &mut self, - id: &usize, + &self, codec: C, - ) -> Option> { - debug_assert!(*id != self.my_id); - if let Some(conn) = self.connections.get_mut(id) { - if let Some(stream) = conn.streams.pop() { - let recv = TrackingAsyncReader::new(stream.recv, conn.recv.clone()); - let send = TrackingAsyncWriter::new(stream.send, conn.sent.clone()); - return Some(Channel::new(recv, send, codec)); - } - } - None + ) -> eyre::Result>> { + self.establish() + .await? + .into_iter() + .map(|(id, stream)| { + let conn_info = self.conn_infos.get(&id).context("while get conn info")?; + let recv = TrackingAsyncReader::new(stream.recv, conn_info.recv.clone()); + let send = TrackingAsyncWriter::new(stream.send, conn_info.sent.clone()); + Ok((id, Channel::new(recv, send, codec.clone()))) + }) + .collect() } - /// Get a [Channel] to each party using the provided codec. This pops a stream from each pool. - pub fn get_custom_channels< - MSend, - MRecv, - C: Encoder - + Decoder - + 'static - + Clone, - >( - &mut self, - codec: C, - ) -> Option>> { - let mut channels = HashMap::new(); - let party_ids: Vec<_> = self.connections.keys().cloned().collect(); - for id in party_ids { - let chan = self.get_custom_channel(&id, codec.clone())?; - channels.insert(id, chan); - } - Some(channels) + /// Get a [Channel] to each party. This establishes a new Connection with each party. + pub async fn get_byte_channels( + &self, + ) -> eyre::Result>> { + // set max frame length to 1Tb and length_field_length to 5 bytes + const NUM_BYTES: usize = 5; + let codec = LengthDelimitedCodec::builder() + .length_field_type::() // u64 because this is the type the length is decoded into, and u32 doesnt fit 5 bytes + .length_field_length(NUM_BYTES) + .max_frame_length(1usize << (NUM_BYTES * 8)) + .new_codec(); + self.get_custom_channels(codec).await } - /// Get a [Channel] to each party. This pops a stream from each pool. - pub fn get_byte_channels( - &mut self, - ) -> Option>> { - let mut codec = LengthDelimitedCodec::new(); - codec.set_max_frame_length(1_000_000_000); - self.get_custom_channels(codec) - } - - /// Get a [Channel] to each party. This pops a stream from each pool. - pub fn get_serde_bincode_channels( - &mut self, - ) -> Option>> { + /// Get a [Channel] to each party. This establishes a new Connection with each party. + pub async fn get_serde_bincode_channels( + &self, + ) -> eyre::Result>> { let bincodec = BincodeCodec::::new(); - self.get_custom_channels(bincodec) - } -} - -/// A wrapper around [`AsyncRead`] types that keeps track of the number of read bytes -struct TrackingAsyncReader { - inner: R, - bytes_read: Arc, -} - -impl TrackingAsyncReader { - fn new(inner: R, bytes_read: Arc) -> Self { - Self { inner, bytes_read } + self.get_custom_channels(bincodec).await } -} -impl AsyncRead for TrackingAsyncReader { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - let inner = Pin::new(&mut self.inner); - let initial_len = buf.filled().len(); - let res = inner.poll_read(cx, buf); - - // if the read was ok, update bytes_read - if let Poll::Ready(Ok(())) = &res { - self.bytes_read - .fetch_add(buf.filled().len() - initial_len, Ordering::SeqCst); - } - - res - } -} - -/// A wrapper around [`AsyncWrite`] types that keeps track of the number of written bytes -struct TrackingAsyncWriter { - inner: W, - bytes_written: Arc, -} - -impl TrackingAsyncWriter { - fn new(inner: R, bytes_written: Arc) -> Self { - Self { - inner, - bytes_written, - } - } -} - -impl AsyncWrite for TrackingAsyncWriter { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let inner = Pin::new(&mut self.inner); - let res = inner.poll_write(cx, buf); - - // if the write was ok, update bytes_written - if let Poll::Ready(Ok(bytes_written)) = &res { - self.bytes_written - .fetch_add(*bytes_written, Ordering::SeqCst); - } - - res - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_shutdown(cx) - } - - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - let inner = Pin::new(&mut self.inner); - let res = inner.poll_write_vectored(cx, bufs); - - // if the write was ok, update bytes_written - if let Poll::Ready(Ok(bytes_written)) = &res { - self.bytes_written - .fetch_add(*bytes_written, Ordering::SeqCst); - } - - res + /// Get a [ChannelHandle] to each party. This establishes a new Connection with each party. + /// Reads and writes are handled in tokio tasks. On drop, these tasks are awaited. + pub async fn get_custom_channels_managed< + MSend: std::fmt::Debug + std::marker::Send + 'static, + MRecv: std::fmt::Debug + std::marker::Send + 'static, + C: Encoder + + Decoder + + 'static + + Clone + + std::marker::Send, + >( + &self, + codec: C, + ) -> eyre::Result>> { + Ok(self + .get_custom_channels(codec) + .await? + .into_iter() + .map(|(id, chan)| (id, ChannelHandle::spawn(chan))) + .collect()) } - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() + /// Get a [ChannelHandle] to each party. This establishes a new Connection with each party. + /// Reads and writes are handled in tokio tasks. On drop, these tasks are awaited. + pub async fn get_byte_channels_managed( + &self, + ) -> eyre::Result>> { + // set max frame length to 1Tb and length_field_length to 5 bytes + const NUM_BYTES: usize = 5; + let codec = LengthDelimitedCodec::builder() + .length_field_type::() // u64 because this is the type the length is decoded into, and u32 doesnt fit 5 bytes + .length_field_length(NUM_BYTES) + .max_frame_length(1usize << (NUM_BYTES * 8)) + .new_codec(); + self.get_custom_channels_managed(codec).await } } diff --git a/mpc-net/src/queue.rs b/mpc-net/src/queue.rs new file mode 100644 index 000000000..cbbfe8026 --- /dev/null +++ b/mpc-net/src/queue.rs @@ -0,0 +1,80 @@ +//! A queue implementation that uses actors to creat new channels. + +use std::collections::{HashMap, VecDeque}; + +use bytes::{Bytes, BytesMut}; +use tokio::sync::{ + mpsc::{Receiver, Sender}, + oneshot, +}; + +use crate::{channel::ChannelHandle, MpcNetworkHandler}; + +pub(crate) struct CreateJob; + +pub(crate) enum QueueJob { + GetChannels(oneshot::Sender>>), + PutChannels(HashMap>), +} + +/// A queue for [ConnectionHandle]s for each party. +#[derive(Debug, Clone)] +pub struct ChannelQueue { + sender: Sender, +} + +impl ChannelQueue { + pub(crate) fn new(sender: Sender) -> Self { + Self { sender } + } + + /// Get a [ChannelHandle] for each party from the queue. New connections will be created in the background. + pub fn get_channels(&self) -> eyre::Result>> { + let (send, recv) = oneshot::channel(); + self.sender.blocking_send(QueueJob::GetChannels(send))?; + Ok(recv.blocking_recv()?) + } +} + +/// Spawn connection creating actor that holds net_handler +pub(crate) async fn create_channel_actor( + net_handler: MpcNetworkHandler, + mut receiver: Receiver, + queue_sender: Sender, +) -> eyre::Result<()> { + while (receiver.recv().await).is_some() { + let handles = net_handler.get_byte_channels_managed().await?; + queue_sender.send(QueueJob::PutChannels(handles)).await?; + } + Ok(()) +} + +/// Spawn queue actor that holds connection and requests new ones +pub(crate) async fn get_channel_actor( + init_queue: Vec>>, + create_sender: Sender, + mut receiver: Receiver, +) -> eyre::Result<()> { + let mut queue = VecDeque::from(init_queue); + let mut open_get_jobs = VecDeque::new(); + while let Some(job) = receiver.recv().await { + match job { + QueueJob::GetChannels(sender) => { + if let Some(handles) = queue.pop_back() { + sender.send(handles).expect("recv is alive"); + } else { + open_get_jobs.push_front(sender); + } + create_sender.send(CreateJob).await?; + } + QueueJob::PutChannels(handles) => { + if let Some(sender) = open_get_jobs.pop_back() { + sender.send(handles).expect("recv is alive"); + } else { + queue.push_front(handles); + } + } + } + } + Ok(()) +} diff --git a/mpc-net/src/tracking_rw.rs b/mpc-net/src/tracking_rw.rs new file mode 100644 index 000000000..b969b4ed4 --- /dev/null +++ b/mpc-net/src/tracking_rw.rs @@ -0,0 +1,116 @@ +use std::{ + io, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// A wrapper around [`AsyncRead`] types that keeps track of the number of read bytes +pub(crate) struct TrackingAsyncReader { + inner: R, + bytes_read: Arc, +} + +impl TrackingAsyncReader { + pub fn new(inner: R, bytes_read: Arc) -> Self { + Self { inner, bytes_read } + } + + #[allow(unused)] + pub fn into_inner(self) -> R { + self.inner + } +} + +impl AsyncRead for TrackingAsyncReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let inner = Pin::new(&mut self.inner); + let initial_len = buf.filled().len(); + let res = inner.poll_read(cx, buf); + + // if the read was ok, update bytes_read + if let Poll::Ready(Ok(())) = &res { + self.bytes_read + .fetch_add(buf.filled().len() - initial_len, Ordering::SeqCst); + } + + res + } +} + +/// A wrapper around [`AsyncWrite`] types that keeps track of the number of written bytes +pub(crate) struct TrackingAsyncWriter { + inner: W, + bytes_written: Arc, +} + +impl TrackingAsyncWriter { + pub fn new(inner: W, bytes_written: Arc) -> Self { + Self { + inner, + bytes_written, + } + } + + #[allow(unused)] + pub fn into_inner(self) -> W { + self.inner + } +} + +impl AsyncWrite for TrackingAsyncWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let inner = Pin::new(&mut self.inner); + let res = inner.poll_write(cx, buf); + + // if the write was ok, update bytes_written + if let Poll::Ready(Ok(bytes_written)) = &res { + self.bytes_written + .fetch_add(*bytes_written, Ordering::SeqCst); + } + + res + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + let inner = Pin::new(&mut self.inner); + let res = inner.poll_write_vectored(cx, bufs); + + // if the write was ok, update bytes_written + if let Poll::Ready(Ok(bytes_written)) = &res { + self.bytes_written + .fetch_add(*bytes_written, Ordering::SeqCst); + } + + res + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +}