Skip to content

Commit

Permalink
ML-KEM encaps key modulus check optimization (aws#1874)
Browse files Browse the repository at this point in the history
PR aws#1868 implemented the encapsulation key modulus check
naively, as specified in FIPS 203 which introduced a significant
performance hit to encapsulation. In this change we optimize
encoding/decoding to/from bytes functions to reduce the performance
regression.

For example, encapsulation performance (ops/s) measured on M1 macbook:

```
| Variant | Before aws#1868 | After aws#1868 | After this change |
| 512     |    94738     |    68293    |       92494       |
| 768     |    58672     |    42083    |       55572       |
| 1024    |    38893     |    28536    |       36820       |
```
  • Loading branch information
dkostic authored Sep 26, 2024
1 parent 8970f68 commit ed6d6ca
Showing 1 changed file with 114 additions and 55 deletions.
169 changes: 114 additions & 55 deletions crypto/fipsmodule/ml_kem/ml_kem_ref/kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,54 +60,77 @@ int crypto_kem_keypair(ml_kem_params *params,
return 0;
}

// REFERENCE IMPLEMENTATION OF SEVERAL FIPS 203 FUNCTIONS.
// Further below we implement optimized versions of the functions
// that are actually used. We commented out and kept the reference
// code for posterity.
//
// FIPS 203. Algorithm 3 BitsToBytes
// Converts a bit array (of a length that is a multiple of eight)
// into an array of bytes.
static void bits_to_bytes(uint8_t *bytes, size_t num_bytes,
const uint8_t *bits, size_t num_bits) {
assert(num_bits == num_bytes * 8);

for (size_t i = 0; i < num_bytes; i++) {
uint8_t byte = 0;
for (size_t j = 0; j < 8; j++) {
byte |= (bits[i * 8 + j] << j);
}
bytes[i] = byte;
}
}

// static void bits_to_bytes(uint8_t *bytes, size_t num_bytes,
// const uint8_t *bits, size_t num_bits) {
// assert(num_bits == num_bytes * 8);
//
// for (size_t i = 0; i < num_bytes; i++) {
// uint8_t byte = 0;
// for (size_t j = 0; j < 8; j++) {
// byte |= (bits[i * 8 + j] << j);
// }
// bytes[i] = byte;
// }
// }
// FIPS 203. Algorithm 4 BytesToBits
// Performs the inverse of BitsToBytes, converting a byte array into a bit array.
static void bytes_to_bits(uint8_t *bits, size_t num_bits,
const uint8_t *bytes, size_t num_bytes) {
assert(num_bits == num_bytes * 8);

for (size_t i = 0; i < num_bytes; i++) {
uint8_t byte = bytes[i];
for (size_t j = 0; j < 8; j++) {
bits[i * 8 + j] = (byte >> j) & 1;
}
}
}

#define BYTE_ENCODE_12_IN_SIZE (256)
#define BYTE_ENCODE_12_OUT_SIZE (32 * 12)
#define BYTE_ENCODE_12_NUM_BITS (256 * 12)

// FIPS 203. Algorithm 5 ByteEncode_12
// Encodes an array of 256 12-bit integers into a byte array.
static void byte_encode_12(uint8_t out[BYTE_ENCODE_12_OUT_SIZE],
const int16_t in[BYTE_ENCODE_12_IN_SIZE]) {
uint8_t bits[BYTE_ENCODE_12_NUM_BITS] = {0};
for (size_t i = 0; i < BYTE_ENCODE_12_IN_SIZE; i++) {
int16_t a = in[i];
for (size_t j = 0; j < 12; j++) {
bits[i * 12 + j] = a & 1;
a = a >> 1;
}
}
bits_to_bytes(out, BYTE_ENCODE_12_OUT_SIZE, bits, BYTE_ENCODE_12_NUM_BITS);
}
// static void bytes_to_bits(uint8_t *bits, size_t num_bits,
// const uint8_t *bytes, size_t num_bytes) {
// assert(num_bits == num_bytes * 8);
//
// for (size_t i = 0; i < num_bytes; i++) {
// uint8_t byte = bytes[i];
// for (size_t j = 0; j < 8; j++) {
// bits[i * 8 + j] = (byte >> j) & 1;
// }
// }
// }
//
// #define BYTE_ENCODE_12_IN_SIZE (256)
// #define BYTE_ENCODE_12_OUT_SIZE (32 * 12)
// #define BYTE_ENCODE_12_NUM_BITS (256 * 12)
//
// // FIPS 203. Algorithm 5 ByteEncode_12
// // Encodes an array of 256 12-bit integers into a byte array.
// static void byte_encode_12(uint8_t out[BYTE_ENCODE_12_OUT_SIZE],
// const int16_t in[BYTE_ENCODE_12_IN_SIZE]) {
// uint8_t bits[BYTE_ENCODE_12_NUM_BITS] = {0};
// for (size_t i = 0; i < BYTE_ENCODE_12_IN_SIZE; i++) {
// int16_t a = in[i];
// for (size_t j = 0; j < 12; j++) {
// bits[i * 12 + j] = a & 1;
// a = a >> 1;
// }
// }
// bits_to_bytes(out, BYTE_ENCODE_12_OUT_SIZE, bits, BYTE_ENCODE_12_NUM_BITS);
// }
//
// #define BYTE_DECODE_12_OUT_SIZE (256)
// #define BYTE_DECODE_12_IN_SIZE (32 * 12)
// #define BYTE_DECODE_12_NUM_BITS (256 * 12)
//
// // FIPS 203. Algorithm 6 ByteDecode_12
// // Decodes a byte array into an array of 256 12-bit integers.
// static void byte_decode_12(int16_t out[BYTE_DECODE_12_OUT_SIZE],
// const uint8_t in[BYTE_DECODE_12_IN_SIZE]) {
// uint8_t bits[BYTE_DECODE_12_NUM_BITS] = {0};
// bytes_to_bits(bits, BYTE_DECODE_12_NUM_BITS, in, BYTE_DECODE_12_IN_SIZE);
// for (size_t i = 0; i < BYTE_DECODE_12_OUT_SIZE; i++) {
// int16_t val = 0;
// for (size_t j = 0; j < 12; j++) {
// val |= bits[i * 12 + j] << j;
// }
// out[i] = centered_to_positive_representative(barrett_reduce(val));
// }
// }

// Converts a centered representative |in| which is an integer in
// {-(q-1)/2, ..., (q-1)/2}, to a positive representative in {0, ..., q-1}.
Expand All @@ -120,22 +143,58 @@ static int16_t centered_to_positive_representative(int16_t in) {
return constant_time_select_int(mask, in, in_fixed);
}

#define BYTE_DECODE_12_OUT_SIZE (256)
#define BYTE_DECODE_12_IN_SIZE (32 * 12)
#define BYTE_DECODE_12_NUM_BITS (256 * 12)
#define BYTE_ENCODE_12_IN_SIZE (256)
#define BYTE_ENCODE_12_OUT_SIZE (32 * 12)
#define BYTE_DECODE_12_OUT_SIZE (BYTE_ENCODE_12_IN_SIZE)
#define BYTE_DECODE_12_IN_SIZE (BYTE_ENCODE_12_OUT_SIZE)

// FIPS 203. Algorithm 5 ByteEncode_12
// Encodes an array of 256 12-bit integers into a byte array.
// Intuition for the implementation:
// in: |xxxxxxxxyyyy| |yyyyzzzzzzzz| ...
// out: |xxxxxxxx| |yyyyyyyy| |zzzzzzzz| ...
// We divide the input in pairs of elements (2 x 12 bits = 24 bits),
// and the output in triples (3 x 8 bits = 24 bits). For each pair/triplet we:
// - out0 <-- first eight bits of in0,
// - out1 <-- concatenate last 4 bits of in0 and first 4 bits of in1,
// - out2 <-- last 8 bits of in1.
static void byte_encode_12(uint8_t out[BYTE_ENCODE_12_OUT_SIZE],
const int16_t in[BYTE_ENCODE_12_IN_SIZE]) {
for (size_t i = 0; i < BYTE_ENCODE_12_IN_SIZE / 2; i++) {
int16_t in0 = in[2 * i];
int16_t in1 = in[2 * i + 1];
out[3 * i] = in0 & 0xff;
out[3 * i + 1] = ((in0 >> 8) & 0xf) | ((in1 & 0xf) << 4);
out[3 * i + 2] = (in1 >> 4) & 0xff;
}
}

// FIPS 203. Algorithm 5 ByteDecode_12
// FIPS 203. Algorithm 6 ByteDecode_12
// Decodes a byte array into an array of 256 12-bit integers.
// Intuition for the implementation:
// in: |xxxxxxxx| |yyyyyyyy| |zzzzzzzz| ...
// out: |xxxxxxxxyyyy| |yyyyzzzzzzzz| ...
// We divide the input in triples of elements (3 x 8 bits = 24 bits),
// and the output in pairs (2 x 12 bits = 24 bits). For each pair/triplet we:
// - out[0] <-- concatenate eight bits of in[0] and first 4 bits of in[1],
// - out[1] <-- concatenate last 4 bits of in[1] and 8 bits of in[2].
// Additionally we reduce the output elements mod Q as specified in FIPS 203.
static void byte_decode_12(int16_t out[BYTE_DECODE_12_OUT_SIZE],
const uint8_t in[BYTE_DECODE_12_IN_SIZE]) {
uint8_t bits[BYTE_DECODE_12_NUM_BITS] = {0};
bytes_to_bits(bits, BYTE_DECODE_12_NUM_BITS, in, BYTE_DECODE_12_IN_SIZE);
for (size_t i = 0; i < BYTE_DECODE_12_OUT_SIZE; i++) {
int16_t val = 0;
for (size_t j = 0; j < 12; j++) {
val |= bits[i * 12 + j] << j;
}
out[i] = centered_to_positive_representative(barrett_reduce(val));
for(size_t i = 0; i < BYTE_DECODE_12_OUT_SIZE / 2; i++) {
// Cast to 16-bit wide uint's to avoid any issues
// with shifting and implicit casting.
uint16_t in0 = (uint16_t) in[3 * i];
uint16_t in1 = (uint16_t) in[3 * i + 1];
uint16_t in2 = (uint16_t) in[3 * i + 2];

// Build the output pair.
uint16_t out0 = in0 | ((in1 & 0xf) << 8);
uint16_t out1 = (in1 >> 4) | (in2 << 4);

// Reduce mod Q.
out[2 * i] = centered_to_positive_representative(barrett_reduce(out0));
out[2 * i + 1] = centered_to_positive_representative(barrett_reduce(out1));
}
}

Expand Down

0 comments on commit ed6d6ca

Please sign in to comment.