Skip to content

Commit

Permalink
Allow specifying a custom ALPN
Browse files Browse the repository at this point in the history
This way you can e.g. use dumbpipe to test iroh-bytes using bash scripts and
predefined request files.
  • Loading branch information
rklaehn committed Dec 14, 2023
1 parent 86697ba commit 99acd59
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 39 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ anyhow = "1.0.75"
base32 = "0.4.0"
clap = { version = "4.4.10", features = ["derive"] }
data-encoding = "2.5.0"
hex = "0.4.3"
iroh-net = "0.11.0"
postcard = "1.0.8"
quinn = "0.10.2"
Expand Down
133 changes: 95 additions & 38 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,59 @@ pub enum Commands {
}

#[derive(Parser, Debug)]
pub struct ListenArgs {
/// The port to listen on.
pub struct CommonArgs {
/// The port to use for the magicsocket. Random by default.
#[clap(long, default_value_t = 0)]
pub magic_port: u16,

/// A custom ALPN to use for the magicsocket.
///
/// This is an expert feature that allows dumbpipe to be used to interact
/// with existing iroh protocols.
///
/// When using this option, the connect side must also specify the same ALPN.
/// The listen side will not expect a handshake, and the connect side will
/// not send one.
///
/// Alpns are byte strings. To specify an utf8 string, prefix it with `utf8:`.
/// Otherwise, it will be parsed as a hex string.
pub custom_alpn: Option<String>,
}

impl CommonArgs {
fn alpns(&self) -> anyhow::Result<Vec<Vec<u8>>> {
Ok(vec![match &self.custom_alpn {
Some(alpn) => parse_alpn(alpn)?,
None => dumbpipe::ALPN.to_vec(),
}])
}

fn is_custom_alpn(&self) -> bool {
self.custom_alpn.is_some()
}
}

fn parse_alpn(alpn: &str) -> anyhow::Result<Vec<u8>> {
Ok(if let Some(text) = alpn.strip_prefix("utf8:") {
text.as_bytes().to_vec()
} else {
hex::decode(alpn)?
})
}

#[derive(Parser, Debug)]
pub struct ListenArgs {
#[clap(flatten)]
pub common: CommonArgs,
}

#[derive(Parser, Debug)]
pub struct ListenTcpArgs {
#[clap(long)]
pub host: String,

/// The port to use for the magicsocket. Random by default.
#[clap(long, default_value_t = 0)]
pub magic_port: u16,
#[clap(flatten)]
pub common: CommonArgs,
}

#[derive(Parser, Debug)]
Expand All @@ -92,22 +131,20 @@ pub struct ConnectTcpArgs {
#[clap(long)]
pub addr: String,

/// The port to use for the magicsocket. Random by default.
#[clap(long, default_value_t = 0)]
pub magic_port: u16,

/// The node to connect to
pub ticket: NodeTicket,

#[clap(flatten)]
pub common: CommonArgs,
}

#[derive(Parser, Debug)]
pub struct ConnectArgs {
/// The node to connect to
pub ticket: NodeTicket,

/// The port to bind to.
#[clap(long, default_value_t = 0)]
pub port: u16,
#[clap(flatten)]
pub common: CommonArgs,
}

/// Copy from a reader to a quinn stream.
Expand Down Expand Up @@ -214,9 +251,9 @@ async fn forward_bidi(
async fn listen_stdio(args: ListenArgs) -> anyhow::Result<()> {
let secret_key = get_or_create_secret()?;
let endpoint = MagicEndpoint::builder()
.alpns(vec![dumbpipe::ALPN.to_vec()])
.alpns(args.common.alpns()?)
.secret_key(secret_key)
.bind(args.magic_port)
.bind(args.common.magic_port)
.await?;
// wait for the endpoint to figure out its address before making a ticket
while endpoint.my_derp().is_none() {
Expand Down Expand Up @@ -252,10 +289,12 @@ async fn listen_stdio(args: ListenArgs) -> anyhow::Result<()> {
}
};
tracing::info!("accepted bidi stream from {}", remote_node_id);
// read the handshake and verify it
let mut buf = [0u8; 5];
r.read_exact(&mut buf).await?;
anyhow::ensure!(buf == dumbpipe::HANDSHAKE, "invalid handshake");
if !args.common.is_custom_alpn() {
// read the handshake and verify it
let mut buf = [0u8; dumbpipe::HANDSHAKE.len()];
r.read_exact(&mut buf).await?;
anyhow::ensure!(buf == dumbpipe::HANDSHAKE, "invalid handshake");
}
tracing::info!("forwarding stdin/stdout to {}", remote_node_id);
forward_bidi(tokio::io::stdin(), tokio::io::stdout(), r, s).await?;
// stop accepting connections after the first successful one
Expand All @@ -268,8 +307,8 @@ async fn connect_stdio(args: ConnectArgs) -> anyhow::Result<()> {
let secret_key = get_or_create_secret()?;
let endpoint = MagicEndpoint::builder()
.secret_key(secret_key)
.alpns(vec![dumbpipe::ALPN.to_vec()])
.bind(args.port)
.alpns(vec![])
.bind(args.common.magic_port)
.await?;
let addr = args.ticket.addr;
let remote_node_id = addr.node_id;
Expand All @@ -279,9 +318,13 @@ async fn connect_stdio(args: ConnectArgs) -> anyhow::Result<()> {
// open a bidi stream, try only once
let (mut s, r) = connection.open_bi().await?;
tracing::info!("opened bidi stream to {}", remote_node_id);
// the connecting side must write first. we don't know if there will be something
// on stdin, so just write a handshake.
s.write_all(&dumbpipe::HANDSHAKE).await?;
// send the handshake unless we are using a custom alpn
// when using a custom alpn, evertyhing is up to the user
if !args.common.is_custom_alpn() {
// the connecting side must write first. we don't know if there will be something
// on stdin, so just write a handshake.
s.write_all(&dumbpipe::HANDSHAKE).await?;
}
tracing::info!("forwarding stdin/stdout to {}", remote_node_id);
forward_bidi(tokio::io::stdin(), tokio::io::stdout(), r, s).await?;
tokio::io::stdout().flush().await?;
Expand All @@ -296,9 +339,9 @@ async fn connect_tcp(args: ConnectTcpArgs) -> anyhow::Result<()> {
.context(format!("invalid host string {}", args.addr))?;
let secret_key = get_or_create_secret()?;
let endpoint = MagicEndpoint::builder()
.alpns(vec![dumbpipe::ALPN.to_vec()])
.alpns(vec![])
.secret_key(secret_key)
.bind(args.magic_port)
.bind(args.common.magic_port)
.await
.context("unable to bind magicsock")?;
tracing::info!("tcp listening on {:?}", addrs);
Expand All @@ -313,6 +356,7 @@ async fn connect_tcp(args: ConnectTcpArgs) -> anyhow::Result<()> {
next: io::Result<(tokio::net::TcpStream, SocketAddr)>,
addr: NodeAddr,
endpoint: MagicEndpoint,
handshake: bool,
) -> anyhow::Result<()> {
let (tcp_stream, tcp_addr) = next.context("error accepting tcp connection")?;
let (tcp_recv, tcp_send) = tcp_stream.into_split();
Expand All @@ -326,7 +370,13 @@ async fn connect_tcp(args: ConnectTcpArgs) -> anyhow::Result<()> {
.open_bi()
.await
.context(format!("error opening bidi stream to {}", remote_node_id))?;
magic_send.write_all(&dumbpipe::HANDSHAKE).await?;
// send the handshake unless we are using a custom alpn
// when using a custom alpn, evertyhing is up to the user
if handshake {
// the connecting side must write first. we don't know if there will be something
// on stdin, so just write a handshake.
magic_send.write_all(&dumbpipe::HANDSHAKE).await?;
}
forward_bidi(tcp_recv, tcp_send, magic_recv, magic_send).await?;
anyhow::Ok(())
}
Expand All @@ -342,8 +392,9 @@ async fn connect_tcp(args: ConnectTcpArgs) -> anyhow::Result<()> {
};
let endpoint = endpoint.clone();
let addr = addr.clone();
let handshake = !args.common.is_custom_alpn();
tokio::spawn(async move {
if let Err(cause) = handle_tcp_accept(next, addr, endpoint).await {
if let Err(cause) = handle_tcp_accept(next, addr, endpoint, handshake).await {
// log error at warn level
//
// we should know about it, but it's not fatal
Expand All @@ -362,31 +413,34 @@ async fn listen_tcp(args: ListenTcpArgs) -> anyhow::Result<()> {
};
let secret_key = get_or_create_secret()?;
let endpoint = MagicEndpoint::builder()
.alpns(vec![dumbpipe::ALPN.to_vec()])
.alpns(args.common.alpns()?)
.secret_key(secret_key)
.bind(args.magic_port)
.bind(args.common.magic_port)
.await?;
// wait for the endpoint to figure out its address before making a ticket
while endpoint.my_derp().is_none() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
let addr = endpoint.my_addr().await?;
let ticket = NodeTicket { addr };
let mut short_ticket = ticket.clone();
short_ticket.addr.info.direct_addresses.clear();

// print the ticket on stderr so it doesn't interfere with the data itself
//
// note that the tests rely on the ticket being the last thing printed
eprintln!(
"Forwarding incoming requests to '{}'. To connect, use e.g.:\ndumbpipe connect {}",
args.host, ticket
);
eprintln!("Forwarding incoming requests to '{}'.", args.host);
eprintln!("To connect, use e.g.:");
eprintln!("dumbpipe connect {short_ticket}");
eprintln!("dumbpipe connect {ticket}");
tracing::info!("node id is {}", ticket.addr.node_id);
tracing::info!("derp region is {:?}", ticket.addr.info.derp_region);

// handle a new incoming connection on the magic endpoint
async fn handle_magic_accept(
connecting: quinn::Connecting,
addrs: Vec<std::net::SocketAddr>,
handshake: bool,
) -> anyhow::Result<()> {
let connection = connecting.await.context("error accepting connection")?;
let remote_node_id = get_remote_node_id(&connection)?;
Expand All @@ -396,10 +450,12 @@ async fn listen_tcp(args: ListenTcpArgs) -> anyhow::Result<()> {
.await
.context("error accepting stream")?;
tracing::info!("accepted bidi stream from {}", remote_node_id);
// read the handshake and verify it
let mut buf = [0u8; 5];
r.read_exact(&mut buf).await?;
anyhow::ensure!(buf == dumbpipe::HANDSHAKE, "invalid handshake");
if handshake {
// read the handshake and verify it
let mut buf = [0u8; dumbpipe::HANDSHAKE.len()];
r.read_exact(&mut buf).await?;
anyhow::ensure!(buf == dumbpipe::HANDSHAKE, "invalid handshake");
}
let connection = tokio::net::TcpStream::connect(addrs.as_slice())
.await
.context(format!("error connecting to {:?}", addrs))?;
Expand All @@ -420,8 +476,9 @@ async fn listen_tcp(args: ListenTcpArgs) -> anyhow::Result<()> {
break;
};
let addrs = addrs.clone();
let handshake = !args.common.is_custom_alpn();
tokio::spawn(async move {
if let Err(cause) = handle_magic_accept(connecting, addrs).await {
if let Err(cause) = handle_magic_accept(connecting, addrs, handshake).await {
// log error at warn level
//
// we should know about it, but it's not fatal
Expand Down
2 changes: 1 addition & 1 deletion tests/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ fn listen_tcp_happy() {
.stderr_to_stdout() //
.reader()
.unwrap();
let header = read_ascii_lines(3, &mut listen_tcp).unwrap();
let header = read_ascii_lines(4, &mut listen_tcp).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = NodeTicket::from_str(ticket).unwrap();
Expand Down

0 comments on commit 99acd59

Please sign in to comment.