Skip to content

Commit

Permalink
Extended API
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias J. Kannwischer <[email protected]>
  • Loading branch information
mkannwischer committed Feb 25, 2025
1 parent e319849 commit 3ca90c7
Show file tree
Hide file tree
Showing 4 changed files with 374 additions and 141 deletions.
127 changes: 55 additions & 72 deletions mlkem/indcpa.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,15 @@
* This is to facilitate building multiple instances
* of mlkem-native (e.g. with varying security levels)
* within a single compilation unit. */
#define mlk_pack_pk MLK_ADD_LEVEL(mlk_pack_pk)
#define mlk_unpack_pk MLK_ADD_LEVEL(mlk_unpack_pk)
#define mlk_pack_sk MLK_ADD_LEVEL(mlk_pack_sk)
#define mlk_unpack_sk MLK_ADD_LEVEL(mlk_unpack_sk)
#define mlk_pack_ciphertext MLK_ADD_LEVEL(mlk_pack_ciphertext)
#define mlk_unpack_ciphertext MLK_ADD_LEVEL(mlk_unpack_ciphertext)
#define mlk_matvec_mul MLK_ADD_LEVEL(mlk_matvec_mul)
/* End of level namespacing */

/*************************************************
* Name: mlk_pack_pk
* Name: mlk_indcpa_marshal_pk
*
* Description: Serialize the public key as concatenation of the
* serialized vector of polynomials pk
Expand All @@ -45,16 +43,18 @@
* Implements [FIPS 203, Algorithm 13 (K-PKE.KeyGen), L19]
*
**************************************************/
static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], mlk_polyvec *pk,
const uint8_t seed[MLKEM_SYMBYTES])
MLK_INTERNAL_API
void mlk_indcpa_marshal_pk(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
const mlk_indcpa_public_key *pks)
{
mlk_assert_bound_2d(pk, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
mlk_polyvec_tobytes(r, pk);
memcpy(r + MLKEM_POLYVECBYTES, seed, MLKEM_SYMBYTES);
mlk_assert_bound_2d(pks->pkpv, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
mlk_polyvec_tobytes(pk, &pks->pkpv);
memcpy(pk + MLKEM_POLYVECBYTES, pks->seed, MLKEM_SYMBYTES);
}


/*************************************************
* Name: mlk_unpack_pk
* Name: mlk_indcpa_parse_pk
*
* Description: De-serialize public key from a byte array;
* approximate inverse of mlk_pack_pk
Expand All @@ -69,11 +69,13 @@ static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], mlk_polyvec *pk,
* Implements [FIPS 203, Algorithm 14 (K-PKE.Encrypt), L2-3]
*
**************************************************/
static void mlk_unpack_pk(mlk_polyvec *pk, uint8_t seed[MLKEM_SYMBYTES],
const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES])
MLK_INTERNAL_API
void mlk_indcpa_parse_pk(mlk_indcpa_public_key *pks,
const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES])
{
mlk_polyvec_frombytes(pk, packedpk);
memcpy(seed, packedpk + MLKEM_POLYVECBYTES, MLKEM_SYMBYTES);
mlk_polyvec_frombytes(&pks->pkpv, pk);
memcpy(pks->seed, pk + MLKEM_POLYVECBYTES, MLKEM_SYMBYTES);
mlk_gen_matrix(pks->at, pks->seed, 1);

/* NOTE: If a modulus check was conducted on the PK, we know at this
* point that the coefficients of `pk` are unsigned canonical. The
Expand All @@ -82,7 +84,7 @@ static void mlk_unpack_pk(mlk_polyvec *pk, uint8_t seed[MLKEM_SYMBYTES],
}

/*************************************************
* Name: mlk_pack_sk
* Name: mlk_indcpa_marshal_sk
*
* Description: Serialize the secret key
*
Expand All @@ -94,14 +96,16 @@ static void mlk_unpack_pk(mlk_polyvec *pk, uint8_t seed[MLKEM_SYMBYTES],
* Implements [FIPS 203, Algorithm 13 (K-PKE.KeyGen), L20]
*
**************************************************/
static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], mlk_polyvec *sk)
MLK_INTERNAL_API
void mlk_indcpa_marshal_sk(uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES],
const mlk_indcpa_secret_key *sks)
{
mlk_assert_bound_2d(sk, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
mlk_polyvec_tobytes(r, sk);
mlk_assert_bound_2d(&sks->skpv, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
mlk_polyvec_tobytes(sk, &sks->skpv);
}

/*************************************************
* Name: mlk_unpack_sk
* Name: mlk_indcpa_parse_sk
*
* Description: De-serialize the secret key; inverse of mlk_pack_sk
*
Expand All @@ -114,10 +118,11 @@ static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], mlk_polyvec *sk)
* Implements [FIPS 203, Algorithm 15 (K-PKE.Decrypt), L5]
*
**************************************************/
static void mlk_unpack_sk(mlk_polyvec *sk,
const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES])
MLK_INTERNAL_API
void mlk_indcpa_parse_sk(mlk_indcpa_secret_key *sks,
const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES])
{
mlk_polyvec_frombytes(sk, packedsk);
mlk_polyvec_frombytes(&sks->skpv, sk);
}

/*************************************************
Expand Down Expand Up @@ -332,14 +337,14 @@ __contract__(
* - We include buffer zeroization.
*/
MLK_INTERNAL_API
void mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES],
void mlk_indcpa_keypair_derand(mlk_indcpa_public_key *pk,
mlk_indcpa_secret_key *sk,
const uint8_t coins[MLKEM_SYMBYTES])
{
MLK_ALIGN uint8_t buf[2 * MLKEM_SYMBYTES];
const uint8_t *publicseed = buf;
const uint8_t *noiseseed = buf + MLKEM_SYMBYTES;
mlk_polyvec a[MLKEM_K], e, pkpv, skpv;
mlk_polyvec e;
mlk_polyvec_mulcache skpv_cache;

MLK_ALIGN uint8_t coins_with_domain_separator[MLKEM_SYMBYTES + 1];
Expand All @@ -357,51 +362,48 @@ void mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
*/
MLK_CT_TESTING_DECLASSIFY(publicseed, MLKEM_SYMBYTES);

mlk_gen_matrix(a, publicseed, 0 /* no transpose */);
mlk_gen_matrix(pk->at, publicseed, 0 /* no transpose */);

#if MLKEM_K == 2
mlk_poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, e.vec + 0, e.vec + 1,
noiseseed, 0, 1, 2, 3);
mlk_poly_getnoise_eta1_4x(sk->skpv.vec + 0, sk->skpv.vec + 1, e.vec + 0,
e.vec + 1, noiseseed, 0, 1, 2, 3);
#elif MLKEM_K == 3
/*
* Only the first three output buffers are needed.
* The laster parameter is a dummy that's overwritten later.
*/
mlk_poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2,
pkpv.vec + 0 /* irrelevant */, noiseseed, 0, 1, 2,
0xFF /* irrelevant */);
mlk_poly_getnoise_eta1_4x(sk->skpv.vec + 0, sk->skpv.vec + 1,
sk->skpv.vec + 2, pk->pkpv.vec + 0 /* irrelevant */,
noiseseed, 0, 1, 2, 0xFF /* irrelevant */);
/* Same here */
mlk_poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2,
pkpv.vec + 0 /* irrelevant */, noiseseed, 3, 4, 5,
0xFF /* irrelevant */);
pk->pkpv.vec + 0 /* irrelevant */, noiseseed, 3, 4,
5, 0xFF /* irrelevant */);
#elif MLKEM_K == 4
mlk_poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2,
skpv.vec + 3, noiseseed, 0, 1, 2, 3);
mlk_poly_getnoise_eta1_4x(sk->skpv.vec + 0, sk->skpv.vec + 1,
sk->skpv.vec + 2, sk->skpv.vec + 3, noiseseed, 0, 1,
2, 3);
mlk_poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, e.vec + 3,
noiseseed, 4, 5, 6, 7);
#endif

mlk_polyvec_ntt(&skpv);
mlk_polyvec_ntt(&sk->skpv);
mlk_polyvec_ntt(&e);

mlk_polyvec_mulcache_compute(&skpv_cache, &skpv);
mlk_matvec_mul(&pkpv, a, &skpv, &skpv_cache);
mlk_polyvec_tomont(&pkpv);

mlk_polyvec_add(&pkpv, &e);
mlk_polyvec_reduce(&pkpv);
mlk_polyvec_reduce(&skpv);
mlk_polyvec_mulcache_compute(&skpv_cache, &sk->skpv);
mlk_matvec_mul(&pk->pkpv, pk->at, &sk->skpv, &skpv_cache);
mlk_polyvec_tomont(&pk->pkpv);

mlk_pack_sk(sk, &skpv);
mlk_pack_pk(pk, &pkpv, publicseed);
mlk_polyvec_add(&pk->pkpv, &e);
mlk_polyvec_reduce(&pk->pkpv);
mlk_polyvec_reduce(&sk->skpv);
memcpy(pk->seed, publicseed, MLKEM_SYMBYTES);

/* Specification: Partially implements
* [FIPS 203, Section 3.3, Destruction of intermediate values] */
mlk_zeroize(buf, sizeof(buf));
mlk_zeroize(coins_with_domain_separator, sizeof(coins_with_domain_separator));
mlk_zeroize(a, sizeof(a));
mlk_zeroize(&e, sizeof(e));
mlk_zeroize(&skpv, sizeof(skpv));
mlk_zeroize(&skpv_cache, sizeof(skpv_cache));
}

Expand All @@ -416,27 +418,14 @@ void mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
MLK_INTERNAL_API
void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
const uint8_t m[MLKEM_INDCPA_MSGBYTES],
const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
const mlk_indcpa_public_key *pk,
const uint8_t coins[MLKEM_SYMBYTES])
{
MLK_ALIGN uint8_t seed[MLKEM_SYMBYTES];
mlk_polyvec sp, pkpv, ep, at[MLKEM_K], b;
mlk_polyvec sp, ep, b;
mlk_poly v, k, epp;
mlk_polyvec_mulcache sp_cache;

mlk_unpack_pk(&pkpv, seed, pk);
mlk_poly_frommsg(&k, m);

/*
* Declassify the public seed.
* Required to use it in conditional-branches in rejection sampling.
* This is needed because in re-encryption the publicseed originated from sk
* which is marked undefined.
*/
MLK_CT_TESTING_DECLASSIFY(seed, MLKEM_SYMBYTES);

mlk_gen_matrix(at, seed, 1 /* transpose */);

#if MLKEM_K == 2
mlk_poly_getnoise_eta1122_4x(sp.vec + 0, sp.vec + 1, ep.vec + 0, ep.vec + 1,
coins, 0, 1, 2, 3);
Expand All @@ -462,8 +451,8 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
mlk_polyvec_ntt(&sp);

mlk_polyvec_mulcache_compute(&sp_cache, &sp);
mlk_matvec_mul(&b, at, &sp, &sp_cache);
mlk_polyvec_basemul_acc_montgomery_cached(&v, &pkpv, &sp, &sp_cache);
mlk_matvec_mul(&b, pk->at, &sp, &sp_cache);
mlk_polyvec_basemul_acc_montgomery_cached(&v, &pk->pkpv, &sp, &sp_cache);

mlk_polyvec_invntt_tomont(&b);
mlk_poly_invntt_tomont(&v);
Expand All @@ -479,12 +468,10 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],

/* Specification: Partially implements
* [FIPS 203, Section 3.3, Destruction of intermediate values] */
mlk_zeroize(seed, sizeof(seed));
mlk_zeroize(&sp, sizeof(sp));
mlk_zeroize(&sp_cache, sizeof(sp_cache));
mlk_zeroize(&b, sizeof(b));
mlk_zeroize(&v, sizeof(v));
mlk_zeroize(at, sizeof(at));
mlk_zeroize(&k, sizeof(k));
mlk_zeroize(&ep, sizeof(ep));
mlk_zeroize(&epp, sizeof(epp));
Expand All @@ -496,18 +483,17 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
MLK_INTERNAL_API
void mlk_indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES],
const uint8_t c[MLKEM_INDCPA_BYTES],
const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES])
const mlk_indcpa_secret_key *sk)
{
mlk_polyvec b, skpv;
mlk_polyvec b;
mlk_poly v, sb;
mlk_polyvec_mulcache b_cache;

mlk_unpack_ciphertext(&b, &v, c);
mlk_unpack_sk(&skpv, sk);

mlk_polyvec_ntt(&b);
mlk_polyvec_mulcache_compute(&b_cache, &b);
mlk_polyvec_basemul_acc_montgomery_cached(&sb, &skpv, &b, &b_cache);
mlk_polyvec_basemul_acc_montgomery_cached(&sb, &sk->skpv, &b, &b_cache);
mlk_poly_invntt_tomont(&sb);

mlk_poly_sub(&v, &sb);
Expand All @@ -517,7 +503,6 @@ void mlk_indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES],

/* Specification: Partially implements
* [FIPS 203, Section 3.3, Destruction of intermediate values] */
mlk_zeroize(&skpv, sizeof(skpv));
mlk_zeroize(&b, sizeof(b));
mlk_zeroize(&b_cache, sizeof(b_cache));
mlk_zeroize(&v, sizeof(v));
Expand All @@ -526,9 +511,7 @@ void mlk_indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES],

/* To facilitate single-compilation-unit (SCU) builds, undefine all macros.
* Don't modify by hand -- this is auto-generated by scripts/autogen. */
#undef mlk_pack_pk
#undef mlk_unpack_pk
#undef mlk_pack_sk
#undef mlk_unpack_sk
#undef mlk_pack_ciphertext
#undef mlk_unpack_ciphertext
Expand Down
44 changes: 36 additions & 8 deletions mlkem/indcpa.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,34 @@
#include "common.h"
#include "poly_k.h"

typedef struct
{
mlk_polyvec skpv;
} mlk_indcpa_secret_key;

typedef struct
{
mlk_polyvec at[MLKEM_K]; /* transposed matrix */
mlk_polyvec pkpv;
uint8_t seed[MLKEM_SYMBYTES];
} mlk_indcpa_public_key;

MLK_INTERNAL_API
void mlk_indcpa_marshal_pk(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
const mlk_indcpa_public_key *pks);

MLK_INTERNAL_API
void mlk_indcpa_parse_pk(mlk_indcpa_public_key *pks,
const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES]);

MLK_INTERNAL_API
void mlk_indcpa_marshal_sk(uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES],
const mlk_indcpa_secret_key *sks);

MLK_INTERNAL_API
void mlk_indcpa_parse_sk(mlk_indcpa_secret_key *sks,
const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]);

#define mlk_gen_matrix MLK_NAMESPACE_K(gen_matrix)
/*************************************************
* Name: mlk_gen_matrix
Expand Down Expand Up @@ -58,12 +86,12 @@ __contract__(
*
**************************************************/
MLK_INTERNAL_API
void mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES],
void mlk_indcpa_keypair_derand(mlk_indcpa_public_key *pk,
mlk_indcpa_secret_key *sk,
const uint8_t coins[MLKEM_SYMBYTES])
__contract__(
requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES))
requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES))
requires(memory_no_alias(pk, sizeof(mlk_indcpa_public_key)))
requires(memory_no_alias(sk, sizeof(mlk_indcpa_secret_key)))
requires(memory_no_alias(coins, MLKEM_SYMBYTES))
assigns(object_whole(pk))
assigns(object_whole(sk))
Expand Down Expand Up @@ -92,12 +120,12 @@ __contract__(
MLK_INTERNAL_API
void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
const uint8_t m[MLKEM_INDCPA_MSGBYTES],
const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
const mlk_indcpa_public_key *pk,
const uint8_t coins[MLKEM_SYMBYTES])
__contract__(
requires(memory_no_alias(c, MLKEM_INDCPA_BYTES))
requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES))
requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES))
requires(memory_no_alias(pk, sizeof(mlk_indcpa_public_key)))
requires(memory_no_alias(coins, MLKEM_SYMBYTES))
assigns(object_whole(c))
);
Expand All @@ -122,11 +150,11 @@ __contract__(
MLK_INTERNAL_API
void mlk_indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES],
const uint8_t c[MLKEM_INDCPA_BYTES],
const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES])
const mlk_indcpa_secret_key *sk)
__contract__(
requires(memory_no_alias(c, MLKEM_INDCPA_BYTES))
requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES))
requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES))
requires(memory_no_alias(sk, sizeof(mlk_indcpa_secret_key)))
assigns(object_whole(m))
);

Expand Down
Loading

0 comments on commit 3ca90c7

Please sign in to comment.