Skip to content

Commit

Permalink
encpedpop: Add coordinator step function
Browse files Browse the repository at this point in the history
  • Loading branch information
real-or-random committed Mar 22, 2024
1 parent e31bf4a commit 17b1d62
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 40 deletions.
23 changes: 21 additions & 2 deletions reference/encpedpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ class SignerMsg(NamedTuple):

class CoordinatorMsg(NamedTuple):
simpl_cmsg: simplpedpop.CoordinatorMsg
enc_shares_sum: Scalar


# TODO Define a CoordinatorUnicastMsg to imrpove handling of the enc_shares_sums?


###
Expand Down Expand Up @@ -103,9 +105,10 @@ def signer_step(
def signer_pre_finalize(
state: SignerState,
cmsg: CoordinatorMsg,
enc_shares_sum: Scalar,
) -> Tuple[bytes, simplpedpop.DKGOutput]:
t, deckey, enckeys, idx, self_share, simpl_state = state
simpl_cmsg, enc_shares_sum = cmsg
simpl_cmsg, = cmsg # Unpack unary tuple # fmt: skip

enc_context = t.to_bytes(4, byteorder="big") + b"".join(enckeys)
shares_sum = decrypt_sum(enc_shares_sum, deckey, enckeys, idx, enc_context)
Expand All @@ -115,3 +118,19 @@ def signer_pre_finalize(
)
eta += b"".join(enckeys)
return eta, dkg_output


###
### Coordinator
###


def coordinator_step(
smsgs: List[SignerMsg], t: int
) -> Tuple[CoordinatorMsg, List[Scalar]]:
n = len(smsgs)
simpl_cmsg = simplpedpop.coordinator_step([smsg.simpl_smsg for smsg in smsgs], t)
enc_shares_sums = [
Scalar.sum(*([smsg.enc_shares[i] for smsg in smsgs])) for i in range(n)
]
return CoordinatorMsg(simpl_cmsg), enc_shares_sums
50 changes: 23 additions & 27 deletions reference/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def chilldkg_round1(
def chilldkg_round2(
seed: bytes,
state1: ChillDKGStateR1,
vss_commitments_sum: simplpedpop.CoordinatorMsg,
all_enc_shares_sum: List[Scalar],
enc_cmsg: encpedpop.CoordinatorMsg,
enc_shares_sums: List[Scalar],
) -> Tuple[ChillDKGStateR2, bytes]:
(hostseckey, _) = chilldkg_hostkey_gen(seed)
(params, idx, enc_state1) = state1
Expand All @@ -77,11 +77,11 @@ def chilldkg_round2(
# shares, which in turn ensures that they have the right backup.
# TODO This means all parties who hold the "backup" in the end should
# participate in Eq?
enc_share = all_enc_shares_sum[idx]

enc_cmsg = encpedpop.CoordinatorMsg(vss_commitments_sum, enc_share)
eta, dkg_output = encpedpop.signer_pre_finalize(enc_state1, enc_cmsg)
eta += b"".join([bytes_from_int(int(share)) for share in all_enc_shares_sum])
eta, dkg_output = encpedpop.signer_pre_finalize(
enc_state1, enc_cmsg, enc_shares_sums[idx]
)
eta += b"".join([bytes_from_int(int(share)) for share in enc_shares_sums])
state2 = (params, eta, dkg_output)
return state2, certifying_eq_round1(hostseckey, eta)

Expand Down Expand Up @@ -118,12 +118,10 @@ async def chilldkg(
) -> Optional[Tuple[simplpedpop.DKGOutput, Any]]:
# TODO Top-level error handling
state1, vss_commitment_ext, enc_gen_shares = chilldkg_round1(seed, params)
chan.send((vss_commitment_ext, enc_gen_shares))
vss_commitments_sum, all_enc_shares_sum = await chan.receive()
chan.send(encpedpop.SignerMsg(vss_commitment_ext, enc_gen_shares))
enc_cmsg, enc_shares_sums = await chan.receive()

state2, eq_round1 = chilldkg_round2(
seed, state1, vss_commitments_sum, all_enc_shares_sum
)
state2, eq_round1 = chilldkg_round2(seed, state1, enc_cmsg, enc_shares_sums)

chan.send(eq_round1)
cert = await chan.receive()
Expand Down Expand Up @@ -175,13 +173,13 @@ def serialize_eta(
t: int,
vss_commit: VSSCommitment,
hostpubkeys: List[bytes],
all_enc_shares_sum: List[Scalar],
enc_shares_sums: List[Scalar],
) -> bytes:
return (
t.to_bytes(4, byteorder="big")
+ vss_commit.to_bytes()
+ b"".join(hostpubkeys)
+ b"".join([bytes_from_int(int(share)) for share in all_enc_shares_sum])
+ b"".join([bytes_from_int(int(share)) for share in enc_shares_sums])
)


Expand All @@ -190,21 +188,19 @@ async def chilldkg_coordinate(
) -> Union[GroupInfo, Literal[False]]:
(hostpubkeys, t, params_id) = params
n = len(hostpubkeys)
simpl_round1_ins = []
all_enc_shares_sum = [Scalar(0)] * n
enc_round1_ins = []
for i in range(n):
simpl_round1_in, enc_shares = await chans.receive_from(i)
simpl_round1_ins += [simpl_round1_in]
all_enc_shares_sum = [all_enc_shares_sum[j] + enc_shares[j] for j in range(n)]
simpl_round1_outs = simplpedpop.coordinator_step(simpl_round1_ins, t)
chans.send_all((simpl_round1_outs, all_enc_shares_sum))
enc_round1_ins.append(encpedpop.SignerMsg(simpl_round1_in, enc_shares))
enc_round1_out, enc_shares_sums = encpedpop.coordinator_step(enc_round1_ins, t)
chans.send_all((enc_round1_out, enc_shares_sums))
vss_commitment = simplpedpop.assemble_sum_vss_commitment(
simpl_round1_outs.coms_to_secrets,
simpl_round1_outs.sum_coms_to_nonconst_terms,
enc_round1_out.simpl_cmsg.coms_to_secrets,
enc_round1_out.simpl_cmsg.sum_coms_to_nonconst_terms,
t,
n,
)
eta = serialize_eta(t, vss_commitment, hostpubkeys, all_enc_shares_sum)
eta = serialize_eta(t, vss_commitment, hostpubkeys, enc_shares_sums)
cert = await certifying_eq_coordinate(chans, hostpubkeys)
if not verify_cert(hostpubkeys, eta, cert):
return False
Expand Down Expand Up @@ -234,17 +230,17 @@ def deserialize_eta(b: bytes) -> Any:
raise DeserializationError
hostpubkeys, rest = [rest[i : i + 33] for i in range(0, 33 * n, 33)], rest[33 * n :]

# Read all_enc_shares_sum (32*n bytes)
# Read enc_shares_sums (32*n bytes)
if len(rest) < 32 * n:
raise DeserializationError
all_enc_shares_sum, rest = (
enc_shares_sums, rest = (
[Scalar(int_from_bytes(rest[i : i + 32])) for i in range(0, 32 * n, 32)],
rest[32 * n :],
)

if len(rest) != 0:
raise DeserializationError
return (t, vss_commit, hostpubkeys, all_enc_shares_sum)
return (t, vss_commit, hostpubkeys, enc_shares_sums)


# Recovery requires the seed and the public backup
Expand All @@ -253,7 +249,7 @@ def chilldkg_recover(
) -> Union[Tuple[simplpedpop.DKGOutput, SessionParams], Literal[False]]:
(eta, cert) = backup
try:
(t, vss_commit, hostpubkeys, all_enc_shares_sum) = deserialize_eta(eta)
(t, vss_commit, hostpubkeys, enc_shares_sums) = deserialize_eta(eta)
except DeserializationError as e:
raise InvalidBackupError("Failed to deserialize backup") from e

Expand All @@ -273,7 +269,7 @@ def chilldkg_recover(
raise InvalidBackupError("Seed and backup don't match") from e

shares_sum = encpedpop.decrypt_sum(
all_enc_shares_sum[idx], hostseckey, hostpubkeys, idx, enc_context
enc_shares_sums[idx], hostseckey, hostpubkeys, idx, enc_context
)
# TODO: don't call full round1 function
(state1, (_, _)) = encpedpop.signer_step(
Expand Down
20 changes: 9 additions & 11 deletions reference/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,12 @@ def simulate_encpedpop(seeds, t):
deckey = enc_soutputs0[i][0]
enc_soutputs1 += [encpedpop.signer_step(seeds[i], t, n, deckey, enckeys, i)]

simpl_smsgs = [out[1][0] for out in enc_soutputs1]
simpl_cmsg = simplpedpop.coordinator_step(simpl_smsgs, t)
smsgs = [smsg for (_, smsg) in enc_soutputs1]
sstates = [sstate for (sstate, _) in enc_soutputs1]
cmsg, enc_shares_sums = encpedpop.coordinator_step(smsgs, t)
for i in range(n):
enc_shares_sum = Scalar.sum(*([out[1][1][i] for out in enc_soutputs1]))
dkg_outputs += [
encpedpop.signer_pre_finalize(
enc_soutputs1[i][0], (simpl_cmsg, enc_shares_sum)
)
encpedpop.signer_pre_finalize(sstates[i], cmsg, enc_shares_sums[i])
]
return dkg_outputs

Expand All @@ -101,16 +99,16 @@ def simulate_chilldkg(seeds, t):
chill_sstate1s = [out[0] for out in chill_soutputs1]
simpl_smsgs = [out[1] for out in chill_soutputs1]
simpl_cmsg = simplpedpop.coordinator_step(simpl_smsgs, t)
enc_cmsg = encpedpop.CoordinatorMsg(simpl_cmsg)

dkg_outputs = []
all_enc_shares_sum = []
enc_shares_sums = []
for i in range(n):
all_enc_shares_sum += [Scalar.sum(*([out[2][i] for out in chill_soutputs1]))]
enc_shares_sums += [Scalar.sum(*([out[2][i] for out in chill_soutputs1]))]
round2_outputs = []
for i in range(n):
round2_outputs += [
chilldkg_round2(
seeds[i], chill_sstate1s[i], simpl_cmsg, all_enc_shares_sum
)
chilldkg_round2(seeds[i], chill_sstate1s[i], enc_cmsg, enc_shares_sums)
]

cert = b"".join([out[1] for out in round2_outputs])
Expand Down

0 comments on commit 17b1d62

Please sign in to comment.