diff --git a/crypto/fipsmodule/ml_kem/ml_kem_ref/kem.c b/crypto/fipsmodule/ml_kem/ml_kem_ref/kem.c index 5b056256b8..732ee61028 100644 --- a/crypto/fipsmodule/ml_kem/ml_kem_ref/kem.c +++ b/crypto/fipsmodule/ml_kem/ml_kem_ref/kem.c @@ -60,54 +60,77 @@ int crypto_kem_keypair(ml_kem_params *params, return 0; } +// REFERENCE IMPLEMENTATION OF SEVERAL FIPS 203 FUNCTIONS. +// Further below we implement optimized versions of the functions +// that are actually used. We commented out and kept the reference +// code for posterity. +// // FIPS 203. Algorithm 3 BitsToBytes // Converts a bit array (of a length that is a multiple of eight) // into an array of bytes. -static void bits_to_bytes(uint8_t *bytes, size_t num_bytes, - const uint8_t *bits, size_t num_bits) { - assert(num_bits == num_bytes * 8); - - for (size_t i = 0; i < num_bytes; i++) { - uint8_t byte = 0; - for (size_t j = 0; j < 8; j++) { - byte |= (bits[i * 8 + j] << j); - } - bytes[i] = byte; - } -} - +// static void bits_to_bytes(uint8_t *bytes, size_t num_bytes, +// const uint8_t *bits, size_t num_bits) { +// assert(num_bits == num_bytes * 8); +// +// for (size_t i = 0; i < num_bytes; i++) { +// uint8_t byte = 0; +// for (size_t j = 0; j < 8; j++) { +// byte |= (bits[i * 8 + j] << j); +// } +// bytes[i] = byte; +// } +// } // FIPS 203. Algorithm 4 BytesToBits // Performs the inverse of BitsToBytes, converting a byte array into a bit array. -static void bytes_to_bits(uint8_t *bits, size_t num_bits, - const uint8_t *bytes, size_t num_bytes) { - assert(num_bits == num_bytes * 8); - - for (size_t i = 0; i < num_bytes; i++) { - uint8_t byte = bytes[i]; - for (size_t j = 0; j < 8; j++) { - bits[i * 8 + j] = (byte >> j) & 1; - } - } -} - -#define BYTE_ENCODE_12_IN_SIZE (256) -#define BYTE_ENCODE_12_OUT_SIZE (32 * 12) -#define BYTE_ENCODE_12_NUM_BITS (256 * 12) - -// FIPS 203. Algorithm 5 ByteEncode_12 -// Encodes an array of 256 12-bit integers into a byte array. -static void byte_encode_12(uint8_t out[BYTE_ENCODE_12_OUT_SIZE], - const int16_t in[BYTE_ENCODE_12_IN_SIZE]) { - uint8_t bits[BYTE_ENCODE_12_NUM_BITS] = {0}; - for (size_t i = 0; i < BYTE_ENCODE_12_IN_SIZE; i++) { - int16_t a = in[i]; - for (size_t j = 0; j < 12; j++) { - bits[i * 12 + j] = a & 1; - a = a >> 1; - } - } - bits_to_bytes(out, BYTE_ENCODE_12_OUT_SIZE, bits, BYTE_ENCODE_12_NUM_BITS); -} +// static void bytes_to_bits(uint8_t *bits, size_t num_bits, +// const uint8_t *bytes, size_t num_bytes) { +// assert(num_bits == num_bytes * 8); +// +// for (size_t i = 0; i < num_bytes; i++) { +// uint8_t byte = bytes[i]; +// for (size_t j = 0; j < 8; j++) { +// bits[i * 8 + j] = (byte >> j) & 1; +// } +// } +// } +// +// #define BYTE_ENCODE_12_IN_SIZE (256) +// #define BYTE_ENCODE_12_OUT_SIZE (32 * 12) +// #define BYTE_ENCODE_12_NUM_BITS (256 * 12) +// +// // FIPS 203. Algorithm 5 ByteEncode_12 +// // Encodes an array of 256 12-bit integers into a byte array. +// static void byte_encode_12(uint8_t out[BYTE_ENCODE_12_OUT_SIZE], +// const int16_t in[BYTE_ENCODE_12_IN_SIZE]) { +// uint8_t bits[BYTE_ENCODE_12_NUM_BITS] = {0}; +// for (size_t i = 0; i < BYTE_ENCODE_12_IN_SIZE; i++) { +// int16_t a = in[i]; +// for (size_t j = 0; j < 12; j++) { +// bits[i * 12 + j] = a & 1; +// a = a >> 1; +// } +// } +// bits_to_bytes(out, BYTE_ENCODE_12_OUT_SIZE, bits, BYTE_ENCODE_12_NUM_BITS); +// } +// +// #define BYTE_DECODE_12_OUT_SIZE (256) +// #define BYTE_DECODE_12_IN_SIZE (32 * 12) +// #define BYTE_DECODE_12_NUM_BITS (256 * 12) +// +// // FIPS 203. Algorithm 6 ByteDecode_12 +// // Decodes a byte array into an array of 256 12-bit integers. +// static void byte_decode_12(int16_t out[BYTE_DECODE_12_OUT_SIZE], +// const uint8_t in[BYTE_DECODE_12_IN_SIZE]) { +// uint8_t bits[BYTE_DECODE_12_NUM_BITS] = {0}; +// bytes_to_bits(bits, BYTE_DECODE_12_NUM_BITS, in, BYTE_DECODE_12_IN_SIZE); +// for (size_t i = 0; i < BYTE_DECODE_12_OUT_SIZE; i++) { +// int16_t val = 0; +// for (size_t j = 0; j < 12; j++) { +// val |= bits[i * 12 + j] << j; +// } +// out[i] = centered_to_positive_representative(barrett_reduce(val)); +// } +// } // Converts a centered representative |in| which is an integer in // {-(q-1)/2, ..., (q-1)/2}, to a positive representative in {0, ..., q-1}. @@ -120,22 +143,58 @@ static int16_t centered_to_positive_representative(int16_t in) { return constant_time_select_int(mask, in, in_fixed); } -#define BYTE_DECODE_12_OUT_SIZE (256) -#define BYTE_DECODE_12_IN_SIZE (32 * 12) -#define BYTE_DECODE_12_NUM_BITS (256 * 12) +#define BYTE_ENCODE_12_IN_SIZE (256) +#define BYTE_ENCODE_12_OUT_SIZE (32 * 12) +#define BYTE_DECODE_12_OUT_SIZE (BYTE_ENCODE_12_IN_SIZE) +#define BYTE_DECODE_12_IN_SIZE (BYTE_ENCODE_12_OUT_SIZE) + +// FIPS 203. Algorithm 5 ByteEncode_12 +// Encodes an array of 256 12-bit integers into a byte array. +// Intuition for the implementation: +// in: |xxxxxxxxyyyy| |yyyyzzzzzzzz| ... +// out: |xxxxxxxx| |yyyyyyyy| |zzzzzzzz| ... +// We divide the input in pairs of elements (2 x 12 bits = 24 bits), +// and the output in triples (3 x 8 bits = 24 bits). For each pair/triplet we: +// - out0 <-- first eight bits of in0, +// - out1 <-- concatenate last 4 bits of in0 and first 4 bits of in1, +// - out2 <-- last 8 bits of in1. +static void byte_encode_12(uint8_t out[BYTE_ENCODE_12_OUT_SIZE], + const int16_t in[BYTE_ENCODE_12_IN_SIZE]) { + for (size_t i = 0; i < BYTE_ENCODE_12_IN_SIZE / 2; i++) { + int16_t in0 = in[2 * i]; + int16_t in1 = in[2 * i + 1]; + out[3 * i] = in0 & 0xff; + out[3 * i + 1] = ((in0 >> 8) & 0xf) | ((in1 & 0xf) << 4); + out[3 * i + 2] = (in1 >> 4) & 0xff; + } +} -// FIPS 203. Algorithm 5 ByteDecode_12 +// FIPS 203. Algorithm 6 ByteDecode_12 // Decodes a byte array into an array of 256 12-bit integers. +// Intuition for the implementation: +// in: |xxxxxxxx| |yyyyyyyy| |zzzzzzzz| ... +// out: |xxxxxxxxyyyy| |yyyyzzzzzzzz| ... +// We divide the input in triples of elements (3 x 8 bits = 24 bits), +// and the output in pairs (2 x 12 bits = 24 bits). For each pair/triplet we: +// - out[0] <-- concatenate eight bits of in[0] and first 4 bits of in[1], +// - out[1] <-- concatenate last 4 bits of in[1] and 8 bits of in[2]. +// Additionally we reduce the output elements mod Q as specified in FIPS 203. static void byte_decode_12(int16_t out[BYTE_DECODE_12_OUT_SIZE], const uint8_t in[BYTE_DECODE_12_IN_SIZE]) { - uint8_t bits[BYTE_DECODE_12_NUM_BITS] = {0}; - bytes_to_bits(bits, BYTE_DECODE_12_NUM_BITS, in, BYTE_DECODE_12_IN_SIZE); - for (size_t i = 0; i < BYTE_DECODE_12_OUT_SIZE; i++) { - int16_t val = 0; - for (size_t j = 0; j < 12; j++) { - val |= bits[i * 12 + j] << j; - } - out[i] = centered_to_positive_representative(barrett_reduce(val)); + for(size_t i = 0; i < BYTE_DECODE_12_OUT_SIZE / 2; i++) { + // Cast to 16-bit wide uint's to avoid any issues + // with shifting and implicit casting. + uint16_t in0 = (uint16_t) in[3 * i]; + uint16_t in1 = (uint16_t) in[3 * i + 1]; + uint16_t in2 = (uint16_t) in[3 * i + 2]; + + // Build the output pair. + uint16_t out0 = in0 | ((in1 & 0xf) << 8); + uint16_t out1 = (in1 >> 4) | (in2 << 4); + + // Reduce mod Q. + out[2 * i] = centered_to_positive_representative(barrett_reduce(out0)); + out[2 * i + 1] = centered_to_positive_representative(barrett_reduce(out1)); } }