Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: replace quinn with tokio_rustls #227

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ noirc-artifacts = { version = "1.0.0-beta.0", git = "https://github.com/noir-lan
num-bigint = { version = "0.4.5" }
num-traits = { version = "0.2.18", default-features = false }
paste = "1.0.15"
quinn = "0.11"
rand = "0.8.5"
rand_chacha = "0.3"
rayon = "1.8.1"
Expand All @@ -80,6 +79,7 @@ tokio = { version = "1.34.0", features = [
"io-util",
"macros",
] }
tokio-rustls = "0.26.0"
tokio-util = { version = "0.7.10", features = ["codec"] }
toml = "0.8.13"
tracing = { version = "0.1.40" }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# split input into shares
cargo run --release --bin co-circom -- split-input --circuit test_vectors/multiplier2/circuit.circom --input test_vectors/multiplier2/input.json --protocol REP3 --curve BN254 --out-dir test_vectors/multiplier2 --config test_vectors/kyc/config.toml
# run witness extension in MPC
cargo run --release --bin co-circom -- generate-witness -O2 --input test_vectors/multiplier2/input.json.0.shared --circuit test_vectors/multiplier2/circuit.circom --protocol REP3 --curve BN254 --config ../configs/party1.toml --out test_vectors/multiplier2/witness.wtns.0.shared &
cargo run --release --bin co-circom -- generate-witness -O2 --input test_vectors/multiplier2/input.json.1.shared --circuit test_vectors/multiplier2/circuit.circom --protocol REP3 --curve BN254 --config ../configs/party2.toml --out test_vectors/multiplier2/witness.wtns.1.shared &
cargo run --release --bin co-circom -- generate-witness -O2 --input test_vectors/multiplier2/input.json.2.shared --circuit test_vectors/multiplier2/circuit.circom --protocol REP3 --curve BN254 --config ../configs/party3.toml --out test_vectors/multiplier2/witness.wtns.2.shared
wait $(jobs -p)
# run translation from REP3 to Shamir
cargo run --release --bin co-circom -- translate-witness --witness test_vectors/multiplier2/witness.wtns.0.shared --src-protocol REP3 --target-protocol SHAMIR --curve BN254 --config ../configs/party1.toml --out test_vectors/multiplier2/shamir_witness.wtns.0.shared &
cargo run --release --bin co-circom -- translate-witness --witness test_vectors/multiplier2/witness.wtns.1.shared --src-protocol REP3 --target-protocol SHAMIR --curve BN254 --config ../configs/party2.toml --out test_vectors/multiplier2/shamir_witness.wtns.1.shared &
cargo run --release --bin co-circom -- translate-witness --witness test_vectors/multiplier2/witness.wtns.2.shared --src-protocol REP3 --target-protocol SHAMIR --curve BN254 --config ../configs/party3.toml --out test_vectors/multiplier2/shamir_witness.wtns.2.shared
wait $(jobs -p)
# run proving in MPC
cargo run --release --bin co-circom -- generate-proof groth16 --witness test_vectors/multiplier2/shamir_witness.wtns.0.shared --zkey test_vectors/multiplier2/multiplier2.zkey --protocol SHAMIR --curve BN254 --config ../configs/party1.toml --out proof.0.json --public-input public_input.json &
cargo run --release --bin co-circom -- generate-proof groth16 --witness test_vectors/multiplier2/shamir_witness.wtns.1.shared --zkey test_vectors/multiplier2/multiplier2.zkey --protocol SHAMIR --curve BN254 --config ../configs/party2.toml --out proof.1.json &
cargo run --release --bin co-circom -- generate-proof groth16 --witness test_vectors/multiplier2/shamir_witness.wtns.2.shared --zkey test_vectors/multiplier2/multiplier2.zkey --protocol SHAMIR --curve BN254 --config ../configs/party3.toml --out proof.2.json
wait $(jobs -p)
# verify proof
cargo run --release --bin co-circom -- verify groth16 --proof proof.0.json --vk test_vectors/multiplier2/verification_key.json --public-input public_input.json --curve BN254
6 changes: 4 additions & 2 deletions mpc-core/src/protocols/bridges/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ impl RepToShamirNetwork<ShamirMpcNet> for Rep3MpcNet {
fn to_shamir_net(self) -> ShamirMpcNet {
let Self {
id,
net_handler,
chan_next,
chan_prev,
queue,
runtime,
} = self;

let mut channels = HashMap::with_capacity(2);
Expand All @@ -31,8 +32,9 @@ impl RepToShamirNetwork<ShamirMpcNet> for Rep3MpcNet {
ShamirMpcNet {
id: id.into(),
num_parties: 3,
net_handler,
channels,
queue,
runtime,
}
}
}
15 changes: 3 additions & 12 deletions mpc-core/src/protocols/rep3/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,15 @@ pub fn shift_l_public_by_shared<F: PrimeField, N: Rep3Network>(
// Strategy: limit size of b to k bits
// bit-decompose b into bits b_i

// TODO: this sucks... we need something better here...
let io_0 = io_context.fork()?;
let io_1 = io_context.fork()?;
let io_2 = io_context.fork()?;
let io_3 = io_context.fork()?;
let io_4 = io_context.fork()?;
let io_5 = io_context.fork()?;
let io_6 = io_context.fork()?;
let io_7 = io_context.fork()?;
let mut contexts = [io_0, io_1, io_2, io_3, io_4, io_5, io_6, io_7];
// TODO: make the b2a conversion concurrent again
let party_id = io_context.id;
let mut individual_bit_shares = Vec::with_capacity(8);
for (i, context) in izip!((0..8), contexts.iter_mut()) {
for i in 0..8 {
let bit = Rep3BigUintShare::new(
(shared.a.clone() >> i) & BigUint::one(),
(shared.b.clone() >> i) & BigUint::one(),
);
individual_bit_shares.push(conversion::b2a_selector(&bit, context)?);
individual_bit_shares.push(conversion::b2a_selector(&bit, io_context)?);
}
// v_i = 2^2^i * <b_i> + 1 - <b_i>
let mut vs: Vec<_> = individual_bit_shares
Expand Down
76 changes: 34 additions & 42 deletions mpc-core/src/protocols/rep3/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ use crate::RngType;
use ark_ff::PrimeField;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use bytes::{Bytes, BytesMut};
use eyre::{bail, eyre, Report};
use eyre::{bail, ContextCompat};
use mpc_net::{
channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler, MpcNetworkHandlerWrapper,
channel::ChannelHandle, config::NetworkConfig, queue::ChannelQueue, MpcNetworkHandler,
};
use tokio::runtime::Runtime;

use super::{
conversion::A2BType,
Expand Down Expand Up @@ -245,48 +246,45 @@ pub trait Rep3Network: Send {
Self: Sized;
}

// TODO make generic over codec?
/// This struct can be used to facilitate network communication for the REP3 MPC protocol.
#[derive(Debug)]
pub struct Rep3MpcNet {
pub(crate) id: PartyID,
pub(crate) chan_next: ChannelHandle<Bytes, BytesMut>,
pub(crate) chan_prev: ChannelHandle<Bytes, BytesMut>,
pub(crate) net_handler: Arc<MpcNetworkHandlerWrapper>,
pub(crate) queue: ChannelQueue,
// order is important, runtime MUST be dropped last
pub(crate) runtime: Arc<Runtime>,
}

impl Rep3MpcNet {
/// Takes a [NetworkConfig] struct and constructs the network interface. The network needs to contain exactly 3 parties with ids 0, 1, and 2.
pub fn new(config: NetworkConfig) -> Result<Self, Report> {
pub fn new(config: NetworkConfig) -> eyre::Result<Self> {
if config.parties.len() != 3 {
bail!("REP3 protocol requires exactly 3 parties")
}
let queue_size = config.conn_queue_size;
let id = PartyID::try_from(config.my_id)?;
let runtime = tokio::runtime::Builder::new_multi_thread()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to limit the # of worker threads here to 2 or so?

.enable_all()
.build()?;
let (net_handler, chan_next, chan_prev) = runtime.block_on(async {
let net_handler = MpcNetworkHandler::establish(config).await?;
let mut channels = net_handler.get_byte_channels().await?;
let chan_next = channels
.remove(&id.next_id().into())
.ok_or(eyre!("no next channel found"))?;
let chan_prev = channels
.remove(&id.prev_id().into())
.ok_or(eyre!("no prev channel found"))?;
if !channels.is_empty() {
bail!("unexpected channels found")
}
let net_handler = runtime.block_on(MpcNetworkHandler::init(config))?;
let queue = runtime.block_on(MpcNetworkHandler::queue(net_handler, queue_size))?;

let mut channels = queue.get_channels()?;
let chan_next = channels
.remove(&id.next_id().into())
.context("while removing channel")?;
let chan_prev = channels
.remove(&id.prev_id().into())
.context("while removing channel")?;

let chan_next = ChannelHandle::manage(chan_next);
let chan_prev = ChannelHandle::manage(chan_prev);
Ok((net_handler, chan_next, chan_prev))
})?;
Ok(Self {
id,
net_handler: Arc::new(MpcNetworkHandlerWrapper::new(runtime, net_handler)),
chan_next,
chan_prev,
queue,
runtime: Arc::new(runtime),
})
}

Expand Down Expand Up @@ -371,31 +369,25 @@ impl Rep3Network for Rep3MpcNet {
}

fn fork(&mut self) -> std::io::Result<Self> {
let id = self.id;
let net_handler = Arc::clone(&self.net_handler);
let (chan_next, chan_prev) = net_handler.runtime.block_on(async {
let mut channels = net_handler.inner.get_byte_channels().await?;

let chan_next = channels
.remove(&id.next_id().into())
.expect("to find next channel");
let chan_prev = channels
.remove(&id.prev_id().into())
.expect("to find prev channel");
if !channels.is_empty() {
panic!("unexpected channels found")
}

let chan_next = ChannelHandle::manage(chan_next);
let chan_prev = ChannelHandle::manage(chan_prev);
Ok::<_, std::io::Error>((chan_next, chan_prev))
let mut channels = self.queue.get_channels().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::Other,
"could not get channels from queue, channel died",
)
})?;
let chan_next = channels
.remove(&self.id.next_id().into())
.expect("to find next channel");
let chan_prev = channels
.remove(&self.id.prev_id().into())
.expect("to find prev channel");

Ok(Self {
id,
net_handler,
id: self.id,
chan_next,
chan_prev,
queue: self.queue.clone(),
runtime: self.runtime.clone(),
})
}
}
70 changes: 23 additions & 47 deletions mpc-core/src/protocols/shamir/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use bytes::{Bytes, BytesMut};
use eyre::{bail, eyre, Report};
use eyre::bail;
use mpc_net::{
channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler, MpcNetworkHandlerWrapper,
channel::ChannelHandle, config::NetworkConfig, queue::ChannelQueue, MpcNetworkHandler,
};
use std::{collections::HashMap, sync::Arc};
use tokio::runtime::Runtime;

/// This trait defines the network interface for the Shamir protocol.
pub trait ShamirNetwork: Send {
Expand Down Expand Up @@ -74,17 +75,21 @@ pub trait ShamirNetwork: Send {
}

/// This struct can be used to facilitate network communication for the Shamir MPC protocol.
#[derive(Debug)]
pub struct ShamirMpcNet {
pub(crate) id: usize, // 0 <= id < num_parties
pub(crate) num_parties: usize,
pub(crate) channels: HashMap<usize, ChannelHandle<Bytes, BytesMut>>,
pub(crate) net_handler: Arc<MpcNetworkHandlerWrapper>,
pub(crate) queue: ChannelQueue,
// order is important, runtime MUST be dropped last
pub(crate) runtime: Arc<Runtime>,
}

impl ShamirMpcNet {
/// Takes a [NetworkConfig] struct and constructs the network interface. The network needs to contain at least 3 parties and all ids need to be in the range of 0 <= id < num_parties.
pub fn new(config: NetworkConfig) -> Result<Self, Report> {
pub fn new(config: NetworkConfig) -> eyre::Result<Self> {
let num_parties = config.parties.len();
let queue_size = config.conn_queue_size;

if config.parties.len() <= 2 {
bail!("Shamir protocol requires at least 3 parties")
Expand All @@ -97,32 +102,17 @@ impl ShamirMpcNet {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
let (net_handler, channels) = runtime.block_on(async {
let net_handler = MpcNetworkHandler::establish(config).await?;
let mut channels = net_handler.get_byte_channels().await?;

let mut channels_ = HashMap::with_capacity(num_parties - 1);

for other_id in 0..num_parties {
if other_id != id {
let chan = channels
.remove(&other_id)
.ok_or_else(|| eyre!("no channel found for party id={}", other_id))?;
channels_.insert(other_id, ChannelHandle::manage(chan));
}
}
let net_handler = runtime.block_on(MpcNetworkHandler::init(config))?;
let queue = runtime.block_on(MpcNetworkHandler::queue(net_handler, queue_size))?;

if !channels.is_empty() {
bail!("unexpected channels found")
}
let channels = queue.get_channels()?;

Ok((net_handler, channels_))
})?;
Ok(Self {
id,
num_parties,
net_handler: Arc::new(MpcNetworkHandlerWrapper::new(runtime, net_handler)),
channels,
queue,
runtime: Arc::new(runtime),
})
}

Expand Down Expand Up @@ -262,33 +252,19 @@ impl ShamirNetwork for ShamirMpcNet {
}

fn fork(&mut self) -> std::io::Result<Self> {
let id = self.id;
let num_parties = self.num_parties;
let net_handler = Arc::clone(&self.net_handler);
let channels = net_handler.runtime.block_on(async {
let mut channels = net_handler.inner.get_byte_channels().await?;

let mut channels_ = HashMap::with_capacity(num_parties - 1);

for other_id in 0..num_parties {
if other_id != id {
let chan = channels.remove(&other_id).expect("to find channel");
channels_.insert(other_id, ChannelHandle::manage(chan));
}
}

if !channels.is_empty() {
panic!("unexpected channels found")
}

Ok::<_, std::io::Error>(channels_)
let channels = self.queue.get_channels().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::Other,
"could not get channels from queue, channel died",
)
})?;

Ok(Self {
id,
num_parties,
net_handler,
id: self.id,
num_parties: self.num_parties,
channels,
queue: self.queue.clone(),
runtime: self.runtime.clone(),
})
}

Expand Down
3 changes: 2 additions & 1 deletion mpc-net/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ bincode = { workspace = true }
bytes = { workspace = true }
clap = { workspace = true }
color-eyre = { workspace = true }
eyre = { workspace = true }
futures = { workspace = true }
quinn.workspace = true
rcgen = { workspace = true }
rustls = { workspace = true }
serde = { workspace = true }
tokio = { workspace = true }
tokio-rustls.workspace = true
tokio-util.workspace = true
toml.workspace = true
tracing = { workspace = true }
3 changes: 2 additions & 1 deletion mpc-net/examples/three_party.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async fn main() -> Result<()> {
let config = NetworkConfig::try_from(config).context("converting network config")?;
let my_id = config.my_id;

let network = MpcNetworkHandler::establish(config).await?;
let network = MpcNetworkHandler::init(config).await?;

let mut channels = network.get_byte_channels().await?;

Expand All @@ -48,6 +48,7 @@ async fn main() -> Result<()> {
assert!(b.iter().all(|&x| x == my_id as u8))
}
}

network.print_connection_stats(&mut std::io::stdout())?;

Ok(())
Expand Down
3 changes: 2 additions & 1 deletion mpc-net/examples/three_party_bincode_channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async fn main() -> Result<()> {
let config = NetworkConfig::try_from(config).context("converting network config")?;
let my_id = config.my_id;

let network = MpcNetworkHandler::establish(config).await?;
let network = MpcNetworkHandler::init(config).await?;

let mut channels = network.get_serde_bincode_channels().await?;

Expand Down Expand Up @@ -64,6 +64,7 @@ async fn main() -> Result<()> {
panic!("could not receive message");
}
}

network.print_connection_stats(&mut std::io::stdout())?;

Ok(())
Expand Down
3 changes: 2 additions & 1 deletion mpc-net/examples/three_party_custom_channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async fn main() -> Result<()> {
let config = NetworkConfig::try_from(config).context("converting network config")?;
let my_id = config.my_id;

let network = MpcNetworkHandler::establish(config).await?;
let network = MpcNetworkHandler::init(config).await?;

let codec = MessageCodec;
let mut channels = network.get_custom_channels(codec).await?;
Expand Down Expand Up @@ -66,6 +66,7 @@ async fn main() -> Result<()> {
panic!("could not receive message");
}
}

network.print_connection_stats(&mut std::io::stdout())?;

Ok(())
Expand Down
Loading
Loading