Skip to content

Commit

Permalink
refactor: replaced constant channel pool with actor queue that genera…
Browse files Browse the repository at this point in the history
…te new channels in the background
  • Loading branch information
fabian1409 committed Dec 16, 2024
1 parent b91f2cd commit 27c3d60
Show file tree
Hide file tree
Showing 14 changed files with 603 additions and 478 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# split input into shares
cargo run --release --bin co-circom -- split-input --circuit test_vectors/multiplier2/circuit.circom --input test_vectors/multiplier2/input.json --protocol REP3 --curve BN254 --out-dir test_vectors/multiplier2 --config test_vectors/kyc/config.toml
# run witness extension in MPC
cargo run --release --bin co-circom -- generate-witness -O2 --input test_vectors/multiplier2/input.json.0.shared --circuit test_vectors/multiplier2/circuit.circom --protocol REP3 --curve BN254 --config ../configs/party1.toml --out test_vectors/multiplier2/witness.wtns.0.shared &
cargo run --release --bin co-circom -- generate-witness -O2 --input test_vectors/multiplier2/input.json.1.shared --circuit test_vectors/multiplier2/circuit.circom --protocol REP3 --curve BN254 --config ../configs/party2.toml --out test_vectors/multiplier2/witness.wtns.1.shared &
cargo run --release --bin co-circom -- generate-witness -O2 --input test_vectors/multiplier2/input.json.2.shared --circuit test_vectors/multiplier2/circuit.circom --protocol REP3 --curve BN254 --config ../configs/party3.toml --out test_vectors/multiplier2/witness.wtns.2.shared
wait $(jobs -p)
# run translation from REP3 to Shamir
cargo run --release --bin co-circom -- translate-witness --witness test_vectors/multiplier2/witness.wtns.0.shared --src-protocol REP3 --target-protocol SHAMIR --curve BN254 --config ../configs/party1.toml --out test_vectors/multiplier2/shamir_witness.wtns.0.shared &
cargo run --release --bin co-circom -- translate-witness --witness test_vectors/multiplier2/witness.wtns.1.shared --src-protocol REP3 --target-protocol SHAMIR --curve BN254 --config ../configs/party2.toml --out test_vectors/multiplier2/shamir_witness.wtns.1.shared &
cargo run --release --bin co-circom -- translate-witness --witness test_vectors/multiplier2/witness.wtns.2.shared --src-protocol REP3 --target-protocol SHAMIR --curve BN254 --config ../configs/party3.toml --out test_vectors/multiplier2/shamir_witness.wtns.2.shared
wait $(jobs -p)
# run proving in MPC
cargo run --release --bin co-circom -- generate-proof groth16 --witness test_vectors/multiplier2/shamir_witness.wtns.0.shared --zkey test_vectors/multiplier2/multiplier2.zkey --protocol SHAMIR --curve BN254 --config ../configs/party1.toml --out proof.0.json --public-input public_input.json &
cargo run --release --bin co-circom -- generate-proof groth16 --witness test_vectors/multiplier2/shamir_witness.wtns.1.shared --zkey test_vectors/multiplier2/multiplier2.zkey --protocol SHAMIR --curve BN254 --config ../configs/party2.toml --out proof.1.json &
cargo run --release --bin co-circom -- generate-proof groth16 --witness test_vectors/multiplier2/shamir_witness.wtns.2.shared --zkey test_vectors/multiplier2/multiplier2.zkey --protocol SHAMIR --curve BN254 --config ../configs/party3.toml --out proof.2.json
wait $(jobs -p)
# verify proof
cargo run --release --bin co-circom -- verify groth16 --proof proof.0.json --vk test_vectors/multiplier2/verification_key.json --public-input public_input.json --curve BN254
4 changes: 2 additions & 2 deletions mpc-core/src/protocols/bridges/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ impl RepToShamirNetwork<ShamirMpcNet> for Rep3MpcNet {
fn to_shamir_net(self) -> ShamirMpcNet {
let Self {
id,
net_handler,
chan_next,
chan_prev,
queue,
runtime,
} = self;

Expand All @@ -32,8 +32,8 @@ impl RepToShamirNetwork<ShamirMpcNet> for Rep3MpcNet {
ShamirMpcNet {
id: id.into(),
num_parties: 3,
net_handler,
channels,
queue,
runtime,
}
}
Expand Down
64 changes: 33 additions & 31 deletions mpc-core/src/protocols/rep3/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
//!
//! This module contains implementation of the rep3 mpc network
use std::sync::{Arc, Mutex};
use std::sync::Arc;

use crate::RngType;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use bytes::{Bytes, BytesMut};
use eyre::{bail, eyre, Report};
use mpc_net::{channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler};
use eyre::{bail, ContextCompat};
use mpc_net::{
channel::ChannelHandle, config::NetworkConfig, queue::ChannelQueue, MpcNetworkHandler,
};
use tokio::runtime::Runtime;

use super::{
Expand Down Expand Up @@ -224,46 +226,44 @@ pub trait Rep3Network: Send {
Self: Sized;
}

// TODO make generic over codec?
/// This struct can be used to facilitate network communication for the REP3 MPC protocol.
#[derive(Debug)]
pub struct Rep3MpcNet {
pub(crate) id: PartyID,
pub(crate) chan_next: ChannelHandle<Bytes, BytesMut>,
pub(crate) chan_prev: ChannelHandle<Bytes, BytesMut>,
// TODO we should be able to get rid of this mutex once we dont remove streams from the pool anymore
pub(crate) net_handler: Arc<Mutex<MpcNetworkHandler>>,
// order is important, runtime MUST be dropped after network handler
pub(crate) queue: ChannelQueue,
// order is important, runtime MUST be dropped last
pub(crate) runtime: Arc<Runtime>,
}

impl Rep3MpcNet {
/// Takes a [NetworkConfig] struct and constructs the network interface. The network needs to contain exactly 3 parties with ids 0, 1, and 2.
pub fn new(config: NetworkConfig) -> Result<Self, Report> {
pub fn new(config: NetworkConfig) -> eyre::Result<Self> {
if config.parties.len() != 3 {
bail!("REP3 protocol requires exactly 3 parties")
}
let queue_size = config.conn_queue_size;
let id = PartyID::try_from(config.my_id)?;
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
let mut net_handler = runtime.block_on(MpcNetworkHandler::establish(config))?;

let chan_next = net_handler
.get_byte_channel(&id.next_id().into())
.ok_or(eyre!("no next channel found"))?;
let chan_prev = net_handler
.get_byte_channel(&id.prev_id().into())
.ok_or(eyre!("no prev channel found"))?;
let net_handler = runtime.block_on(MpcNetworkHandler::init(config))?;
let queue = runtime.block_on(MpcNetworkHandler::queue(net_handler, queue_size))?;

let chan_next = net_handler.spawn(chan_next);
let chan_prev = net_handler.spawn(chan_prev);
let mut channels = queue.get_channels()?;
let chan_next = channels
.remove(&id.next_id().into())
.context("while removing channel")?;
let chan_prev = channels
.remove(&id.prev_id().into())
.context("while removing channel")?;

Ok(Self {
id,
net_handler: Arc::new(Mutex::new(net_handler)),
chan_next,
chan_prev,
queue,
runtime: Arc::new(runtime),
})
}
Expand Down Expand Up @@ -349,22 +349,24 @@ impl Rep3Network for Rep3MpcNet {
}

fn fork(&mut self) -> std::io::Result<Self> {
let id = self.id;
let mut net_handler = self.net_handler.lock().unwrap();
let chan_next = net_handler
.get_byte_channel(&id.next_id().into())
.expect("no next channel found");
let chan_prev = net_handler
.get_byte_channel(&id.prev_id().into())
.expect("no prev channel found");
let chan_next = net_handler.spawn(chan_next);
let chan_prev = net_handler.spawn(chan_prev);
let mut channels = self.queue.get_channels().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::Other,
"could not get channels from queue, channel died",
)
})?;
let chan_next = channels
.remove(&self.id.next_id().into())
.expect("to find next channel");
let chan_prev = channels
.remove(&self.id.prev_id().into())
.expect("to find prev channel");

Ok(Self {
id,
net_handler: self.net_handler.clone(),
id: self.id,
chan_next,
chan_prev,
queue: self.queue.clone(),
runtime: self.runtime.clone(),
})
}
Expand Down
60 changes: 23 additions & 37 deletions mpc-core/src/protocols/shamir/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use bytes::{Bytes, BytesMut};
use eyre::{bail, eyre, Report};
use mpc_net::{channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
use eyre::bail;
use mpc_net::{
channel::ChannelHandle, config::NetworkConfig, queue::ChannelQueue, MpcNetworkHandler,
};
use std::{collections::HashMap, sync::Arc};
use tokio::runtime::Runtime;

/// This trait defines the network interface for the Shamir protocol.
Expand Down Expand Up @@ -76,20 +75,21 @@ pub trait ShamirNetwork: Send {
}

/// This struct can be used to facilitate network communication for the Shamir MPC protocol.
#[derive(Debug)]
pub struct ShamirMpcNet {
pub(crate) id: usize, // 0 <= id < num_parties
pub(crate) num_parties: usize,
pub(crate) channels: HashMap<usize, ChannelHandle<Bytes, BytesMut>>,
// TODO we should be able to get rid of this mutex once we dont remove streams from the pool anymore
pub(crate) net_handler: Arc<Mutex<MpcNetworkHandler>>,
// order is important, runtime MUST be dropped after network handler
pub(crate) queue: ChannelQueue,
// order is important, runtime MUST be dropped last
pub(crate) runtime: Arc<Runtime>,
}

impl ShamirMpcNet {
/// Takes a [NetworkConfig] struct and constructs the network interface. The network needs to contain at least 3 parties and all ids need to be in the range of 0 <= id < num_parties.
pub fn new(config: NetworkConfig) -> Result<Self, Report> {
pub fn new(config: NetworkConfig) -> eyre::Result<Self> {
let num_parties = config.parties.len();
let queue_size = config.conn_queue_size;

if config.parties.len() <= 2 {
bail!("Shamir protocol requires at least 3 parties")
Expand All @@ -102,23 +102,16 @@ impl ShamirMpcNet {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
let mut net_handler = runtime.block_on(MpcNetworkHandler::establish(config))?;
let mut channels = HashMap::with_capacity(num_parties - 1);

for other_id in 0..num_parties {
if other_id != id {
let chan = net_handler
.get_byte_channel(&other_id)
.ok_or_else(|| eyre!("no channel found for party id={}", other_id))?;
channels.insert(other_id, net_handler.spawn(chan));
}
}
let net_handler = runtime.block_on(MpcNetworkHandler::init(config))?;
let queue = runtime.block_on(MpcNetworkHandler::queue(net_handler, queue_size))?;

let channels = queue.get_channels()?;

Ok(Self {
id,
num_parties,
net_handler: Arc::new(Mutex::new(net_handler)),
channels,
queue,
runtime: Arc::new(runtime),
})
}
Expand Down Expand Up @@ -259,25 +252,18 @@ impl ShamirNetwork for ShamirMpcNet {
}

fn fork(&mut self) -> std::io::Result<Self> {
let id = self.id;
let num_parties = self.num_parties;
let mut net_handler = self.net_handler.lock().unwrap();

let mut channels = HashMap::with_capacity(num_parties - 1);
for other_id in 0..num_parties {
if other_id != id {
let chan = net_handler
.get_byte_channel(&other_id)
.expect("to find channel");
channels.insert(other_id, net_handler.spawn(chan));
}
}
let channels = self.queue.get_channels().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::Other,
"could not get channels from queue, channel died",
)
})?;

Ok(Self {
id,
num_parties,
net_handler: self.net_handler.clone(),
id: self.id,
num_parties: self.num_parties,
channels,
queue: self.queue.clone(),
runtime: self.runtime.clone(),
})
}
Expand Down
1 change: 1 addition & 0 deletions mpc-net/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ bincode = { workspace = true }
bytes = { workspace = true }
clap = { workspace = true }
color-eyre = { workspace = true }
eyre = { workspace = true }
futures = { workspace = true }
rcgen = { workspace = true }
rustls = { workspace = true }
Expand Down
6 changes: 3 additions & 3 deletions mpc-net/examples/three_party.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::path::PathBuf;

use clap::Parser;
use color_eyre::{
eyre::{eyre, Context, ContextCompat},
eyre::{eyre, Context},
Result,
};
use futures::{SinkExt, StreamExt};
Expand Down Expand Up @@ -31,9 +31,9 @@ async fn main() -> Result<()> {
let config = NetworkConfig::try_from(config).context("converting network config")?;
let my_id = config.my_id;

let mut network = MpcNetworkHandler::establish(config).await?;
let network = MpcNetworkHandler::init(config).await?;

let mut channels = network.get_byte_channels().context("get channels")?;
let mut channels = network.get_byte_channels().await?;

// send to all channels
for (&i, channel) in channels.iter_mut() {
Expand Down
8 changes: 3 additions & 5 deletions mpc-net/examples/three_party_bincode_channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::path::PathBuf;

use clap::Parser;
use color_eyre::{
eyre::{eyre, Context, ContextCompat},
eyre::{eyre, Context},
Result,
};
use futures::{SinkExt, StreamExt};
Expand Down Expand Up @@ -32,11 +32,9 @@ async fn main() -> Result<()> {
let config = NetworkConfig::try_from(config).context("converting network config")?;
let my_id = config.my_id;

let mut network = MpcNetworkHandler::establish(config).await?;
let network = MpcNetworkHandler::init(config).await?;

let mut channels = network
.get_serde_bincode_channels()
.context("get channels")?;
let mut channels = network.get_serde_bincode_channels().await?;

// send to all channels
for (&i, channel) in channels.iter_mut() {
Expand Down
6 changes: 3 additions & 3 deletions mpc-net/examples/three_party_custom_channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::path::PathBuf;
use bytes::{Buf, BufMut};
use clap::Parser;
use color_eyre::{
eyre::{eyre, Context, ContextCompat},
eyre::{eyre, Context},
Result,
};
use futures::{SinkExt, StreamExt};
Expand Down Expand Up @@ -33,10 +33,10 @@ async fn main() -> Result<()> {
let config = NetworkConfig::try_from(config).context("converting network config")?;
let my_id = config.my_id;

let mut network = MpcNetworkHandler::establish(config).await?;
let network = MpcNetworkHandler::init(config).await?;

let codec = MessageCodec;
let mut channels = network.get_custom_channels(codec).context("get channels")?;
let mut channels = network.get_custom_channels(codec).await?;

// send to all channels
for (&i, channel) in channels.iter_mut() {
Expand Down
12 changes: 4 additions & 8 deletions mpc-net/examples/three_party_managed.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::{collections::HashMap, path::PathBuf};
use std::path::PathBuf;

use clap::Parser;
use color_eyre::{
eyre::{eyre, Context, ContextCompat},
eyre::{eyre, Context},
Result,
};
use mpc_net::{
Expand Down Expand Up @@ -30,13 +30,9 @@ async fn main() -> Result<()> {
let config = NetworkConfig::try_from(config).context("converting network config")?;
let my_id = config.my_id;

let mut network = MpcNetworkHandler::establish(config).await?;
let network = MpcNetworkHandler::init(config).await?;

let channels = network.get_byte_channels().context("get channels")?;
let mut managed_channels = channels
.into_iter()
.map(|(i, c)| (i, network.spawn(c)))
.collect::<HashMap<_, _>>();
let mut managed_channels = network.get_byte_channels_managed().await?;

// send to all channels
for (&i, channel) in managed_channels.iter_mut() {
Expand Down
Loading

0 comments on commit 27c3d60

Please sign in to comment.