Skip to content

Commit

Permalink
Introduce opt-in access control for iroh-net transport
Browse files Browse the repository at this point in the history
  • Loading branch information
fogodev committed Oct 30, 2024
1 parent 0eabf25 commit 4307dfb
Showing 1 changed file with 79 additions and 12 deletions.
91 changes: 79 additions & 12 deletions src/transport/iroh_net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@ use crate::{
};

use std::{
collections::BTreeSet,
fmt,
future::Future,
io,
iter::once,
marker::PhantomData,
net::SocketAddr,
pin::pin,
pin::Pin,
pin::{pin, Pin},
sync::Arc,
task::{Context, Poll},
};

use futures_lite::{Stream, StreamExt};
use futures_sink::Sink;
use futures_util::FutureExt;
use iroh_net::NodeAddr;
use iroh_net::{NodeAddr, NodeId};
use pin_project::pin_project;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::oneshot;
Expand Down Expand Up @@ -70,6 +70,16 @@ impl Drop for ServerEndpointInner {
}
}

/// Access control for the server, either unrestricted or limited to a list of nodes that can
/// connect to the server endpoint
#[derive(Debug, Clone)]
pub enum AccessControl {
/// Unrestricted access, anyone can connect
Unrestricted,
/// Restricted access, only nodes in the list can connect, all other nodes will be rejected
Allowed(Vec<NodeId>),
}

/// A server endpoint using a quinn connection
#[derive(Debug)]
pub struct IrohNetServerEndpoint<In: RpcMessage, Out: RpcMessage> {
Expand Down Expand Up @@ -103,40 +113,97 @@ impl<In: RpcMessage, Out: RpcMessage> IrohNetServerEndpoint<In, Out> {
}
}

async fn endpoint_handler(endpoint: iroh_net::Endpoint, sender: flume::Sender<SocketInner>) {
async fn endpoint_handler(
endpoint: iroh_net::Endpoint,
sender: flume::Sender<SocketInner>,
allowed_node_ids: BTreeSet<NodeId>,
) {
loop {
tracing::debug!("Waiting for incoming connection...");
let connecting = match endpoint.accept().await {
Some(connecting) => connecting,
None => break,
};

tracing::debug!("Awaiting connection from connect...");
let conection = match connecting.await {
Ok(conection) => conection,
let connection = match connecting.await {
Ok(connection) => connection,
Err(e) => {
tracing::warn!("Error accepting connection: {}", e);
continue;
}
};

// When the `allowed_node_ids` is empty, it's empty forever, so the CPU's branch
// prediction should always optimize this block away from this loop.
// The same applies when it isn't empty, ignoring the check for emptiness and always
// extracting the node id and checking if it's in the set.
if !allowed_node_ids.is_empty() {
let Ok(client_node_id) = iroh_net::endpoint::get_remote_node_id(&connection)
.map_err(|e| {
tracing::error!(
?e,
"Failed to extract iroh-net node id from incoming connection from {:?}",
connection.remote_address()
)
})
else {
connection.close(0u32.into(), b"failed to extract iroh-net node id");
continue;
};

if !allowed_node_ids.contains(&client_node_id) {
connection.close(0u32.into(), b"forbidden node id");
continue;
}
}

tracing::debug!(
"Connection established from {:?}",
conection.remote_address()
connection.remote_address()
);

tracing::debug!("Spawning connection handler...");
tokio::spawn(Self::connection_handler(conection, sender.clone()));
tokio::spawn(Self::connection_handler(connection, sender.clone()));
}
}

/// Create a new server channel, given a quinn endpoint.
///
/// The endpoint must be a server endpoint.
/// Create a new server channel, given a quinn endpoint, with unrestricted access by node id
///
/// The server channel will take care of listening on the endpoint and spawning
/// handlers for new connections.
pub fn new(endpoint: iroh_net::Endpoint) -> io::Result<Self> {
Self::new_with_access_control(endpoint, AccessControl::Unrestricted)
}

/// Create a new server endpoint, with specified access control
///
/// The server channel will take care of listening on the endpoint and spawning
/// handlers for new connections.
pub fn new_with_access_control(
endpoint: iroh_net::Endpoint,
access_control: AccessControl,
) -> io::Result<Self> {
let allowed_node_ids = match access_control {
AccessControl::Unrestricted => BTreeSet::new(),
AccessControl::Allowed(list) if list.is_empty() => {
tracing::warn!(
"Allowed list of `NodeId`s is empty, iroh-net \
quic-rpc endpoint will have unrestricted access!"
);
BTreeSet::new()
}
AccessControl::Allowed(list) => BTreeSet::from_iter(list),
};

let (ipv4_socket_addr, maybe_ipv6_socket_addr) = endpoint.bound_sockets();
let (sender, receiver) = flume::bounded(16);
let task = tokio::spawn(Self::endpoint_handler(endpoint.clone(), sender));
let task = tokio::spawn(Self::endpoint_handler(
endpoint.clone(),
sender,
allowed_node_ids,
));

Ok(Self {
inner: Arc::new(ServerEndpointInner {
endpoint: Some(endpoint),
Expand Down

0 comments on commit 4307dfb

Please sign in to comment.