Skip to content

Commit

Permalink
cleanup: clean group context serialization in the keyschedule
Browse files Browse the repository at this point in the history
  • Loading branch information
TWal committed Dec 25, 2024
1 parent 5cdf922 commit 5526769
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 12 deletions.
2 changes: 1 addition & 1 deletion fstar/api/MLS.fst
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ let process_welcome_message w (sign_pk, sign_sk) lookup =
let? interim_transcript_hash = MLS.TreeDEM.Message.Transcript.compute_interim_transcript_hash #bytes group_info.tbs.confirmation_tag group_info.tbs.group_context.confirmed_transcript_hash in
let? tree_hash = MLS.TreeSync.API.compute_tree_hash treesync_state in
let? group_context = compute_group_context group_info.tbs.group_context.group_id group_info.tbs.group_context.epoch tree_hash group_info.tbs.group_context.confirmed_transcript_hash in
let? epoch_secret = MLS.TreeKEM.KeySchedule.secret_joiner_to_epoch (secrets.joiner_secret <: bytes) [] (serialize _ group_context) in
let? epoch_secret = MLS.TreeKEM.KeySchedule.secret_joiner_to_epoch (secrets.joiner_secret <: bytes) [] group_context in
let? (treekem_state, encryption_secret) = MLS.TreeKEM.API.welcome treekem leaf_decryption_key opt_path_secret_and_inviter_ind leaf_index epoch_secret in

let? tree_hash = MLS.TreeSync.API.compute_tree_hash treesync_state in
Expand Down
6 changes: 3 additions & 3 deletions fstar/test/MLS.Test.FromExt.KeySchedule.fst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ let gen_epoch_output #cb group_id last_init_secret epoch inp =
let commit_secret = hex_string_to_bytes inp.commit_secret in
let psk_secret = hex_string_to_bytes inp.psk_secret in

let group_context = (ps_prefix_to_ps_whole ps_group_context_nt).serialize (gen_group_context (ciphersuite #bytes) (hex_string_to_bytes group_id) epoch (hex_string_to_bytes inp.tree_hash) (hex_string_to_bytes inp.confirmed_transcript_hash)) in
let group_context = gen_group_context (ciphersuite #bytes) (hex_string_to_bytes group_id) epoch (hex_string_to_bytes inp.tree_hash) (hex_string_to_bytes inp.confirmed_transcript_hash) in
let last_init_secret = hex_string_to_bytes last_init_secret in
let joiner_secret = extract_result (secret_init_to_joiner last_init_secret (Some commit_secret) group_context) in
let welcome_secret = extract_result (secret_joiner_to_welcome_internal #bytes joiner_secret psk_secret) in
Expand All @@ -36,11 +36,11 @@ let gen_epoch_output #cb group_id last_init_secret epoch inp =
if not (string_is_ascii inp.exporter_label && String.strlen inp.exporter_label < pow2 30 - 8) then
failwith "gen_epoch_output: exporter label is not ascii"
else
extract_result (mls_exporter exporter_secret inp.exporter_label (hex_string_to_bytes inp.exporter_context) (UInt32.v inp.exporter_length))
extract_result (mls_exporter (exporter_secret <: bytes) inp.exporter_label (hex_string_to_bytes inp.exporter_context) (UInt32.v inp.exporter_length))
in

{
group_context = bytes_to_hex_string group_context;
group_context = bytes_to_hex_string (serialize _ group_context);
joiner_secret = bytes_to_hex_string joiner_secret;
welcome_secret = bytes_to_hex_string welcome_secret;
init_secret = bytes_to_hex_string init_secret;
Expand Down
2 changes: 1 addition & 1 deletion fstar/test/MLS.Test.FromExt.Welcome.fst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ let test_welcome_one t =
failwith "test_welcome_one: bad signature"
);

let group_context = (ps_prefix_to_ps_whole ps_group_context_nt).serialize (gen_group_context (ciphersuite #bytes) group_info.tbs.group_context.group_id group_info.tbs.group_context.epoch group_info.tbs.group_context.tree_hash group_info.tbs.group_context.confirmed_transcript_hash) in
let group_context = gen_group_context (ciphersuite #bytes) group_info.tbs.group_context.group_id group_info.tbs.group_context.epoch group_info.tbs.group_context.tree_hash group_info.tbs.group_context.confirmed_transcript_hash in
let joiner_secret = group_secrets.joiner_secret in
let epoch_secret = extract_result (secret_joiner_to_epoch joiner_secret [] group_context) in
let confirmation_key = extract_result (secret_epoch_to_confirmation #bytes epoch_secret) in
Expand Down
4 changes: 2 additions & 2 deletions fstar/treekem/code/MLS.TreeKEM.API.KeySchedule.fst
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ val commit:
group_context_nt bytes ->
result (treekem_keyschedule_state bytes & bytes & secrets_for_welcome bytes)
let commit #bytes #cb st opt_commit_secret psks new_group_context =
let? joiner_secret: bytes = secret_init_to_joiner st.init_secret opt_commit_secret (serialize _ new_group_context) in
let? epoch_secret: bytes = secret_joiner_to_epoch joiner_secret psks (serialize _ new_group_context) in
let? joiner_secret: bytes = secret_init_to_joiner st.init_secret opt_commit_secret new_group_context in
let? epoch_secret: bytes = secret_joiner_to_epoch joiner_secret psks new_group_context in
let? welcome_secret: bytes = secret_joiner_to_welcome joiner_secret psks in
let? (new_st, encryption_secret) = create epoch_secret in
return (new_st, encryption_secret, {
Expand Down
11 changes: 6 additions & 5 deletions fstar/treekem/code/MLS.TreeKEM.KeySchedule.fst
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ module MLS.TreeKEM.KeySchedule

open Comparse
open MLS.Crypto
open MLS.NetworkTypes
open MLS.TreeKEM.NetworkTypes
open MLS.TreeKEM.PSK
open MLS.Result

val secret_init_to_joiner:
#bytes:Type0 -> {|crypto_bytes bytes|} ->
bytes -> option bytes -> bytes ->
bytes -> option bytes -> group_context_nt bytes ->
result (lbytes bytes (kdf_length #bytes))
let secret_init_to_joiner #bytes #cb init_secret opt_commit_secret group_context =
let commit_secret =
Expand All @@ -17,7 +18,7 @@ let secret_init_to_joiner #bytes #cb init_secret opt_commit_secret group_context
| None -> zero_vector #bytes
in
let? prk = kdf_extract init_secret commit_secret in
expand_with_label #bytes prk "joiner" group_context (kdf_length #bytes)
expand_with_label #bytes prk "joiner" (serialize _ group_context) (kdf_length #bytes)

// this version is tested in the test vectors
val secret_joiner_to_welcome_internal:
Expand All @@ -39,15 +40,15 @@ let secret_joiner_to_welcome #bytes #cb joiner_secret psks =
// this version is tested in the test vectors
val secret_joiner_to_epoch_internal:
#bytes:Type0 -> {|crypto_bytes bytes|} ->
bytes -> bytes -> bytes ->
bytes -> bytes -> group_context_nt bytes ->
result (lbytes bytes (kdf_length #bytes))
let secret_joiner_to_epoch_internal #bytes #cb joiner_secret psk_secret group_context =
let? prk = kdf_extract joiner_secret psk_secret in
expand_with_label #bytes prk "epoch" group_context (kdf_length #bytes)
expand_with_label #bytes prk "epoch" (serialize _ group_context) (kdf_length #bytes)

val secret_joiner_to_epoch:
#bytes:Type0 -> {|crypto_bytes bytes|} ->
bytes -> list (pre_shared_key_id_nt bytes & bytes) -> bytes ->
bytes -> list (pre_shared_key_id_nt bytes & bytes) -> group_context_nt bytes ->
result (lbytes bytes (kdf_length #bytes))
let secret_joiner_to_epoch #bytes #cb joiner_secret psks group_context =
let? psk_secret = compute_psk_secret psks in
Expand Down

0 comments on commit 5526769

Please sign in to comment.