Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ml-kem: Adds feature flag to use key or seed #83

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ml-kem/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ default = ["std"]
std = ["sha3/std"]
deterministic = [] # Expose deterministic generation and encapsulation functions
zeroize = ["dep:zeroize"]
decap_key = [] # Use seed for decapsulation key (default behaviour) or not. If set, will use standard decapsulation key.

[dependencies]
kem = "0.3.0-pre.0"
Expand Down
264 changes: 254 additions & 10 deletions ml-kem/src/kem.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use core::convert::Infallible;
use core::marker::PhantomData;
use hybrid_array::typenum::U32;
#[cfg(not(feature = "decap_key"))]
use hybrid_array::typenum::U64;
use rand_core::CryptoRngCore;

use crate::crypto::{rand, G, H, J};
Expand All @@ -18,10 +20,19 @@ pub use ::kem::{Decapsulate, Encapsulate};
/// A shared key resulting from an ML-KEM transaction
pub(crate) type SharedKey = B32;

/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an
/// encapsulated shared key.
#[cfg(not(feature = "decap_key"))]
#[derive(Clone, Debug, PartialEq)]
pub struct DecapsulationKey<P>
struct DecapsulationSeed<P>
where
P: KemParams,
{
d: B32,
z: B32,
_phantom: PhantomData<P>,
}

#[derive(Clone, Debug, PartialEq)]
struct DecapsulationKeyInner<P>
where
P: KemParams,
{
Expand All @@ -30,8 +41,29 @@ where
z: B32,
}

/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an
/// encapsulated shared key.
#[cfg(feature = "decap_key")]
#[derive(Clone, Debug, PartialEq)]
pub struct DecapsulationKey<P>
where
P: KemParams,
{
key: DecapsulationKeyInner<P>,
}
/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an
/// encapsulated shared key.
#[cfg(not(feature = "decap_key"))]
#[derive(Clone, Debug, PartialEq)]
pub struct DecapsulationKey<P>
where
P: KemParams,
{
key: DecapsulationSeed<P>,
}

#[cfg(feature = "zeroize")]
impl<P> Drop for DecapsulationKey<P>
impl<P> Drop for DecapsulationKeyInner<P>
where
P: KemParams,
{
Expand All @@ -41,10 +73,59 @@ where
}
}

#[cfg(all(feature = "zeroize", not(feature = "decap_key")))]
impl<P> Drop for DecapsulationSeed<P>
where
P: KemParams,
{
fn drop(&mut self) {
self.d.zeroize();
self.z.zeroize();
}
}

#[cfg(feature = "zeroize")]
impl<P> Zeroize for DecapsulationKeyInner<P>
where
P: KemParams,
{
fn zeroize(&mut self) {
self.dk_pke.zeroize();
self.z.zeroize();
}
}

#[cfg(all(feature = "zeroize", not(feature = "decap_key")))]
impl<P> Zeroize for DecapsulationSeed<P>
where
P: KemParams,
{
fn zeroize(&mut self) {
self.d.zeroize();
self.z.zeroize();
}
}

#[cfg(feature = "zeroize")]
impl<P> Drop for DecapsulationKey<P>
where
P: KemParams,
{
fn drop(&mut self) {
self.key.zeroize();
}
}

#[cfg(feature = "zeroize")]
impl<P> ZeroizeOnDrop for DecapsulationKeyInner<P> where P: KemParams {}

#[cfg(all(feature = "zeroize", not(feature = "decap_key")))]
impl<P> ZeroizeOnDrop for DecapsulationSeed<P> where P: KemParams {}

#[cfg(feature = "zeroize")]
impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}

impl<P> EncodedSizeUser for DecapsulationKey<P>
impl<P> EncodedSizeUser for DecapsulationKeyInner<P>
where
P: KemParams,
{
Expand Down Expand Up @@ -75,14 +156,67 @@ where
}
}

#[cfg(not(feature = "decap_key"))]
impl<P> EncodedSizeUser for DecapsulationSeed<P>
where
P: KemParams,
{
type EncodedSize = U64;

#[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
fn from_bytes(enc: &Encoded<Self>) -> Self {
let (d, z) = P::split_seed(enc);

Self {
d: d.clone(),
z: z.clone(),
_phantom: PhantomData,
}
}

fn as_bytes(&self) -> Encoded<Self> {
self.d.clone().concat(self.z.clone())
}
}

impl<P> EncodedSizeUser for DecapsulationKey<P>
where
P: KemParams,
{
#[cfg(feature = "decap_key")]
type EncodedSize = DecapsulationKeySize<P>;
#[cfg(not(feature = "decap_key"))]
type EncodedSize = U64;

#[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
fn from_bytes(enc: &Encoded<Self>) -> Self {
#[cfg(feature = "decap_key")]
{
Self {
key: DecapsulationKeyInner::<P>::from_bytes(enc),
}
}
#[cfg(not(feature = "decap_key"))]
{
Self {
key: DecapsulationSeed::<P>::from_bytes(enc),
}
}
}

fn as_bytes(&self) -> Encoded<Self> {
self.key.as_bytes()
}
}

// 0xff if x == y, 0x00 otherwise
fn constant_time_eq(x: u8, y: u8) -> u8 {
let diff = x ^ y;
let is_zero = !diff & diff.wrapping_sub(1);
0u8.wrapping_sub(is_zero >> 7)
}

impl<P> ::kem::Decapsulate<EncodedCiphertext<P>, SharedKey> for DecapsulationKey<P>
impl<P> ::kem::Decapsulate<EncodedCiphertext<P>, SharedKey> for DecapsulationKeyInner<P>
where
P: KemParams,
{
Expand Down Expand Up @@ -117,15 +251,46 @@ where
}
}

impl<P> DecapsulationKey<P>
#[cfg(not(feature = "decap_key"))]
impl<P> ::kem::Decapsulate<EncodedCiphertext<P>, SharedKey> for DecapsulationSeed<P>
where
P: KemParams,
{
/// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKey`].
pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
&self.ek
type Error = Infallible;

fn decapsulate(
&self,
encapsulated_key: &EncodedCiphertext<P>,
) -> Result<SharedKey, Self::Error> {
DecapsulationKeyInner::<P>::generate_deterministic(&self.d, &self.z)
.decapsulate(encapsulated_key)
}
}

impl<P> ::kem::Decapsulate<EncodedCiphertext<P>, SharedKey> for DecapsulationKey<P>
where
P: KemParams,
{
type Error = Infallible;

fn decapsulate(
&self,
encapsulated_key: &EncodedCiphertext<P>,
) -> Result<SharedKey, Self::Error> {
self.key.decapsulate(encapsulated_key)
}
}

impl<P> DecapsulationKeyInner<P>
where
P: KemParams,
{
/// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKeyInner`].
pub fn encapsulation_key(&self) -> EncapsulationKey<P> {
self.ek.clone()
}

#[cfg(feature = "decap_key")]
pub(crate) fn generate(rng: &mut impl CryptoRngCore) -> Self {
let d: B32 = rand(rng);
let z: B32 = rand(rng);
Expand All @@ -142,6 +307,85 @@ where
}
}

#[cfg(not(feature = "decap_key"))]
impl<P> DecapsulationSeed<P>
where
P: KemParams,
{
/// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationSeed`].
#[must_use]
pub fn encapsulation_key(&self) -> EncapsulationKey<P> {
DecapsulationKeyInner::<P>::generate_deterministic(&self.d, &self.z)
.encapsulation_key()
.clone()
}

pub(crate) fn generate(rng: &mut impl CryptoRngCore) -> Self {
let d: B32 = rand(rng);
let z: B32 = rand(rng);
Self {
d,
z,
_phantom: PhantomData,
}
}

#[must_use]
#[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
#[cfg(feature = "deterministic")]
pub(crate) fn generate_deterministic(d: &B32, z: &B32) -> Self {
Self {
d: *d,
z: *z,
_phantom: PhantomData,
}
}
}

impl<P> DecapsulationKey<P>
where
P: KemParams,
{
/// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKey`].
#[must_use]
pub fn encapsulation_key(&self) -> EncapsulationKey<P> {
self.key.encapsulation_key()
}

pub(crate) fn generate(rng: &mut impl CryptoRngCore) -> Self {
#[cfg(not(feature = "decap_key"))]
{
DecapsulationKey {
key: DecapsulationSeed::<P>::generate(rng),
}
}
#[cfg(feature = "decap_key")]
{
DecapsulationKey {
key: DecapsulationKeyInner::<P>::generate(rng),
}
}
}

#[must_use]
#[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
#[cfg(feature = "deterministic")]
pub(crate) fn generate_deterministic(d: &B32, z: &B32) -> Self {
#[cfg(not(feature = "decap_key"))]
{
DecapsulationKey {
key: DecapsulationSeed::<P>::generate_deterministic(d, z),
}
}
#[cfg(feature = "decap_key")]
{
DecapsulationKey {
key: DecapsulationKeyInner::<P>::generate_deterministic(d, z),
}
}
}
}

/// An `EncapsulationKey` provides the ability to encapsulate a shared key so that it can only be
/// decapsulated by the holder of the corresponding decapsulation key.
#[derive(Clone, Debug, PartialEq)]
Expand Down
8 changes: 8 additions & 0 deletions ml-kem/src/param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,15 @@ pub trait KemParams: PkeParams {
&B32,
&B32,
);

fn split_seed(enc: &EncodedDecapsulationSeed) -> (&B32, &B32);
}

pub type DecapsulationKeySize<P> = <P as KemParams>::DecapsulationKeySize;
pub type EncapsulationKeySize<P> = <P as PkeParams>::EncryptionKeySize;

pub type EncodedDecapsulationKey<P> = Array<u8, <P as KemParams>::DecapsulationKeySize>;
pub type EncodedDecapsulationSeed = Array<u8, U64>;

impl<P> KemParams for P
where
Expand Down Expand Up @@ -295,4 +298,9 @@ where
let (dk_pke, ek_pke) = enc.split_ref();
(dk_pke, ek_pke, h, z)
}

fn split_seed(enc: &EncodedDecapsulationSeed) -> (&B32, &B32) {
let (d, z) = enc.split_ref();
(d, z)
}
}
1 change: 1 addition & 0 deletions ml-kem/tests/encap-decap.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![cfg(feature = "deterministic")]
#![cfg(feature = "decap_key")]

use ml_kem::*;

Expand Down
1 change: 1 addition & 0 deletions ml-kem/tests/key-gen.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![cfg(feature = "deterministic")]
#![cfg(feature = "decap_key")]

use ml_kem::*;

Expand Down
Loading