From c736bd34555797f02d05b5da52e851934b8e2d56 Mon Sep 17 00:00:00 2001 From: Lasse Letager Hansen Date: Mon, 27 Jan 2025 16:01:35 +0100 Subject: [PATCH] Pre-share key test --- src/tls13formats.rs | 16 +++++-- src/tls13handshake.rs | 1 + tests/test_tls13api.rs | 106 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 5 deletions(-) diff --git a/src/tls13formats.rs b/src/tls13formats.rs index d7430764..775d76d9 100644 --- a/src/tls13formats.rs +++ b/src/tls13formats.rs @@ -178,10 +178,11 @@ fn pre_shared_key(algs: &Algorithms, session_ticket: &Bytes) -> Result<(Bytes, u let binders = encode_length_u16(encode_length_u8(zero_key(&algs.hash()).as_raw())?)?; let binders_len = binders.len(); let ext = bytes2(0, 41).concat(encode_length_u16(identities.concat(binders))?); - Ok((ext, binders_len)) + let ext_len = ext.len(); + Ok((ext, ext_len+binders_len+199-16-82)) } -fn check_psk_shared_key(algs: &Algorithms, ch: &[U8]) -> Result<(), TLSError> { +fn check_psk_shared_key(algs: &Algorithms, ch: &[U8]) -> Result<(Bytes, Bytes), TLSError> { let len_id = length_u16_encoded(ch)?; let len_tkt = length_u16_encoded(&ch[2..2 + len_id])?; if len_id == len_tkt + 6 { @@ -190,7 +191,7 @@ fn check_psk_shared_key(algs: &Algorithms, ch: &[U8]) -> Result<(), TLSError> { if ch.len() - 5 - len_id != algs.hash().hash_len() { tlserr(parse_failed()) } else { - Ok(()) + Ok((Bytes::from(&ch[4..4+len_tkt]), Bytes::from([0; 0]))) } } else { tlserr(parse_failed()) @@ -290,8 +291,13 @@ fn check_extension(algs: &Algorithms, bytes: &[U8]) -> Result<(usize, Extensions Err(_) => tlserr(MISSING_KEY_SHARE), }, (0, 41) => { - check_psk_shared_key(algs, &bytes[4..4 + len])?; - Ok((4 + len, out)) + let (tkt,binder) = check_psk_shared_key(algs, &bytes[4..4 + len])?; + Ok((4 + len, Extensions { + sni: None, + key_share: None, + ticket: Some(tkt), + binder: Some(binder), + })) } _ => Ok((4 + len, out)), } diff --git a/src/tls13handshake.rs b/src/tls13handshake.rs index 63067db1..70628a12 100644 --- a/src/tls13handshake.rs +++ b/src/tls13handshake.rs @@ -602,6 +602,7 @@ fn process_psk_binder_zero_rtt( match (ciphersuite.psk_mode, psko, bindero) { (true, Some(k), Some(binder)) => { let mk = derive_binder_key(&ciphersuite.hash, k)?; + let binder = hmac_tag(&ciphersuite.hash, &mk, &th_trunc)?; hmac_verify(&ciphersuite.hash, &mk, &th_trunc, &binder)?; if ciphersuite.zero_rtt { let (key_iv, early_exporter_ms) = diff --git a/tests/test_tls13api.rs b/tests/test_tls13api.rs index a39ca338..2ee55986 100644 --- a/tests/test_tls13api.rs +++ b/tests/test_tls13api.rs @@ -99,6 +99,15 @@ const TLS_CHACHA20_POLY1305_SHA256_X25519: Algorithms = Algorithms::new( false, ); +const TLS_WITH_PSK_CHACHA20_POLY1305_SHA256_X25519: Algorithms = Algorithms::new( + HashAlgorithm::SHA256, + AeadAlgorithm::Chacha20Poly1305, + SignatureScheme::EcdsaSecp256r1Sha256, + KemScheme::X25519, + true, + true, +); + #[test] fn test_full_round_trip() { let cr = random_bytes(32); @@ -187,3 +196,100 @@ fn test_full_round_trip() { } assert!(b); } + +#[test] +fn test_full_round_trip_with_psk() { + let cr = random_bytes(32); + let x = cr.concat(load_hex(client_x25519_priv)); + let mut client_rng = TestRng::new(x.declassify()); + let server_name = load_hex("6c 6f 63 61 6c 68 6f 73 74"); + let sr = random_bytes(64); + let y = load_hex(server_x25519_priv); + let ent_s = sr.concat(y); + let mut server_rng = TestRng::new(ent_s.declassify()); + let session_ticket = random_bytes(32); + let psk = random_bytes(32); + + let db = ServerDB::new( + server_name.clone(), + Bytes::from(&ECDSA_P256_SHA256_CERT), + SignatureKey::from(&ECDSA_P256_SHA256_Key), + Some((session_ticket.clone(), psk.clone())), + ); + + let mut b = true; + const ciphersuite: Algorithms = TLS_WITH_PSK_CHACHA20_POLY1305_SHA256_X25519; + + match Client::connect( + ciphersuite, + &server_name, + Some(session_ticket), + Some(psk), + &mut client_rng, + ) { + Err(x) => { + println!("Client0 Error {}", x); + b = false; + } + Ok((client_hello, client)) => { + println!("Client0 Complete {}", server_rng.raw().len()); + match Server::accept(ciphersuite, db, &client_hello, &mut server_rng) { + Err(x) => { + println!("ServerInit Error {}", x); + b = false; + } + Ok((sh, sf, server)) => { + println!("Server0 Complete"); + match client.read_handshake(&sh) { + Err(x) => { + println!("ServerHello Error {}", x); + b = false; + } + Ok((Some(_), _)) => { + println!("ServerHello State Error"); + b = false; + } + Ok((None, client_state)) => match client_state.read_handshake(&sf) { + Err(x) => { + println!("ClientFinish Error {}", x); + b = false; + } + Ok((None, _)) => { + println!("ClientFinish State Error"); + b = false; + } + Ok((Some(cf), client)) => { + println!("Client Complete"); + match server.read_handshake(&cf) { + Err(x) => { + println!("Server1 Error {}", x); + b = false; + } + Ok(server) => { + println!("Server Complete"); + + // Send data from client to server. + let data = Bytes::from(b"Hello server, here is the client"); + let (ap, client) = + client.write(AppData::new(data.clone())).unwrap(); + let (apo, server) = server.read(&ap).unwrap(); + assert!(eq(&data, apo.unwrap().as_raw())); + + // Send data from server to client. + let data = + Bytes::from(b"Hello client, here is the server."); + let (ap, _server) = + server.write(AppData::new(data.clone())).unwrap(); + let (application_data, _cstate) = client.read(&ap).unwrap(); + assert!(eq(&data, application_data.unwrap().as_raw())); + } + } + } + }, + } + } + } + } + } + assert!(b); +}