Skip to content

Commit

Permalink
Add test for custom alpn feature
Browse files Browse the repository at this point in the history
just a happy case test that tests that the args are properly parsed etc.
  • Loading branch information
rklaehn committed Dec 14, 2023
1 parent 99acd59 commit b9318bb
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 8 deletions.
19 changes: 11 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,16 @@ pub struct CommonArgs {
///
/// Alpns are byte strings. To specify an utf8 string, prefix it with `utf8:`.
/// Otherwise, it will be parsed as a hex string.
#[clap(long)]
pub custom_alpn: Option<String>,
}

impl CommonArgs {
fn alpns(&self) -> anyhow::Result<Vec<Vec<u8>>> {
Ok(vec![match &self.custom_alpn {
fn alpn(&self) -> anyhow::Result<Vec<u8>> {
Ok(match &self.custom_alpn {
Some(alpn) => parse_alpn(alpn)?,
None => dumbpipe::ALPN.to_vec(),
}])
})
}

fn is_custom_alpn(&self) -> bool {
Expand Down Expand Up @@ -251,7 +252,7 @@ async fn forward_bidi(
async fn listen_stdio(args: ListenArgs) -> anyhow::Result<()> {
let secret_key = get_or_create_secret()?;
let endpoint = MagicEndpoint::builder()
.alpns(args.common.alpns()?)
.alpns(vec![args.common.alpn()?])
.secret_key(secret_key)
.bind(args.common.magic_port)
.await?;
Expand Down Expand Up @@ -313,7 +314,7 @@ async fn connect_stdio(args: ConnectArgs) -> anyhow::Result<()> {
let addr = args.ticket.addr;
let remote_node_id = addr.node_id;
// connect to the node, try only once
let connection = endpoint.connect(addr, dumbpipe::ALPN).await?;
let connection = endpoint.connect(addr, &args.common.alpn()?).await?;
tracing::info!("connected to {}", remote_node_id);
// open a bidi stream, try only once
let (mut s, r) = connection.open_bi().await?;
Expand Down Expand Up @@ -357,13 +358,14 @@ async fn connect_tcp(args: ConnectTcpArgs) -> anyhow::Result<()> {
addr: NodeAddr,
endpoint: MagicEndpoint,
handshake: bool,
alpn: &[u8],
) -> anyhow::Result<()> {
let (tcp_stream, tcp_addr) = next.context("error accepting tcp connection")?;
let (tcp_recv, tcp_send) = tcp_stream.into_split();
tracing::info!("got tcp connection from {}", tcp_addr);
let remote_node_id = addr.node_id;
let connection = endpoint
.connect(addr, dumbpipe::ALPN)
.connect(addr, alpn)
.await
.context(format!("error connecting to {}", remote_node_id))?;
let (mut magic_send, magic_recv) = connection
Expand Down Expand Up @@ -393,8 +395,9 @@ async fn connect_tcp(args: ConnectTcpArgs) -> anyhow::Result<()> {
let endpoint = endpoint.clone();
let addr = addr.clone();
let handshake = !args.common.is_custom_alpn();
let alpn = args.common.alpn()?;
tokio::spawn(async move {
if let Err(cause) = handle_tcp_accept(next, addr, endpoint, handshake).await {
if let Err(cause) = handle_tcp_accept(next, addr, endpoint, handshake, &alpn).await {
// log error at warn level
//
// we should know about it, but it's not fatal
Expand All @@ -413,7 +416,7 @@ async fn listen_tcp(args: ListenTcpArgs) -> anyhow::Result<()> {
};
let secret_key = get_or_create_secret()?;
let endpoint = MagicEndpoint::builder()
.alpns(args.common.alpns()?)
.alpns(vec![args.common.alpn()?])
.secret_key(secret_key)
.bind(args.common.magic_port)
.await?;
Expand Down
48 changes: 48 additions & 0 deletions tests/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,54 @@ fn connect_listen_happy() {
assert_eq!(&listen_stdout, connect_to_listen);
}

/// Tests the basic functionality of the connect and listen pair
///
/// Connect and listen both write a limited amount of data and then EOF.
/// The interaction should stop when both sides have EOF'd.
#[test]
fn connect_listen_custom_alpn_happy() {
// the bytes provided by the listen command
let listen_to_connect = b"hello from listen";
let connect_to_listen = b"hello from connect";
let mut listen = duct::cmd(
dumbpipe_bin(),
["listen", "--custom-alpn", "utf8:mysuperalpn/0.1.0"],
)
.env_remove("RUST_LOG") // disable tracing
.stdin_bytes(listen_to_connect)
.stderr_to_stdout() //
.reader()
.unwrap();
// read the first 3 lines of the header, and parse the last token as a ticket
let header = read_ascii_lines(3, &mut listen).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = NodeTicket::from_str(ticket).unwrap();

let connect = duct::cmd(
dumbpipe_bin(),
[
"connect",
&ticket.to_string(),
"--custom-alpn",
"utf8:mysuperalpn/0.1.0",
],
)
.env_remove("RUST_LOG") // disable tracing
.stdin_bytes(connect_to_listen)
.stderr_null()
.stdout_capture()
.run()
.unwrap();

assert!(connect.status.success());
assert_eq!(&connect.stdout, listen_to_connect);

let mut listen_stdout = Vec::new();
listen.read_to_end(&mut listen_stdout).unwrap();
assert_eq!(&listen_stdout, connect_to_listen);
}

#[cfg(unix)]
#[test]
fn connect_listen_ctrlc_connect() {
Expand Down

0 comments on commit b9318bb

Please sign in to comment.