From 3ec3391c3f71b3d82cdfe0064cb8a390e03440d5 Mon Sep 17 00:00:00 2001 From: fionser Date: Sat, 13 Jan 2024 23:13:09 +0800 Subject: [PATCH 1/4] change small polysize to 16 --- crates/fhe-math/src/rq/context.rs | 14 +++--- crates/fhe-math/src/rq/convert.rs | 64 ++++++++++++------------- crates/fhe-math/src/rq/mod.rs | 74 ++++++++++++++++------------- crates/fhe-math/src/rq/ops.rs | 35 +++++++------- crates/fhe-math/src/rq/scaler.rs | 4 +- crates/fhe-math/src/rq/serialize.rs | 4 +- 6 files changed, 102 insertions(+), 93 deletions(-) diff --git a/crates/fhe-math/src/rq/context.rs b/crates/fhe-math/src/rq/context.rs index 49119311..edd80fd3 100644 --- a/crates/fhe-math/src/rq/context.rs +++ b/crates/fhe-math/src/rq/context.rs @@ -172,7 +172,7 @@ mod tests { fn context_constructor() { for modulus in MODULI { // modulus is = 1 modulo 2 * 8 - assert!(Context::new(&[*modulus], 8).is_ok()); + assert!(Context::new(&[*modulus], 16).is_ok()); if supports_ntt(*modulus, 128) { assert!(Context::new(&[*modulus], 128).is_ok()); @@ -182,7 +182,7 @@ mod tests { } // All moduli in MODULI are = 1 modulo 2 * 8 - assert!(Context::new(MODULI, 8).is_ok()); + assert!(Context::new(MODULI, 16).is_ok()); // This should fail since 1153 != 1 moduli 2 * 128 assert!(Context::new(MODULI, 128).is_err()); @@ -191,10 +191,10 @@ mod tests { #[test] fn next_context() -> Result<(), Box> { // A context should have a children pointing to a context with one less modulus. - let context = Arc::new(Context::new(MODULI, 8)?); + let context = Arc::new(Context::new(MODULI, 16)?); assert_eq!( context.next_context, - Some(Arc::new(Context::new(&MODULI[..MODULI.len() - 1], 8)?)) + Some(Arc::new(Context::new(&MODULI[..MODULI.len() - 1], 16)?)) ); // We can go down the chain of the MODULI.len() - 1 context's. @@ -212,13 +212,13 @@ mod tests { #[test] fn niterations_to() -> Result<(), Box> { // A context should have a children pointing to a context with one less modulus. - let context = Arc::new(Context::new(MODULI, 8)?); + let context = Arc::new(Context::new(MODULI, 16)?); assert_eq!(context.niterations_to(&context).ok(), Some(0)); assert_eq!( context - .niterations_to(&Arc::new(Context::new(&MODULI[1..], 8)?)) + .niterations_to(&Arc::new(Context::new(&MODULI[1..], 16)?)) .err(), Some(crate::Error::InvalidContext) ); @@ -226,7 +226,7 @@ mod tests { for i in 1..MODULI.len() { assert_eq!( context - .niterations_to(&Arc::new(Context::new(&MODULI[..MODULI.len() - i], 8)?)) + .niterations_to(&Arc::new(Context::new(&MODULI[..MODULI.len() - i], 16)?)) .ok(), Some(i) ); diff --git a/crates/fhe-math/src/rq/convert.rs b/crates/fhe-math/src/rq/convert.rs index d4c0c607..0f505ab5 100644 --- a/crates/fhe-math/src/rq/convert.rs +++ b/crates/fhe-math/src/rq/convert.rs @@ -439,7 +439,7 @@ mod tests { fn proto() -> Result<(), Box> { let mut rng = thread_rng(); for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); let proto = Rq::from(&p); assert_eq!(Poly::try_convert_from(&proto, &ctx, false, None)?, p); @@ -459,7 +459,7 @@ mod tests { ); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); let proto = Rq::from(&p); assert_eq!(Poly::try_convert_from(&proto, &ctx, false, None)?, p); @@ -478,7 +478,7 @@ mod tests { CrateError::Default("The representation asked for does not match the representation in the serialization".to_string()) ); - let ctx = Arc::new(Context::new(&MODULI[0..1], 8)?); + let ctx = Arc::new(Context::new(&MODULI[0..1], 16)?); assert_eq!( Poly::try_convert_from(&proto, &ctx, false, None) .expect_err("Should fail because of incorrect context"), @@ -491,7 +491,7 @@ mod tests { #[test] fn try_convert_from_slice_zero() -> Result<(), Box> { for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); // Power Basis assert_eq!( @@ -503,15 +503,15 @@ mod tests { Poly::zero(&ctx, Representation::PowerBasis) ); assert_eq!( - Poly::try_convert_from(&[0u64; 8], &ctx, false, Representation::PowerBasis)?, + Poly::try_convert_from(&[0u64; 16], &ctx, false, Representation::PowerBasis)?, Poly::zero(&ctx, Representation::PowerBasis) ); assert_eq!( - Poly::try_convert_from(&[0i64; 8], &ctx, false, Representation::PowerBasis)?, + Poly::try_convert_from(&[0i64; 16], &ctx, false, Representation::PowerBasis)?, Poly::zero(&ctx, Representation::PowerBasis) ); assert!(Poly::try_convert_from( - &[0u64; 9], // One too many + &[0u64; 17], // One too many &ctx, false, Representation::PowerBasis, @@ -522,12 +522,12 @@ mod tests { assert!(Poly::try_convert_from(&[0u64], &ctx, false, Representation::Ntt).is_err()); assert!(Poly::try_convert_from(&[0i64], &ctx, false, Representation::Ntt).is_err()); assert_eq!( - Poly::try_convert_from(&[0u64; 8], &ctx, false, Representation::Ntt)?, - Poly::zero(&ctx, Representation::Ntt) + Poly::try_convert_from(&[0u64; 16], &ctx, false, Representation::Ntt)?, + Poly::zero(&ctx, Representation::Ntt) ); - assert!(Poly::try_convert_from(&[0i64; 8], &ctx, false, Representation::Ntt).is_err()); + assert!(Poly::try_convert_from(&[0i64; 16], &ctx, false, Representation::Ntt).is_err()); assert!(Poly::try_convert_from( - &[0u64; 9], // One too many + &[0u64; 17], // One too many &ctx, false, Representation::Ntt, @@ -535,7 +535,7 @@ mod tests { .is_err()); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); assert_eq!( Poly::try_convert_from( Vec::::default(), @@ -557,22 +557,22 @@ mod tests { assert!(Poly::try_convert_from(&[0u64], &ctx, false, Representation::Ntt).is_err()); assert_eq!( - Poly::try_convert_from(&[0u64; 8], &ctx, false, Representation::PowerBasis)?, + Poly::try_convert_from(&[0u64; 16], &ctx, false, Representation::PowerBasis)?, Poly::zero(&ctx, Representation::PowerBasis) ); - assert!(Poly::try_convert_from(&[0u64; 8], &ctx, false, Representation::Ntt).is_err()); + assert!(Poly::try_convert_from(&[0u64; 16], &ctx, false, Representation::Ntt).is_err()); assert!( - Poly::try_convert_from(&[0u64; 9], &ctx, false, Representation::PowerBasis).is_err() + Poly::try_convert_from(&[0u64; 17], &ctx, false, Representation::PowerBasis).is_err() ); - assert!(Poly::try_convert_from(&[0u64; 9], &ctx, false, Representation::Ntt).is_err()); + assert!(Poly::try_convert_from(&[0u64; 17], &ctx, false, Representation::Ntt).is_err()); assert_eq!( - Poly::try_convert_from(&[0u64; 24], &ctx, false, Representation::PowerBasis)?, + Poly::try_convert_from(&[0u64; 16], &ctx, false, Representation::PowerBasis)?, Poly::zero(&ctx, Representation::PowerBasis) ); assert_eq!( - Poly::try_convert_from(&[0u64; 24], &ctx, false, Representation::Ntt)?, + Poly::try_convert_from(&[0u64; 48], &ctx, false, Representation::Ntt)?, Poly::zero(&ctx, Representation::Ntt) ); @@ -582,7 +582,7 @@ mod tests { #[test] fn try_convert_from_vec_zero() -> Result<(), Box> { for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); assert_eq!( Poly::try_convert_from(vec![], &ctx, false, Representation::PowerBasis)?, Poly::zero(&ctx, Representation::PowerBasis) @@ -596,22 +596,22 @@ mod tests { assert!(Poly::try_convert_from(vec![0], &ctx, false, Representation::Ntt).is_err()); assert_eq!( - Poly::try_convert_from(vec![0; 8], &ctx, false, Representation::PowerBasis)?, + Poly::try_convert_from(vec![0; 16], &ctx, false, Representation::PowerBasis)?, Poly::zero(&ctx, Representation::PowerBasis) ); assert_eq!( - Poly::try_convert_from(vec![0; 8], &ctx, false, Representation::Ntt)?, + Poly::try_convert_from(vec![0; 16], &ctx, false, Representation::Ntt)?, Poly::zero(&ctx, Representation::Ntt) ); assert!( - Poly::try_convert_from(vec![0; 9], &ctx, false, Representation::PowerBasis) + Poly::try_convert_from(vec![0; 17], &ctx, false, Representation::PowerBasis) .is_err() ); - assert!(Poly::try_convert_from(vec![0; 9], &ctx, false, Representation::Ntt).is_err()); + assert!(Poly::try_convert_from(vec![0; 17], &ctx, false, Representation::Ntt).is_err()); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); assert_eq!( Poly::try_convert_from(vec![], &ctx, false, Representation::PowerBasis)?, Poly::zero(&ctx, Representation::PowerBasis) @@ -625,22 +625,22 @@ mod tests { assert!(Poly::try_convert_from(vec![0], &ctx, false, Representation::Ntt).is_err()); assert_eq!( - Poly::try_convert_from(vec![0; 8], &ctx, false, Representation::PowerBasis)?, + Poly::try_convert_from(vec![0; 16], &ctx, false, Representation::PowerBasis)?, Poly::zero(&ctx, Representation::PowerBasis) ); - assert!(Poly::try_convert_from(vec![0; 8], &ctx, false, Representation::Ntt).is_err()); + assert!(Poly::try_convert_from(vec![0; 16], &ctx, false, Representation::Ntt).is_err()); assert!( - Poly::try_convert_from(vec![0; 9], &ctx, false, Representation::PowerBasis).is_err() + Poly::try_convert_from(vec![0; 17], &ctx, false, Representation::PowerBasis).is_err() ); - assert!(Poly::try_convert_from(vec![0; 9], &ctx, false, Representation::Ntt).is_err()); + assert!(Poly::try_convert_from(vec![0; 17], &ctx, false, Representation::Ntt).is_err()); assert_eq!( - Poly::try_convert_from(vec![0; 24], &ctx, false, Representation::PowerBasis)?, + Poly::try_convert_from(vec![0; 48], &ctx, false, Representation::PowerBasis)?, Poly::zero(&ctx, Representation::PowerBasis) ); assert_eq!( - Poly::try_convert_from(vec![0; 24], &ctx, false, Representation::Ntt)?, + Poly::try_convert_from(vec![0; 48], &ctx, false, Representation::Ntt)?, Poly::zero(&ctx, Representation::Ntt) ); @@ -652,7 +652,7 @@ mod tests { let mut rng = thread_rng(); for _ in 0..100 { for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); let p_coeffs = Vec::::from(&p); let q = Poly::try_convert_from( @@ -664,7 +664,7 @@ mod tests { assert_eq!(p, q); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); let p_coeffs = Vec::::from(&p); assert_eq!(p_coeffs.len(), ctx.degree); diff --git a/crates/fhe-math/src/rq/mod.rs b/crates/fhe-math/src/rq/mod.rs index ea16fed1..f782d514 100644 --- a/crates/fhe-math/src/rq/mod.rs +++ b/crates/fhe-math/src/rq/mod.rs @@ -608,23 +608,31 @@ mod tests { BigUint::zero(), BigUint::zero(), BigUint::zero(), + BigUint::zero(), + BigUint::zero(), + BigUint::zero(), + BigUint::zero(), + BigUint::zero(), + BigUint::zero(), + BigUint::zero(), + BigUint::zero(), ]; for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); let p = Poly::zero(&ctx, Representation::PowerBasis); let q = Poly::zero(&ctx, Representation::Ntt); assert_ne!(p, q); - assert_eq!(Vec::::from(&p), &[0; 8]); - assert_eq!(Vec::::from(&q), &[0; 8]); + assert_eq!(Vec::::from(&p), &[0; 16]); + assert_eq!(Vec::::from(&q), &[0; 16]); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let p = Poly::zero(&ctx, Representation::PowerBasis); let q = Poly::zero(&ctx, Representation::Ntt); assert_ne!(p, q); - assert_eq!(Vec::::from(&p), [0; 8 * MODULI.len()]); - assert_eq!(Vec::::from(&q), [0; 8 * MODULI.len()]); + assert_eq!(Vec::::from(&p), [0; 16 * MODULI.len()]); + assert_eq!(Vec::::from(&q), [0; 16 * MODULI.len()]); assert_eq!(Vec::::from(&p), reference); assert_eq!(Vec::::from(&q), reference); @@ -634,12 +642,12 @@ mod tests { #[test] fn ctx() -> Result<(), Box> { for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); let p = Poly::zero(&ctx, Representation::PowerBasis); assert_eq!(p.ctx(), &ctx); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let p = Poly::zero(&ctx, Representation::PowerBasis); assert_eq!(p.ctx(), &ctx); @@ -654,13 +662,13 @@ mod tests { thread_rng().fill(&mut seed); for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed); let q = Poly::random_from_seed(&ctx, Representation::Ntt, seed); assert_eq!(p, q); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed); let q = Poly::random_from_seed(&ctx, Representation::Ntt, seed); assert_eq!(p, q); @@ -681,13 +689,13 @@ mod tests { let mut rng = thread_rng(); for _ in 0..50 { for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); let p = Poly::random(&ctx, Representation::Ntt, &mut rng); let p_coefficients = Vec::::from(&p); assert_eq!(p_coefficients, p.coefficients().as_slice().unwrap()) } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let p = Poly::random(&ctx, Representation::Ntt, &mut rng); let p_coefficients = Vec::::from(&p); assert_eq!(p_coefficients, p.coefficients().as_slice().unwrap()) @@ -699,13 +707,13 @@ mod tests { fn modulus() -> Result<(), Box> { for modulus in MODULI { let modulus_biguint = BigUint::from(*modulus); - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); assert_eq!(ctx.modulus(), &modulus_biguint) } let mut modulus_biguint = BigUint::one(); MODULI.iter().for_each(|m| modulus_biguint *= *m); - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); assert_eq!(ctx.modulus(), &modulus_biguint); Ok(()) @@ -715,7 +723,7 @@ mod tests { fn allow_variable_time_computations() -> Result<(), Box> { let mut rng = thread_rng(); for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); let mut p = Poly::random(&ctx, Representation::default(), &mut rng); assert!(!p.allow_variable_time_computations); @@ -729,7 +737,7 @@ mod tests { assert!(!p.allow_variable_time_computations); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let mut p = Poly::random(&ctx, Representation::default(), &mut rng); assert!(!p.allow_variable_time_computations); @@ -765,7 +773,7 @@ mod tests { #[test] fn change_representation() -> Result<(), Box> { let mut rng = thread_rng(); - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let mut p = Poly::random(&ctx, Representation::default(), &mut rng); assert_eq!(p.representation, Representation::default()); @@ -809,7 +817,7 @@ mod tests { #[test] fn override_representation() -> Result<(), Box> { let mut rng = thread_rng(); - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); assert_eq!(p.representation(), &p.representation); @@ -843,7 +851,7 @@ mod tests { fn small() -> Result<(), Box> { let mut rng = thread_rng(); for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); let q = Modulus::new(*modulus).unwrap(); let e = Poly::small(&ctx, Representation::PowerBasis, 0, &mut rng); @@ -871,11 +879,11 @@ mod tests { // Generate a very large polynomial to check the variance (here equal to 8). let ctx = Arc::new(Context::new(&[4611686018326724609], 1 << 18)?); let q = Modulus::new(4611686018326724609).unwrap(); - let p = Poly::small(&ctx, Representation::PowerBasis, 8, &mut thread_rng())?; + let p = Poly::small(&ctx, Representation::PowerBasis, 16, &mut thread_rng())?; let coefficients = p.coefficients().to_slice().unwrap(); let v = unsafe { q.center_vec_vt(coefficients) }; - assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 16); - assert_eq!(variance(&v).round(), 8.0); + assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 32); + assert_eq!(variance(&v).round(), 16.0); Ok(()) } @@ -884,7 +892,7 @@ mod tests { fn substitute() -> Result<(), Box> { let mut rng = thread_rng(); for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); let mut p_ntt = p.clone(); p_ntt.change_representation(Representation::Ntt); @@ -910,9 +918,9 @@ mod tests { // Substitution by 3 let mut q = p.substitute(&SubstitutionExponent::new(&ctx, 3)?)?; - let mut v = vec![0u64; 8]; - for i in 0..8 { - v[(3 * i) % 8] = if ((3 * i) / 8) & 1 == 1 && p_coeffs[i] > 0 { + let mut v = vec![0u64; 16]; + for i in 0..16 { + v[(3 * i) % 16] = if ((3 * i) / 16) & 1 == 1 && p_coeffs[i] > 0 { *modulus - p_coeffs[i] } else { p_coeffs[i] @@ -948,7 +956,7 @@ mod tests { ); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); let mut p_ntt = p.clone(); p_ntt.change_representation(Representation::Ntt); @@ -980,7 +988,7 @@ mod tests { fn mod_switch_down_next() -> Result<(), Box> { let mut rng = thread_rng(); let ntests = 100; - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); for _ in 0..ntests { // If the polynomial has incorrect representation, an error is returned @@ -1026,8 +1034,8 @@ mod tests { fn mod_switch_down_to() -> Result<(), Box> { let mut rng = thread_rng(); let ntests = 100; - let ctx1 = Arc::new(Context::new(MODULI, 8)?); - let ctx2 = Arc::new(Context::new(&MODULI[..2], 8)?); + let ctx1 = Arc::new(Context::new(MODULI, 16)?); + let ctx2 = Arc::new(Context::new(&MODULI[..2], 16)?); for _ in 0..ntests { let mut p = Poly::random(&ctx1, Representation::PowerBasis, &mut rng); @@ -1052,8 +1060,8 @@ mod tests { fn mod_switch_to() -> Result<(), Box> { let mut rng = thread_rng(); let ntests = 100; - let ctx1 = Arc::new(Context::new(&MODULI[..2], 8)?); - let ctx2 = Arc::new(Context::new(&MODULI[3..], 8)?); + let ctx1 = Arc::new(Context::new(&MODULI[..2], 16)?); + let ctx2 = Arc::new(Context::new(&MODULI[3..], 16)?); let switcher = Switcher::new(&ctx1, &ctx2)?; for _ in 0..ntests { let p = Poly::random(&ctx1, Representation::PowerBasis, &mut rng); @@ -1076,7 +1084,7 @@ mod tests { #[test] fn mul_x_power() -> Result<(), Box> { let mut rng = thread_rng(); - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let e = Poly::random(&ctx, Representation::Ntt, &mut rng).multiply_inverse_power_of_x(1); assert!(e.is_err()); assert_eq!( diff --git a/crates/fhe-math/src/rq/ops.rs b/crates/fhe-math/src/rq/ops.rs index 71994e1b..74da4c56 100644 --- a/crates/fhe-math/src/rq/ops.rs +++ b/crates/fhe-math/src/rq/ops.rs @@ -472,9 +472,10 @@ mod tests { #[test] fn add() -> Result<(), Box> { let mut rng = thread_rng(); + let n = 16; for _ in 0..100 { for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], n)?); let m = Modulus::new(*modulus).unwrap(); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); @@ -494,14 +495,14 @@ mod tests { assert_eq!(Vec::::from(&r), a); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); let mut a = Vec::::from(&p); let b = Vec::::from(&q); for i in 0..MODULI.len() { let m = Modulus::new(MODULI[i]).unwrap(); - m.add_vec(&mut a[i * 8..(i + 1) * 8], &b[i * 8..(i + 1) * 8]) + m.add_vec(&mut a[i *16..(i + 1) * 16], &b[i * 16..(i + 1) * 16]) } let r = &p + &q; assert_eq!(r.representation, Representation::PowerBasis); @@ -515,7 +516,7 @@ mod tests { let mut rng = thread_rng(); for _ in 0..100 { for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); let m = Modulus::new(*modulus).unwrap(); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); @@ -535,14 +536,14 @@ mod tests { assert_eq!(Vec::::from(&r), a); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); let mut a = Vec::::from(&p); let b = Vec::::from(&q); for i in 0..MODULI.len() { let m = Modulus::new(MODULI[i]).unwrap(); - m.sub_vec(&mut a[i * 8..(i + 1) * 8], &b[i * 8..(i + 1) * 8]) + m.sub_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16]) } let r = &p - &q; assert_eq!(r.representation, Representation::PowerBasis); @@ -556,7 +557,7 @@ mod tests { let mut rng = thread_rng(); for _ in 0..100 { for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); let m = Modulus::new(*modulus).unwrap(); let p = Poly::random(&ctx, Representation::Ntt, &mut rng); @@ -568,14 +569,14 @@ mod tests { assert_eq!(Vec::::from(&r), a); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let p = Poly::random(&ctx, Representation::Ntt, &mut rng); let q = Poly::random(&ctx, Representation::Ntt, &mut rng); let mut a = Vec::::from(&p); let b = Vec::::from(&q); for i in 0..MODULI.len() { let m = Modulus::new(MODULI[i]).unwrap(); - m.mul_vec(&mut a[i * 8..(i + 1) * 8], &b[i * 8..(i + 1) * 8]) + m.mul_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16]) } let r = &p * &q; assert_eq!(r.representation, Representation::Ntt); @@ -589,7 +590,7 @@ mod tests { let mut rng = thread_rng(); for _ in 0..100 { for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); let m = Modulus::new(*modulus).unwrap(); let p = Poly::random(&ctx, Representation::Ntt, &mut rng); @@ -601,14 +602,14 @@ mod tests { assert_eq!(Vec::::from(&r), a); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let p = Poly::random(&ctx, Representation::Ntt, &mut rng); let q = Poly::random(&ctx, Representation::NttShoup, &mut rng); let mut a = Vec::::from(&p); let b = Vec::::from(&q); for i in 0..MODULI.len() { let m = Modulus::new(MODULI[i]).unwrap(); - m.mul_vec(&mut a[i * 8..(i + 1) * 8], &b[i * 8..(i + 1) * 8]) + m.mul_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16]) } let r = &p * &q; assert_eq!(r.representation, Representation::Ntt); @@ -622,7 +623,7 @@ mod tests { let mut rng = thread_rng(); for _ in 0..100 { for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); let m = Modulus::new(*modulus).unwrap(); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); @@ -640,12 +641,12 @@ mod tests { assert_eq!(Vec::::from(&r), a); } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); let mut a = Vec::::from(&p); for i in 0..MODULI.len() { let m = Modulus::new(MODULI[i]).unwrap(); - m.neg_vec(&mut a[i * 8..(i + 1) * 8]) + m.neg_vec(&mut a[i * 16..(i + 1) * 16]) } let r = -&p; assert_eq!(r.representation, Representation::PowerBasis); @@ -663,7 +664,7 @@ mod tests { let mut rng = thread_rng(); for _ in 0..20 { for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let ctx = Arc::new(Context::new(&[*modulus], 16)?); for len in 1..50 { let p = (0..len) @@ -680,7 +681,7 @@ mod tests { } } - let ctx = Arc::new(Context::new(MODULI, 8)?); + let ctx = Arc::new(Context::new(MODULI, 16)?); for len in 1..50 { let p = (0..len) .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) diff --git a/crates/fhe-math/src/rq/scaler.rs b/crates/fhe-math/src/rq/scaler.rs index 3a11a939..e03c3d45 100644 --- a/crates/fhe-math/src/rq/scaler.rs +++ b/crates/fhe-math/src/rq/scaler.rs @@ -161,8 +161,8 @@ mod tests { fn scaler() -> Result<(), Box> { let mut rng = thread_rng(); let ntests = 100; - let from = Arc::new(Context::new(Q, 8)?); - let to = Arc::new(Context::new(P, 8)?); + let from = Arc::new(Context::new(Q, 16)?); + let to = Arc::new(Context::new(P, 16)?); for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] { for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] { diff --git a/crates/fhe-math/src/rq/serialize.rs b/crates/fhe-math/src/rq/serialize.rs index 85a99d9b..658fa588 100644 --- a/crates/fhe-math/src/rq/serialize.rs +++ b/crates/fhe-math/src/rq/serialize.rs @@ -43,7 +43,7 @@ mod tests { let mut rng = thread_rng(); for qi in Q { - let ctx = Arc::new(Context::new(&[*qi], 8)?); + let ctx = Arc::new(Context::new(&[*qi], 16)?); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); let p = Poly::random(&ctx, Representation::Ntt, &mut rng); @@ -52,7 +52,7 @@ mod tests { assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); } - let ctx = Arc::new(Context::new(Q, 8)?); + let ctx = Arc::new(Context::new(Q, 16)?); let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); let p = Poly::random(&ctx, Representation::Ntt, &mut rng); From c52c425c8b10b053bd45cc7641f22d4e32d3aed0 Mon Sep 17 00:00:00 2001 From: fionser Date: Sat, 13 Jan 2024 23:28:51 +0800 Subject: [PATCH 2/4] change small polysize from 8 to a larger value --- crates/fhe-math/src/ntt/mod.rs | 6 +- crates/fhe/src/bfv/ciphertext.rs | 16 +- .../bfv/keys/.key_switching_key.rs.rustfmt | 549 ++++++++++++++++++ crates/fhe/src/bfv/keys/evaluation_key.rs | 28 +- crates/fhe/src/bfv/keys/galois_key.rs | 8 +- crates/fhe/src/bfv/keys/key_switching_key.rs | 16 +- crates/fhe/src/bfv/keys/public_key.rs | 10 +- .../fhe/src/bfv/keys/relinearization_key.rs | 12 +- crates/fhe/src/bfv/keys/secret_key.rs | 6 +- crates/fhe/src/bfv/ops/dot_product.rs | 4 +- crates/fhe/src/bfv/ops/mod.rs | 30 +- crates/fhe/src/bfv/ops/mul.rs | 8 +- crates/fhe/src/bfv/parameters.rs | 28 +- crates/fhe/src/bfv/plaintext.rs | 20 +- crates/fhe/src/bfv/plaintext_vec.rs | 2 +- crates/fhe/src/bfv/rgsw_ciphertext.rs | 8 +- crates/fhe/src/mbfv/public_key_gen.rs | 4 +- crates/fhe/src/mbfv/public_key_switch.rs | 4 +- crates/fhe/src/mbfv/relin_key_gen.rs | 4 +- crates/fhe/src/mbfv/secret_key_switch.rs | 12 +- 20 files changed, 662 insertions(+), 113 deletions(-) create mode 100644 crates/fhe/src/bfv/keys/.key_switching_key.rs.rustfmt diff --git a/crates/fhe-math/src/ntt/mod.rs b/crates/fhe-math/src/ntt/mod.rs index 24bcefe0..5294b870 100644 --- a/crates/fhe-math/src/ntt/mod.rs +++ b/crates/fhe-math/src/ntt/mod.rs @@ -24,7 +24,7 @@ mod tests { #[test] fn constructor() { - for size in [8, 1024] { + for size in [32, 1024] { for p in [1153, 4611686018326724609] { let q = Modulus::new(p).unwrap(); let supports_ntt = supports_ntt(p, size); @@ -45,7 +45,7 @@ mod tests { let ntests = 100; let mut rng = thread_rng(); - for size in [8, 1024] { + for size in [32, 1024] { for p in [1153, 4611686018326724609] { let q = Modulus::new(p).unwrap(); @@ -79,7 +79,7 @@ mod tests { let ntests = 100; let mut rng = thread_rng(); - for size in [8, 1024] { + for size in [32, 1024] { for p in [1153, 4611686018326724609] { let q = Modulus::new(p).unwrap(); diff --git a/crates/fhe/src/bfv/ciphertext.rs b/crates/fhe/src/bfv/ciphertext.rs index 91fb76e1..995f0c54 100644 --- a/crates/fhe/src/bfv/ciphertext.rs +++ b/crates/fhe/src/bfv/ciphertext.rs @@ -205,8 +205,8 @@ mod tests { fn proto_conversion() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); let v = params.plaintext.random_vec(params.degree(), &mut rng); @@ -226,8 +226,8 @@ mod tests { fn serialize() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); let v = params.plaintext.random_vec(params.degree(), &mut rng); @@ -243,8 +243,8 @@ mod tests { fn new() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); let v = params.plaintext.random_vec(params.degree(), &mut rng); @@ -281,8 +281,8 @@ mod tests { fn mod_switch_to_last_level() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); let v = params.plaintext.random_vec(params.degree(), &mut rng); diff --git a/crates/fhe/src/bfv/keys/.key_switching_key.rs.rustfmt b/crates/fhe/src/bfv/keys/.key_switching_key.rs.rustfmt new file mode 100644 index 00000000..e3a03098 --- /dev/null +++ b/crates/fhe/src/bfv/keys/.key_switching_key.rs.rustfmt @@ -0,0 +1,549 @@ +//! Key-switching keys for the BFV encryption scheme + +use crate::bfv::{traits::TryConvertFrom as BfvTryConvertFrom, BfvParameters, SecretKey}; +use crate::proto::bfv::KeySwitchingKey as KeySwitchingKeyProto; +use crate::{Error, Result}; +use fhe_math::rq::traits::TryConvertFrom; +use fhe_math::rq::Context; +use fhe_math::{ + rns::RnsContext, + rq::{Poly, Representation}, +}; +use fhe_traits::{DeserializeWithContext, Serialize}; +use itertools::{izip, Itertools}; +use num_bigint::BigUint; +use rand::{CryptoRng, Rng, RngCore, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use std::sync::Arc; +use zeroize::Zeroizing; + +/// Key switching key for the BFV encryption scheme. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct KeySwitchingKey { + /// The parameters of the underlying BFV encryption scheme. + pub(crate) par: Arc, + + /// The (optional) seed that generated the polynomials c1. + pub(crate) seed: Option<::Seed>, + + /// The key switching elements c0. + pub(crate) c0: Box<[Poly]>, + + /// The key switching elements c1. + pub(crate) c1: Box<[Poly]>, + + /// The level and context of the polynomials that will be key switched. + pub(crate) ciphertext_level: usize, + pub(crate) ctx_ciphertext: Arc, + + /// The level and context of the key switching key. + pub(crate) ksk_level: usize, + pub(crate) ctx_ksk: Arc, + + // For level with only one modulus, we will use basis + pub(crate) log_base: usize, +} + +impl KeySwitchingKey { + /// Generate a [`KeySwitchingKey`] to this [`SecretKey`] from a polynomial + /// `from`. + pub fn new( + sk: &SecretKey, + from: &Poly, + ciphertext_level: usize, + ksk_level: usize, + rng: &mut R, + ) -> Result { + let ctx_ksk = sk.par.ctx_at_level(ksk_level)?; + let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?; + + if from.ctx() != ctx_ksk { + return Err(Error::DefaultError( + "Incorrect context for polynomial from".to_string(), + )); + } + + let mut seed = ::Seed::default(); + rng.fill(&mut seed); + + if ctx_ksk.moduli().len() == 1 { + let modulus = ctx_ksk.moduli().first().unwrap(); + let log_modulus = modulus.next_power_of_two().ilog2() as usize; + let log_base = log_modulus / 2; + + let c1 = Self::generate_c1(ctx_ksk, seed, (log_modulus + log_base - 1) / log_base); + let c0 = Self::generate_c0_decomposition(sk, from, &c1, rng, log_base)?; + + Ok(Self { + par: sk.par.clone(), + seed: Some(seed), + c0: c0.into_boxed_slice(), + c1: c1.into_boxed_slice(), + ciphertext_level, + ctx_ciphertext: ctx_ciphertext.clone(), + ksk_level, + ctx_ksk: ctx_ksk.clone(), + log_base, + }) + } else { + let c1 = Self::generate_c1(ctx_ksk, seed, ctx_ciphertext.moduli().len()); + let c0 = Self::generate_c0(sk, from, &c1, rng)?; + + Ok(Self { + par: sk.par.clone(), + seed: Some(seed), + c0: c0.into_boxed_slice(), + c1: c1.into_boxed_slice(), + ciphertext_level, + ctx_ciphertext: ctx_ciphertext.clone(), + ksk_level, + ctx_ksk: ctx_ksk.clone(), + log_base: 0, + }) + } + } + + /// Generate the c1's from the seed + fn generate_c1( + ctx: &Arc, + seed: ::Seed, + size: usize, + ) -> Vec { + let mut c1 = Vec::with_capacity(size); + let mut rng = ChaCha8Rng::from_seed(seed); + (0..size).for_each(|_| { + let mut seed_i = ::Seed::default(); + rng.fill(&mut seed_i); + let mut a = Poly::random_from_seed(ctx, Representation::NttShoup, seed_i); + unsafe { a.allow_variable_time_computations() } + c1.push(a); + }); + c1 + } + + /// Generate the c0's from the c1's and the secret key + fn generate_c0( + sk: &SecretKey, + from: &Poly, + c1: &[Poly], + rng: &mut R, + ) -> Result> { + if c1.is_empty() { + return Err(Error::DefaultError("Empty number of c1's".to_string())); + } + if from.representation() != &Representation::PowerBasis { + return Err(Error::DefaultError( + "Unexpected representation for from".to_string(), + )); + } + + let size = c1.len(); + + let mut s = Zeroizing::new(Poly::try_convert_from( + sk.coeffs.as_ref(), + c1[0].ctx(), + false, + Representation::PowerBasis, + )?); + s.change_representation(Representation::Ntt); + + let rns = RnsContext::new(&sk.par.moduli[..size])?; + let c0 = c1 + .iter() + .enumerate() + .map(|(i, c1i)| { + let mut a_s = Zeroizing::new(c1i.clone()); + a_s.disallow_variable_time_computations(); + a_s.change_representation(Representation::Ntt); + *a_s.as_mut() *= s.as_ref(); + a_s.change_representation(Representation::PowerBasis); + + let mut b = + Poly::small(a_s.ctx(), Representation::PowerBasis, sk.par.variance, rng)?; + b -= &a_s; + + let gi = rns.get_garner(i).unwrap(); + let g_i_from = Zeroizing::new(gi * from); + b += &g_i_from; + + // It is now safe to enable variable time computations. + unsafe { b.allow_variable_time_computations() } + b.change_representation(Representation::NttShoup); + Ok(b) + }) + .collect::>>()?; + + Ok(c0) + } + + /// Generate the c0's from the c1's and the secret key + fn generate_c0_decomposition( + sk: &SecretKey, + from: &Poly, + c1: &[Poly], + rng: &mut R, + log_base: usize, + ) -> Result> { + if c1.is_empty() { + return Err(Error::DefaultError("Empty number of c1's".to_string())); + } + + if from.representation() != &Representation::PowerBasis { + return Err(Error::DefaultError( + "Unexpected representation for from".to_string(), + )); + } + + let mut s = Zeroizing::new(Poly::try_convert_from( + sk.coeffs.as_ref(), + c1[0].ctx(), + false, + Representation::PowerBasis, + )?); + s.change_representation(Representation::Ntt); + + let c0 = c1 + .iter() + .enumerate() + .map(|(i, c1i)| { + let mut a_s = Zeroizing::new(c1i.clone()); + a_s.disallow_variable_time_computations(); + a_s.change_representation(Representation::Ntt); + *a_s.as_mut() *= s.as_ref(); + a_s.change_representation(Representation::PowerBasis); + + let mut b = + Poly::small(a_s.ctx(), Representation::PowerBasis, sk.par.variance, rng)?; + b -= &a_s; + + let power = BigUint::from(1u64 << (i * log_base)); + b += &(from * &power); + + // It is now safe to enable variable time computations. + unsafe { b.allow_variable_time_computations() } + b.change_representation(Representation::NttShoup); + Ok(b) + }) + .collect::>>()?; + + Ok(c0) + } + + /// Key switch a polynomial. + pub fn key_switch(&self, p: &Poly) -> Result<(Poly, Poly)> { + if self.log_base != 0 { + return self.key_switch_decomposition(p); + } + + if p.ctx().as_ref() != self.ctx_ciphertext.as_ref() { + return Err(Error::DefaultError( + "The input polynomial does not have the correct context.".to_string(), + )); + } + if p.representation() != &Representation::PowerBasis { + return Err(Error::DefaultError("Incorrect representation".to_string())); + } + + let mut c0 = Poly::zero(&self.ctx_ksk, Representation::Ntt); + let mut c1 = Poly::zero(&self.ctx_ksk, Representation::Ntt); + for (c2_i_coefficients, c0_i, c1_i) in izip!( + p.coefficients().outer_iter(), + self.c0.iter(), + self.c1.iter() + ) { + let mut c2_i = unsafe { + Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( + c2_i_coefficients.as_slice().unwrap(), + &self.ctx_ksk, + ) + }; + c0 += &(&c2_i * c0_i); + c2_i *= c1_i; + c1 += &c2_i; + } + Ok((c0, c1)) + } + + /// Key switch a polynomial. + fn key_switch_decomposition(&self, p: &Poly) -> Result<(Poly, Poly)> { + if p.ctx().as_ref() != self.ctx_ciphertext.as_ref() { + return Err(Error::DefaultError( + "The input polynomial does not have the correct context.".to_string(), + )); + } + if p.representation() != &Representation::PowerBasis { + return Err(Error::DefaultError("Incorrect representation".to_string())); + } + + let log_modulus = p + .ctx() + .moduli() + .first() + .unwrap() + .next_power_of_two() + .ilog2() as usize; + + let mut coefficients = p.coefficients().to_slice().unwrap().to_vec(); + let mut c2i = vec![]; + let mask = (1u64 << self.log_base) - 1; + (0..(log_modulus + self.log_base - 1) / self.log_base).for_each(|_| { + c2i.push(coefficients.iter().map(|c| c & mask).collect_vec()); + coefficients.iter_mut().for_each(|c| *c >>= self.log_base); + }); + + let mut c0 = Poly::zero(&self.ctx_ksk, Representation::Ntt); + let mut c1 = Poly::zero(&self.ctx_ksk, Representation::Ntt); + for (c2_i_coefficients, c0_i, c1_i) in izip!(c2i.iter(), self.c0.iter(), self.c1.iter()) { + let mut c2_i = unsafe { + Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( + c2_i_coefficients.as_slice(), + &self.ctx_ksk, + ) + }; + c0 += &(&c2_i * c0_i); + c2_i *= c1_i; + c1 += &c2_i; + } + Ok((c0, c1)) + } +} + +impl From<&KeySwitchingKey> for KeySwitchingKeyProto { + fn from(value: &KeySwitchingKey) -> Self { + let mut ksk = KeySwitchingKeyProto::default(); + if let Some(seed) = value.seed.as_ref() { + ksk.seed = seed.to_vec(); + } else { + ksk.c1.reserve_exact(value.c1.len()); + for c1 in value.c1.iter() { + ksk.c1.push(c1.to_bytes()) + } + } + ksk.c0.reserve_exact(value.c0.len()); + for c0 in value.c0.iter() { + ksk.c0.push(c0.to_bytes()) + } + ksk.ciphertext_level = value.ciphertext_level as u32; + ksk.ksk_level = value.ksk_level as u32; + ksk.log_base = value.log_base as u32; + ksk + } +} + +impl BfvTryConvertFrom<&KeySwitchingKeyProto> for KeySwitchingKey { + fn try_convert_from(value: &KeySwitchingKeyProto, par: &Arc) -> Result { + let ciphertext_level = value.ciphertext_level as usize; + let ksk_level = value.ksk_level as usize; + let ctx_ksk = par.ctx_at_level(ksk_level)?; + let ctx_ciphertext = par.ctx_at_level(ciphertext_level)?; + + let c0_size: usize; + let log_base = value.log_base as usize; + if log_base != 0 { + if ksk_level != par.max_level() || ciphertext_level != par.max_level() { + return Err(Error::DefaultError( + "A decomposition size is specified but the levels are not maximal".to_string(), + )); + } else { + let log_modulus: usize = + par.moduli().first().unwrap().next_power_of_two().ilog2() as usize; + c0_size = (log_modulus + log_base - 1) / log_base; + } + } else { + c0_size = ctx_ciphertext.moduli().len(); + } + + if value.c0.len() != c0_size { + return Err(Error::DefaultError( + "Incorrect number of values in c0".to_string(), + )); + } + + let seed = if value.seed.is_empty() { + if value.c1.len() != c0_size { + return Err(Error::DefaultError( + "Incorrect number of values in c1".to_string(), + )); + } + None + } else { + let unwrapped = ::Seed::try_from(value.seed.clone()); + if unwrapped.is_err() { + return Err(Error::DefaultError("Invalid seed".to_string())); + } + Some(unwrapped.unwrap()) + }; + + let c1 = if let Some(seed) = seed { + Self::generate_c1(ctx_ksk, seed, value.c0.len()) + } else { + value + .c1 + .iter() + .map(|c1i| Poly::from_bytes(c1i, ctx_ksk).map_err(Error::MathError)) + .collect::>>()? + }; + + let c0 = value + .c0 + .iter() + .map(|c0i| Poly::from_bytes(c0i, ctx_ksk).map_err(Error::MathError)) + .collect::>>()?; + + Ok(Self { + par: par.clone(), + seed, + c0: c0.into_boxed_slice(), + c1: c1.into_boxed_slice(), + ciphertext_level, + ctx_ciphertext: ctx_ciphertext.clone(), + ksk_level, + ctx_ksk: ctx_ksk.clone(), + log_base: value.log_base as usize, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::bfv::{ + keys::key_switching_key::KeySwitchingKey, traits::TryConvertFrom, BfvParameters, SecretKey, + }; + use crate::proto::bfv::KeySwitchingKey as KeySwitchingKeyProto; + use fhe_math::{ + rns::RnsContext, + rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation}, + }; + use num_bigint::BigUint; + use rand::thread_rng; + use std::error::Error; + + #[test] + fn constructor() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(3, 8), + ] { + let sk = SecretKey::random(¶ms, &mut rng); + let ctx = params.ctx_at_level(0)?; + let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng); + assert!(ksk.is_ok()); + } + Ok(()) + } + + #[test] + fn constructor_last_level() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(3, 8), + ] { + let level = params.moduli().len() - 1; + let sk = SecretKey::random(¶ms, &mut rng); + let ctx = params.ctx_at_level(level)?; + let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let ksk = KeySwitchingKey::new(&sk, &p, level, level, &mut rng); + assert!(ksk.is_ok()); + } + Ok(()) + } + + #[test] + fn key_switch() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [BfvParameters::default_arc(6, 8)] { + for _ in 0..100 { + let sk = SecretKey::random(¶ms, &mut rng); + let ctx = params.ctx_at_level(0)?; + let mut p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?; + let mut s = Poly::try_convert_from( + sk.coeffs.as_ref(), + ctx, + false, + Representation::PowerBasis, + ) + .map_err(crate::Error::MathError)?; + s.change_representation(Representation::Ntt); + + let mut input = Poly::random(ctx, Representation::PowerBasis, &mut rng); + let (c0, c1) = ksk.key_switch(&input)?; + + let mut c2 = &c0 + &(&c1 * &s); + c2.change_representation(Representation::PowerBasis); + + input.change_representation(Representation::Ntt); + p.change_representation(Representation::Ntt); + let mut c3 = &input * &p; + c3.change_representation(Representation::PowerBasis); + + let rns = RnsContext::new(¶ms.moduli)?; + Vec::::from(&(&c2 - &c3)).iter().for_each(|b| { + assert!(std::cmp::min(b.bits(), (rns.modulus() - b).bits()) <= 70) + }); + } + } + Ok(()) + } + + #[test] + fn key_switch_decomposition() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [BfvParameters::default_arc(6, 8)] { + for _ in 0..100 { + let sk = SecretKey::random(¶ms, &mut rng); + let ctx = params.ctx_at_level(5)?; + let mut p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let ksk = KeySwitchingKey::new(&sk, &p, 5, 5, &mut rng)?; + let mut s = Poly::try_convert_from( + sk.coeffs.as_ref(), + ctx, + false, + Representation::PowerBasis, + ) + .map_err(crate::Error::MathError)?; + s.change_representation(Representation::Ntt); + + let mut input = Poly::random(ctx, Representation::PowerBasis, &mut rng); + let (c0, c1) = ksk.key_switch(&input)?; + + let mut c2 = &c0 + &(&c1 * &s); + c2.change_representation(Representation::PowerBasis); + + input.change_representation(Representation::Ntt); + p.change_representation(Representation::Ntt); + let mut c3 = &input * &p; + c3.change_representation(Representation::PowerBasis); + + let rns = RnsContext::new(ctx.moduli())?; + Vec::::from(&(&c2 - &c3)).iter().for_each(|b| { + assert!( + std::cmp::min(b.bits(), (rns.modulus() - b).bits()) + <= (rns.modulus().bits() / 2) + 10 + ) + }); + } + } + Ok(()) + } + + #[test] + fn proto_conversion() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(3, 8), + ] { + let sk = SecretKey::random(¶ms, &mut rng); + let ctx = params.ctx_at_level(0)?; + let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?; + let ksk_proto = KeySwitchingKeyProto::from(&ksk); + assert_eq!(ksk, KeySwitchingKey::try_convert_from(&ksk_proto, ¶ms)?); + } + Ok(()) + } +} diff --git a/crates/fhe/src/bfv/keys/evaluation_key.rs b/crates/fhe/src/bfv/keys/evaluation_key.rs index c23a8ed1..ab379e24 100644 --- a/crates/fhe/src/bfv/keys/evaluation_key.rs +++ b/crates/fhe/src/bfv/keys/evaluation_key.rs @@ -443,7 +443,7 @@ mod tests { #[test] fn builder() -> Result<(), Box> { let mut rng = thread_rng(); - let params = BfvParameters::default_arc(6, 8); + let params = BfvParameters::default_arc(6, 16); let sk = SecretKey::random(¶ms, &mut rng); let max_level = params.max_level(); @@ -517,8 +517,8 @@ mod tests { fn inner_sum() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(5, 8), + BfvParameters::default_arc(6, 16), + BfvParameters::default_arc(5, 16), ] { for _ in 0..25 { for ciphertext_level in 0..=params.max_level() { @@ -561,8 +561,8 @@ mod tests { fn row_rotation() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(5, 8), + BfvParameters::default_arc(6, 16), + BfvParameters::default_arc(5, 16), ] { for _ in 0..50 { for ciphertext_level in 0..=params.max_level() { @@ -606,8 +606,8 @@ mod tests { fn column_rotation() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(5, 8), + BfvParameters::default_arc(6, 16), + BfvParameters::default_arc(5, 16), ] { let row_size = params.degree() >> 1; for _ in 0..50 { @@ -661,8 +661,8 @@ mod tests { fn expansion() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(5, 8), + BfvParameters::default_arc(6, 16), + BfvParameters::default_arc(5, 16), ] { let log_degree = 64 - 1 - params.degree().leading_zeros(); for _ in 0..15 { @@ -715,9 +715,9 @@ mod tests { fn proto_conversion() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(5, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), + BfvParameters::default_arc(5, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); @@ -759,8 +759,8 @@ mod tests { fn serialize() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); diff --git a/crates/fhe/src/bfv/keys/galois_key.rs b/crates/fhe/src/bfv/keys/galois_key.rs index ca1a641c..65a9a7b6 100644 --- a/crates/fhe/src/bfv/keys/galois_key.rs +++ b/crates/fhe/src/bfv/keys/galois_key.rs @@ -128,8 +128,8 @@ mod tests { fn relinearization() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(3, 8), + BfvParameters::default_arc(6, 16), + BfvParameters::default_arc(3, 16), ] { for _ in 0..30 { let sk = SecretKey::random(¶ms, &mut rng); @@ -178,8 +178,8 @@ mod tests { fn proto_conversion() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(4, 8), + BfvParameters::default_arc(6, 16), + BfvParameters::default_arc(4, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); let gk = GaloisKey::new(&sk, 9, 0, 0, &mut rng)?; diff --git a/crates/fhe/src/bfv/keys/key_switching_key.rs b/crates/fhe/src/bfv/keys/key_switching_key.rs index e3a03098..0e663cc6 100644 --- a/crates/fhe/src/bfv/keys/key_switching_key.rs +++ b/crates/fhe/src/bfv/keys/key_switching_key.rs @@ -422,8 +422,8 @@ mod tests { fn constructor() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(3, 8), + BfvParameters::default_arc(6, 16), + BfvParameters::default_arc(3, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); let ctx = params.ctx_at_level(0)?; @@ -438,8 +438,8 @@ mod tests { fn constructor_last_level() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(3, 8), + BfvParameters::default_arc(6, 16), + BfvParameters::default_arc(3, 16), ] { let level = params.moduli().len() - 1; let sk = SecretKey::random(¶ms, &mut rng); @@ -454,7 +454,7 @@ mod tests { #[test] fn key_switch() -> Result<(), Box> { let mut rng = thread_rng(); - for params in [BfvParameters::default_arc(6, 8)] { + for params in [BfvParameters::default_arc(6, 16)] { for _ in 0..100 { let sk = SecretKey::random(¶ms, &mut rng); let ctx = params.ctx_at_level(0)?; @@ -492,7 +492,7 @@ mod tests { #[test] fn key_switch_decomposition() -> Result<(), Box> { let mut rng = thread_rng(); - for params in [BfvParameters::default_arc(6, 8)] { + for params in [BfvParameters::default_arc(6, 16)] { for _ in 0..100 { let sk = SecretKey::random(¶ms, &mut rng); let ctx = params.ctx_at_level(5)?; @@ -534,8 +534,8 @@ mod tests { fn proto_conversion() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(3, 8), + BfvParameters::default_arc(6, 16), + BfvParameters::default_arc(3, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); let ctx = params.ctx_at_level(0)?; diff --git a/crates/fhe/src/bfv/keys/public_key.rs b/crates/fhe/src/bfv/keys/public_key.rs index 7ad0ccb8..f0d61cf8 100644 --- a/crates/fhe/src/bfv/keys/public_key.rs +++ b/crates/fhe/src/bfv/keys/public_key.rs @@ -146,7 +146,7 @@ mod tests { #[test] fn keygen() -> Result<(), Box> { let mut rng = thread_rng(); - let params = BfvParameters::default_arc(1, 8); + let params = BfvParameters::default_arc(1, 16); let sk = SecretKey::random(¶ms, &mut rng); let pk = PublicKey::new(&sk, &mut rng); assert_eq!(pk.par, params); @@ -161,8 +161,8 @@ mod tests { fn encrypt_decrypt() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { for level in 0..params.max_level() { for _ in 0..20 { @@ -190,8 +190,8 @@ mod tests { fn test_serialize() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); let pk = PublicKey::new(&sk, &mut rng); diff --git a/crates/fhe/src/bfv/keys/relinearization_key.rs b/crates/fhe/src/bfv/keys/relinearization_key.rs index 816a8ddd..b16ece6f 100644 --- a/crates/fhe/src/bfv/keys/relinearization_key.rs +++ b/crates/fhe/src/bfv/keys/relinearization_key.rs @@ -165,7 +165,7 @@ mod tests { #[test] fn relinearization() -> Result<(), Box> { let mut rng = thread_rng(); - for params in [BfvParameters::default_arc(6, 8)] { + for params in [BfvParameters::default_arc(6, 16)] { for _ in 0..100 { let sk = SecretKey::random(¶ms, &mut rng); let rk = RelinearizationKey::new(&sk, &mut rng)?; @@ -210,7 +210,7 @@ mod tests { println!("Noise: {}", unsafe { sk.measure_noise(&ct)? }); let pt = sk.try_decrypt(&ct)?; let w = Vec::::try_decode(&pt, Encoding::poly())?; - assert_eq!(w, &[0u64; 8]); + assert_eq!(w, &[0u64; 16]); } } Ok(()) @@ -219,7 +219,7 @@ mod tests { #[test] fn relinearization_leveled() -> Result<(), Box> { let mut rng = thread_rng(); - for params in [BfvParameters::default_arc(5, 8)] { + for params in [BfvParameters::default_arc(5, 16)] { for ciphertext_level in 0..params.max_level() { for key_level in 0..=ciphertext_level { for _ in 0..10 { @@ -271,7 +271,7 @@ mod tests { println!("Noise: {}", unsafe { sk.measure_noise(&ct)? }); let pt = sk.try_decrypt(&ct)?; let w = Vec::::try_decode(&pt, Encoding::poly())?; - assert_eq!(w, &[0u64; 8]); + assert_eq!(w, &[0u64; 16]); } } } @@ -283,8 +283,8 @@ mod tests { fn proto_conversion() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(3, 8), + BfvParameters::default_arc(6, 16), + BfvParameters::default_arc(3, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); let rk = RelinearizationKey::new(&sk, &mut rng)?; diff --git a/crates/fhe/src/bfv/keys/secret_key.rs b/crates/fhe/src/bfv/keys/secret_key.rs index 802af582..447b064b 100644 --- a/crates/fhe/src/bfv/keys/secret_key.rs +++ b/crates/fhe/src/bfv/keys/secret_key.rs @@ -226,7 +226,7 @@ mod tests { #[test] fn keygen() { let mut rng = thread_rng(); - let params = BfvParameters::default_arc(1, 8); + let params = BfvParameters::default_arc(1, 16); let sk = SecretKey::random(¶ms, &mut rng); assert_eq!(sk.par, params); @@ -240,8 +240,8 @@ mod tests { fn encrypt_decrypt() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { for level in 0..params.max_level() { for _ in 0..20 { diff --git a/crates/fhe/src/bfv/ops/dot_product.rs b/crates/fhe/src/bfv/ops/dot_product.rs index af0129b7..b80e7fba 100644 --- a/crates/fhe/src/bfv/ops/dot_product.rs +++ b/crates/fhe/src/bfv/ops/dot_product.rs @@ -167,8 +167,8 @@ mod tests { fn test_dot_product_scalar() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(2, 16), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(2, 32), ] { let sk = SecretKey::random(¶ms, &mut rng); for size in 1..128 { diff --git a/crates/fhe/src/bfv/ops/mod.rs b/crates/fhe/src/bfv/ops/mod.rs index 8985f632..2e0bce04 100644 --- a/crates/fhe/src/bfv/ops/mod.rs +++ b/crates/fhe/src/bfv/ops/mod.rs @@ -290,8 +290,8 @@ mod tests { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { let zero = Ciphertext::zero(¶ms); for _ in 0..50 { @@ -329,8 +329,8 @@ mod tests { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { for _ in 0..50 { let a = params.plaintext.random_vec(params.degree(), &mut rng); @@ -378,8 +378,8 @@ mod tests { fn sub() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { let zero = Ciphertext::zero(¶ms); for _ in 0..50 { @@ -424,8 +424,8 @@ mod tests { fn sub_scalar() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { for _ in 0..50 { let a = params.plaintext.random_vec(params.degree(), &mut rng); @@ -475,8 +475,8 @@ mod tests { fn neg() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { for _ in 0..50 { let a = params.plaintext.random_vec(params.degree(), &mut rng); @@ -508,8 +508,8 @@ mod tests { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 16), ] { for _ in 0..50 { let a = params.plaintext.random_vec(params.degree(), &mut rng); @@ -563,8 +563,8 @@ mod tests { fn mul() -> Result<(), Box> { let mut rng = thread_rng(); for par in [ - BfvParameters::default_arc(2, 8), - BfvParameters::default_arc(8, 8), + BfvParameters::default_arc(2, 16), + BfvParameters::default_arc(8, 16), ] { for _ in 0..1 { // We will encode `values` in an Simd format, and check that the product is @@ -600,7 +600,7 @@ mod tests { #[test] fn square() -> Result<(), Box> { let mut rng = thread_rng(); - let par = BfvParameters::default_arc(6, 8); + let par = BfvParameters::default_arc(6, 16); for _ in 0..20 { // We will encode `values` in an Simd format, and check that the product is // computed correctly. diff --git a/crates/fhe/src/bfv/ops/mul.rs b/crates/fhe/src/bfv/ops/mul.rs index 9097ee5c..b753251f 100644 --- a/crates/fhe/src/bfv/ops/mul.rs +++ b/crates/fhe/src/bfv/ops/mul.rs @@ -257,7 +257,7 @@ mod tests { #[test] fn mul() -> Result<(), Box> { let mut rng = thread_rng(); - let par = BfvParameters::default_arc(3, 8); + let par = BfvParameters::default_arc(3, 16); for _ in 0..30 { // We will encode `values` in an Simd format, and check that the product is // computed correctly. @@ -290,7 +290,7 @@ mod tests { #[test] fn mul_at_level() -> Result<(), Box> { let mut rng = thread_rng(); - let par = BfvParameters::default_arc(3, 8); + let par = BfvParameters::default_arc(3, 16); for _ in 0..15 { for level in 0..2 { let values = par.plaintext.random_vec(par.degree(), &mut rng); @@ -325,7 +325,7 @@ mod tests { #[test] fn mul_no_relin() -> Result<(), Box> { let mut rng = thread_rng(); - let par = BfvParameters::default_arc(6, 8); + let par = BfvParameters::default_arc(6, 16); for _ in 0..30 { // We will encode `values` in an Simd format, and check that the product is // computed correctly. @@ -362,7 +362,7 @@ mod tests { // Implement the second multiplication strategy from let mut rng = thread_rng(); - let par = BfvParameters::default_arc(3, 8); + let par = BfvParameters::default_arc(3, 16); let mut extended_basis = par.moduli().to_vec(); extended_basis .push(generate_prime(62, 2 * par.degree() as u64, extended_basis[2]).unwrap()); diff --git a/crates/fhe/src/bfv/parameters.rs b/crates/fhe/src/bfv/parameters.rs index 5cca0640..44afecc9 100644 --- a/crates/fhe/src/bfv/parameters.rs +++ b/crates/fhe/src/bfv/parameters.rs @@ -590,8 +590,8 @@ mod tests { // assert_eq!(params.ciphertext_moduli, vec![1153]); // assert_eq!(params.moduli(), vec![1153]); // assert_eq!(params.plaintext_modulus, 2); - // assert_eq!(params.polynomial_degree, 8); - // assert_eq!(params.degree(), 8); + // assert_eq!(params.polynomial_degree, 16); + // assert_eq!(params.degree(), 16); // assert_eq!(params.variance, 1); // assert!(params.op.is_none()); @@ -600,9 +600,9 @@ mod tests { #[test] fn default() { - let params = BfvParameters::default_arc(1, 8); + let params = BfvParameters::default_arc(1, 16); assert_eq!(params.moduli.len(), 1); - assert_eq!(params.degree(), 8); + assert_eq!(params.degree(), 16); let params = BfvParameters::default_arc(2, 16); assert_eq!(params.moduli.len(), 2); @@ -612,32 +612,32 @@ mod tests { #[test] fn ciphertext_moduli() -> Result<(), Box> { let params = BfvParametersBuilder::new() - .set_degree(8) + .set_degree(16) .set_plaintext_modulus(2) .set_moduli_sizes(&[62, 62, 62, 61, 60, 11]) .build()?; assert_eq!( params.moduli.to_vec(), &[ - 4611686018427387761, 4611686018427387617, - 4611686018427387409, + 4611686018427387329, + 4611686018427387073, 2305843009213693921, - 1152921504606846577, + 1152921504606845473, 2017 ] ); let params = BfvParametersBuilder::new() - .set_degree(8) + .set_degree(16) .set_plaintext_modulus(2) .set_moduli(&[ - 4611686018427387761, 4611686018427387617, - 4611686018427387409, + 4611686018427387329, + 4611686018427387073, 2305843009213693921, - 1152921504606846577, - 2017, + 1152921504606845473, + 2017 ]) .build()?; assert_eq!(params.moduli_sizes.to_vec(), &[62, 62, 62, 61, 60, 11]); @@ -648,7 +648,7 @@ mod tests { #[test] fn serialize() -> Result<(), Box> { let params = BfvParametersBuilder::new() - .set_degree(8) + .set_degree(16) .set_plaintext_modulus(2) .set_moduli_sizes(&[62, 62, 62, 61, 60, 11]) .set_variance(4) diff --git a/crates/fhe/src/bfv/plaintext.rs b/crates/fhe/src/bfv/plaintext.rs index 5983661b..ef37d7da 100644 --- a/crates/fhe/src/bfv/plaintext.rs +++ b/crates/fhe/src/bfv/plaintext.rs @@ -239,10 +239,10 @@ mod tests { fn try_encode() -> Result<(), Box> { let mut rng = thread_rng(); // The default test parameters support both Poly and Simd encodings - let params = BfvParameters::default_arc(1, 8); + let params = BfvParameters::default_arc(1, 16); let a = params.plaintext.random_vec(params.degree(), &mut rng); - let plaintext = Plaintext::try_encode(&[0u64; 9], Encoding::poly(), ¶ms); + let plaintext = Plaintext::try_encode(&[0u64; 17], Encoding::poly(), ¶ms); assert!(plaintext.is_err()); let plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms); @@ -256,7 +256,7 @@ mod tests { // The following parameters do not allow for Simd encoding let params = BfvParametersBuilder::new() - .set_degree(8) + .set_degree(16) .set_plaintext_modulus(2) .set_moduli(&[4611686018326724609]) .build_arc()?; @@ -275,7 +275,7 @@ mod tests { #[test] fn encode_decode() -> Result<(), Box> { let mut rng = thread_rng(); - let params = BfvParameters::default_arc(1, 8); + let params = BfvParameters::default_arc(1, 16); let a = params.plaintext.random_vec(params.degree(), &mut rng); let plaintext = Plaintext::try_encode(&a, Encoding::simd(), ¶ms); @@ -300,7 +300,7 @@ mod tests { #[test] fn partial_eq() -> Result<(), Box> { let mut rng = thread_rng(); - let params = BfvParameters::default_arc(1, 8); + let params = BfvParameters::default_arc(1, 16); let a = params.plaintext.random_vec(params.degree(), &mut rng); let plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms)?; @@ -320,7 +320,7 @@ mod tests { #[test] fn try_decode_errors() -> Result<(), Box> { let mut rng = thread_rng(); - let params = BfvParameters::default_arc(1, 8); + let params = BfvParameters::default_arc(1, 16); let a = params.plaintext.random_vec(params.degree(), &mut rng); let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms)?; @@ -355,10 +355,10 @@ mod tests { #[test] fn zero() -> Result<(), Box> { - let params = BfvParameters::default_arc(1, 8); + let params = BfvParameters::default_arc(1, 16); let plaintext = Plaintext::zero(Encoding::poly(), ¶ms)?; - assert_eq!(plaintext.value, Box::<[u64]>::from([0u64; 8])); + assert_eq!(plaintext.value, Box::<[u64]>::from([0u64; 16])); assert_eq!( plaintext.poly_ntt, Poly::zero(¶ms.ctx[0], Representation::Ntt) @@ -370,7 +370,7 @@ mod tests { #[test] fn zeroize() -> Result<(), Box> { let mut rng = thread_rng(); - let params = BfvParameters::default_arc(1, 8); + let params = BfvParameters::default_arc(1, 16); let a = params.plaintext.random_vec(params.degree(), &mut rng); let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms)?; @@ -385,7 +385,7 @@ mod tests { fn try_encode_level() -> Result<(), Box> { let mut rng = thread_rng(); // The default test parameters support both Poly and Simd encodings - let params = BfvParameters::default_arc(10, 8); + let params = BfvParameters::default_arc(10, 16); let a = params.plaintext.random_vec(params.degree(), &mut rng); for level in 0..10 { diff --git a/crates/fhe/src/bfv/plaintext_vec.rs b/crates/fhe/src/bfv/plaintext_vec.rs index 25f32ce5..8fdecc19 100644 --- a/crates/fhe/src/bfv/plaintext_vec.rs +++ b/crates/fhe/src/bfv/plaintext_vec.rs @@ -142,7 +142,7 @@ mod tests { let mut rng = thread_rng(); for _ in 0..20 { for i in 1..5 { - let params = BfvParameters::default_arc(1, 8); + let params = BfvParameters::default_arc(1, 16); let a = params.plaintext.random_vec(params.degree() * i, &mut rng); let plaintexts = PlaintextVec::try_encode(&a, Encoding::poly_at_level(0), ¶ms)?; diff --git a/crates/fhe/src/bfv/rgsw_ciphertext.rs b/crates/fhe/src/bfv/rgsw_ciphertext.rs index 5dcf1643..626fb9b1 100644 --- a/crates/fhe/src/bfv/rgsw_ciphertext.rs +++ b/crates/fhe/src/bfv/rgsw_ciphertext.rs @@ -160,8 +160,8 @@ mod tests { fn external_product() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(2, 8), - BfvParameters::default_arc(8, 8), + BfvParameters::default_arc(2, 16), + BfvParameters::default_arc(8, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); let v1 = params.plaintext.random_vec(params.degree(), &mut rng); @@ -192,8 +192,8 @@ mod tests { fn serialize() -> Result<(), Box> { let mut rng = thread_rng(); for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(5, 8), + BfvParameters::default_arc(6, 16), + BfvParameters::default_arc(5, 16), ] { let sk = SecretKey::random(¶ms, &mut rng); let v = params.plaintext.random_vec(params.degree(), &mut rng); diff --git a/crates/fhe/src/mbfv/public_key_gen.rs b/crates/fhe/src/mbfv/public_key_gen.rs index 2335b58b..cb778390 100644 --- a/crates/fhe/src/mbfv/public_key_gen.rs +++ b/crates/fhe/src/mbfv/public_key_gen.rs @@ -98,8 +98,8 @@ mod tests { fn protocol_creates_valid_pk() { let mut rng = thread_rng(); for par in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 32), ] { for level in 0..=par.max_level() { for _ in 0..20 { diff --git a/crates/fhe/src/mbfv/public_key_switch.rs b/crates/fhe/src/mbfv/public_key_switch.rs index fc5f355b..e4fd84f6 100644 --- a/crates/fhe/src/mbfv/public_key_switch.rs +++ b/crates/fhe/src/mbfv/public_key_switch.rs @@ -135,8 +135,8 @@ mod tests { fn encrypt_keyswitch_decrypt() { let mut rng = thread_rng(); for par in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 32), ] { for level in 0..=par.max_level() { for _ in 0..20 { diff --git a/crates/fhe/src/mbfv/relin_key_gen.rs b/crates/fhe/src/mbfv/relin_key_gen.rs index d58d8ad0..6adf2d34 100644 --- a/crates/fhe/src/mbfv/relin_key_gen.rs +++ b/crates/fhe/src/mbfv/relin_key_gen.rs @@ -383,8 +383,8 @@ mod tests { fn relinearization_works() { let mut rng = thread_rng(); for par in [ - BfvParameters::default_arc(3, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(3, 16), + BfvParameters::default_arc(6, 32), ] { // Just support level 0 for now. let level = 0; diff --git a/crates/fhe/src/mbfv/secret_key_switch.rs b/crates/fhe/src/mbfv/secret_key_switch.rs index 4338d31a..68381bd8 100644 --- a/crates/fhe/src/mbfv/secret_key_switch.rs +++ b/crates/fhe/src/mbfv/secret_key_switch.rs @@ -203,8 +203,8 @@ mod tests { fn encrypt_decrypt() { let mut rng = thread_rng(); for par in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 32), ] { for level in 0..=par.max_level() { for _ in 0..20 { @@ -250,8 +250,8 @@ mod tests { fn encrypt_keyswitch_decrypt() { let mut rng = thread_rng(); for par in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 32), ] { for level in 0..=par.max_level() { for _ in 0..20 { @@ -318,8 +318,8 @@ mod tests { fn collective_keys_enable_homomorphic_addition() { let mut rng = thread_rng(); for par in [ - BfvParameters::default_arc(1, 8), - BfvParameters::default_arc(6, 8), + BfvParameters::default_arc(1, 16), + BfvParameters::default_arc(6, 32), ] { for level in 0..=par.max_level() { for _ in 0..20 { From ca4a2564dc70d6e86b4e6be2a0f03a36d4943a96 Mon Sep 17 00:00:00 2001 From: fionser Date: Sat, 13 Jan 2024 23:35:09 +0800 Subject: [PATCH 3/4] remove temporary file --- .../bfv/keys/.key_switching_key.rs.rustfmt | 549 ------------------ 1 file changed, 549 deletions(-) delete mode 100644 crates/fhe/src/bfv/keys/.key_switching_key.rs.rustfmt diff --git a/crates/fhe/src/bfv/keys/.key_switching_key.rs.rustfmt b/crates/fhe/src/bfv/keys/.key_switching_key.rs.rustfmt deleted file mode 100644 index e3a03098..00000000 --- a/crates/fhe/src/bfv/keys/.key_switching_key.rs.rustfmt +++ /dev/null @@ -1,549 +0,0 @@ -//! Key-switching keys for the BFV encryption scheme - -use crate::bfv::{traits::TryConvertFrom as BfvTryConvertFrom, BfvParameters, SecretKey}; -use crate::proto::bfv::KeySwitchingKey as KeySwitchingKeyProto; -use crate::{Error, Result}; -use fhe_math::rq::traits::TryConvertFrom; -use fhe_math::rq::Context; -use fhe_math::{ - rns::RnsContext, - rq::{Poly, Representation}, -}; -use fhe_traits::{DeserializeWithContext, Serialize}; -use itertools::{izip, Itertools}; -use num_bigint::BigUint; -use rand::{CryptoRng, Rng, RngCore, SeedableRng}; -use rand_chacha::ChaCha8Rng; -use std::sync::Arc; -use zeroize::Zeroizing; - -/// Key switching key for the BFV encryption scheme. -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct KeySwitchingKey { - /// The parameters of the underlying BFV encryption scheme. - pub(crate) par: Arc, - - /// The (optional) seed that generated the polynomials c1. - pub(crate) seed: Option<::Seed>, - - /// The key switching elements c0. - pub(crate) c0: Box<[Poly]>, - - /// The key switching elements c1. - pub(crate) c1: Box<[Poly]>, - - /// The level and context of the polynomials that will be key switched. - pub(crate) ciphertext_level: usize, - pub(crate) ctx_ciphertext: Arc, - - /// The level and context of the key switching key. - pub(crate) ksk_level: usize, - pub(crate) ctx_ksk: Arc, - - // For level with only one modulus, we will use basis - pub(crate) log_base: usize, -} - -impl KeySwitchingKey { - /// Generate a [`KeySwitchingKey`] to this [`SecretKey`] from a polynomial - /// `from`. - pub fn new( - sk: &SecretKey, - from: &Poly, - ciphertext_level: usize, - ksk_level: usize, - rng: &mut R, - ) -> Result { - let ctx_ksk = sk.par.ctx_at_level(ksk_level)?; - let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?; - - if from.ctx() != ctx_ksk { - return Err(Error::DefaultError( - "Incorrect context for polynomial from".to_string(), - )); - } - - let mut seed = ::Seed::default(); - rng.fill(&mut seed); - - if ctx_ksk.moduli().len() == 1 { - let modulus = ctx_ksk.moduli().first().unwrap(); - let log_modulus = modulus.next_power_of_two().ilog2() as usize; - let log_base = log_modulus / 2; - - let c1 = Self::generate_c1(ctx_ksk, seed, (log_modulus + log_base - 1) / log_base); - let c0 = Self::generate_c0_decomposition(sk, from, &c1, rng, log_base)?; - - Ok(Self { - par: sk.par.clone(), - seed: Some(seed), - c0: c0.into_boxed_slice(), - c1: c1.into_boxed_slice(), - ciphertext_level, - ctx_ciphertext: ctx_ciphertext.clone(), - ksk_level, - ctx_ksk: ctx_ksk.clone(), - log_base, - }) - } else { - let c1 = Self::generate_c1(ctx_ksk, seed, ctx_ciphertext.moduli().len()); - let c0 = Self::generate_c0(sk, from, &c1, rng)?; - - Ok(Self { - par: sk.par.clone(), - seed: Some(seed), - c0: c0.into_boxed_slice(), - c1: c1.into_boxed_slice(), - ciphertext_level, - ctx_ciphertext: ctx_ciphertext.clone(), - ksk_level, - ctx_ksk: ctx_ksk.clone(), - log_base: 0, - }) - } - } - - /// Generate the c1's from the seed - fn generate_c1( - ctx: &Arc, - seed: ::Seed, - size: usize, - ) -> Vec { - let mut c1 = Vec::with_capacity(size); - let mut rng = ChaCha8Rng::from_seed(seed); - (0..size).for_each(|_| { - let mut seed_i = ::Seed::default(); - rng.fill(&mut seed_i); - let mut a = Poly::random_from_seed(ctx, Representation::NttShoup, seed_i); - unsafe { a.allow_variable_time_computations() } - c1.push(a); - }); - c1 - } - - /// Generate the c0's from the c1's and the secret key - fn generate_c0( - sk: &SecretKey, - from: &Poly, - c1: &[Poly], - rng: &mut R, - ) -> Result> { - if c1.is_empty() { - return Err(Error::DefaultError("Empty number of c1's".to_string())); - } - if from.representation() != &Representation::PowerBasis { - return Err(Error::DefaultError( - "Unexpected representation for from".to_string(), - )); - } - - let size = c1.len(); - - let mut s = Zeroizing::new(Poly::try_convert_from( - sk.coeffs.as_ref(), - c1[0].ctx(), - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); - - let rns = RnsContext::new(&sk.par.moduli[..size])?; - let c0 = c1 - .iter() - .enumerate() - .map(|(i, c1i)| { - let mut a_s = Zeroizing::new(c1i.clone()); - a_s.disallow_variable_time_computations(); - a_s.change_representation(Representation::Ntt); - *a_s.as_mut() *= s.as_ref(); - a_s.change_representation(Representation::PowerBasis); - - let mut b = - Poly::small(a_s.ctx(), Representation::PowerBasis, sk.par.variance, rng)?; - b -= &a_s; - - let gi = rns.get_garner(i).unwrap(); - let g_i_from = Zeroizing::new(gi * from); - b += &g_i_from; - - // It is now safe to enable variable time computations. - unsafe { b.allow_variable_time_computations() } - b.change_representation(Representation::NttShoup); - Ok(b) - }) - .collect::>>()?; - - Ok(c0) - } - - /// Generate the c0's from the c1's and the secret key - fn generate_c0_decomposition( - sk: &SecretKey, - from: &Poly, - c1: &[Poly], - rng: &mut R, - log_base: usize, - ) -> Result> { - if c1.is_empty() { - return Err(Error::DefaultError("Empty number of c1's".to_string())); - } - - if from.representation() != &Representation::PowerBasis { - return Err(Error::DefaultError( - "Unexpected representation for from".to_string(), - )); - } - - let mut s = Zeroizing::new(Poly::try_convert_from( - sk.coeffs.as_ref(), - c1[0].ctx(), - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); - - let c0 = c1 - .iter() - .enumerate() - .map(|(i, c1i)| { - let mut a_s = Zeroizing::new(c1i.clone()); - a_s.disallow_variable_time_computations(); - a_s.change_representation(Representation::Ntt); - *a_s.as_mut() *= s.as_ref(); - a_s.change_representation(Representation::PowerBasis); - - let mut b = - Poly::small(a_s.ctx(), Representation::PowerBasis, sk.par.variance, rng)?; - b -= &a_s; - - let power = BigUint::from(1u64 << (i * log_base)); - b += &(from * &power); - - // It is now safe to enable variable time computations. - unsafe { b.allow_variable_time_computations() } - b.change_representation(Representation::NttShoup); - Ok(b) - }) - .collect::>>()?; - - Ok(c0) - } - - /// Key switch a polynomial. - pub fn key_switch(&self, p: &Poly) -> Result<(Poly, Poly)> { - if self.log_base != 0 { - return self.key_switch_decomposition(p); - } - - if p.ctx().as_ref() != self.ctx_ciphertext.as_ref() { - return Err(Error::DefaultError( - "The input polynomial does not have the correct context.".to_string(), - )); - } - if p.representation() != &Representation::PowerBasis { - return Err(Error::DefaultError("Incorrect representation".to_string())); - } - - let mut c0 = Poly::zero(&self.ctx_ksk, Representation::Ntt); - let mut c1 = Poly::zero(&self.ctx_ksk, Representation::Ntt); - for (c2_i_coefficients, c0_i, c1_i) in izip!( - p.coefficients().outer_iter(), - self.c0.iter(), - self.c1.iter() - ) { - let mut c2_i = unsafe { - Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( - c2_i_coefficients.as_slice().unwrap(), - &self.ctx_ksk, - ) - }; - c0 += &(&c2_i * c0_i); - c2_i *= c1_i; - c1 += &c2_i; - } - Ok((c0, c1)) - } - - /// Key switch a polynomial. - fn key_switch_decomposition(&self, p: &Poly) -> Result<(Poly, Poly)> { - if p.ctx().as_ref() != self.ctx_ciphertext.as_ref() { - return Err(Error::DefaultError( - "The input polynomial does not have the correct context.".to_string(), - )); - } - if p.representation() != &Representation::PowerBasis { - return Err(Error::DefaultError("Incorrect representation".to_string())); - } - - let log_modulus = p - .ctx() - .moduli() - .first() - .unwrap() - .next_power_of_two() - .ilog2() as usize; - - let mut coefficients = p.coefficients().to_slice().unwrap().to_vec(); - let mut c2i = vec![]; - let mask = (1u64 << self.log_base) - 1; - (0..(log_modulus + self.log_base - 1) / self.log_base).for_each(|_| { - c2i.push(coefficients.iter().map(|c| c & mask).collect_vec()); - coefficients.iter_mut().for_each(|c| *c >>= self.log_base); - }); - - let mut c0 = Poly::zero(&self.ctx_ksk, Representation::Ntt); - let mut c1 = Poly::zero(&self.ctx_ksk, Representation::Ntt); - for (c2_i_coefficients, c0_i, c1_i) in izip!(c2i.iter(), self.c0.iter(), self.c1.iter()) { - let mut c2_i = unsafe { - Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( - c2_i_coefficients.as_slice(), - &self.ctx_ksk, - ) - }; - c0 += &(&c2_i * c0_i); - c2_i *= c1_i; - c1 += &c2_i; - } - Ok((c0, c1)) - } -} - -impl From<&KeySwitchingKey> for KeySwitchingKeyProto { - fn from(value: &KeySwitchingKey) -> Self { - let mut ksk = KeySwitchingKeyProto::default(); - if let Some(seed) = value.seed.as_ref() { - ksk.seed = seed.to_vec(); - } else { - ksk.c1.reserve_exact(value.c1.len()); - for c1 in value.c1.iter() { - ksk.c1.push(c1.to_bytes()) - } - } - ksk.c0.reserve_exact(value.c0.len()); - for c0 in value.c0.iter() { - ksk.c0.push(c0.to_bytes()) - } - ksk.ciphertext_level = value.ciphertext_level as u32; - ksk.ksk_level = value.ksk_level as u32; - ksk.log_base = value.log_base as u32; - ksk - } -} - -impl BfvTryConvertFrom<&KeySwitchingKeyProto> for KeySwitchingKey { - fn try_convert_from(value: &KeySwitchingKeyProto, par: &Arc) -> Result { - let ciphertext_level = value.ciphertext_level as usize; - let ksk_level = value.ksk_level as usize; - let ctx_ksk = par.ctx_at_level(ksk_level)?; - let ctx_ciphertext = par.ctx_at_level(ciphertext_level)?; - - let c0_size: usize; - let log_base = value.log_base as usize; - if log_base != 0 { - if ksk_level != par.max_level() || ciphertext_level != par.max_level() { - return Err(Error::DefaultError( - "A decomposition size is specified but the levels are not maximal".to_string(), - )); - } else { - let log_modulus: usize = - par.moduli().first().unwrap().next_power_of_two().ilog2() as usize; - c0_size = (log_modulus + log_base - 1) / log_base; - } - } else { - c0_size = ctx_ciphertext.moduli().len(); - } - - if value.c0.len() != c0_size { - return Err(Error::DefaultError( - "Incorrect number of values in c0".to_string(), - )); - } - - let seed = if value.seed.is_empty() { - if value.c1.len() != c0_size { - return Err(Error::DefaultError( - "Incorrect number of values in c1".to_string(), - )); - } - None - } else { - let unwrapped = ::Seed::try_from(value.seed.clone()); - if unwrapped.is_err() { - return Err(Error::DefaultError("Invalid seed".to_string())); - } - Some(unwrapped.unwrap()) - }; - - let c1 = if let Some(seed) = seed { - Self::generate_c1(ctx_ksk, seed, value.c0.len()) - } else { - value - .c1 - .iter() - .map(|c1i| Poly::from_bytes(c1i, ctx_ksk).map_err(Error::MathError)) - .collect::>>()? - }; - - let c0 = value - .c0 - .iter() - .map(|c0i| Poly::from_bytes(c0i, ctx_ksk).map_err(Error::MathError)) - .collect::>>()?; - - Ok(Self { - par: par.clone(), - seed, - c0: c0.into_boxed_slice(), - c1: c1.into_boxed_slice(), - ciphertext_level, - ctx_ciphertext: ctx_ciphertext.clone(), - ksk_level, - ctx_ksk: ctx_ksk.clone(), - log_base: value.log_base as usize, - }) - } -} - -#[cfg(test)] -mod tests { - use crate::bfv::{ - keys::key_switching_key::KeySwitchingKey, traits::TryConvertFrom, BfvParameters, SecretKey, - }; - use crate::proto::bfv::KeySwitchingKey as KeySwitchingKeyProto; - use fhe_math::{ - rns::RnsContext, - rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation}, - }; - use num_bigint::BigUint; - use rand::thread_rng; - use std::error::Error; - - #[test] - fn constructor() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(3, 8), - ] { - let sk = SecretKey::random(¶ms, &mut rng); - let ctx = params.ctx_at_level(0)?; - let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; - let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng); - assert!(ksk.is_ok()); - } - Ok(()) - } - - #[test] - fn constructor_last_level() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(3, 8), - ] { - let level = params.moduli().len() - 1; - let sk = SecretKey::random(¶ms, &mut rng); - let ctx = params.ctx_at_level(level)?; - let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; - let ksk = KeySwitchingKey::new(&sk, &p, level, level, &mut rng); - assert!(ksk.is_ok()); - } - Ok(()) - } - - #[test] - fn key_switch() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [BfvParameters::default_arc(6, 8)] { - for _ in 0..100 { - let sk = SecretKey::random(¶ms, &mut rng); - let ctx = params.ctx_at_level(0)?; - let mut p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; - let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?; - let mut s = Poly::try_convert_from( - sk.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - ) - .map_err(crate::Error::MathError)?; - s.change_representation(Representation::Ntt); - - let mut input = Poly::random(ctx, Representation::PowerBasis, &mut rng); - let (c0, c1) = ksk.key_switch(&input)?; - - let mut c2 = &c0 + &(&c1 * &s); - c2.change_representation(Representation::PowerBasis); - - input.change_representation(Representation::Ntt); - p.change_representation(Representation::Ntt); - let mut c3 = &input * &p; - c3.change_representation(Representation::PowerBasis); - - let rns = RnsContext::new(¶ms.moduli)?; - Vec::::from(&(&c2 - &c3)).iter().for_each(|b| { - assert!(std::cmp::min(b.bits(), (rns.modulus() - b).bits()) <= 70) - }); - } - } - Ok(()) - } - - #[test] - fn key_switch_decomposition() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [BfvParameters::default_arc(6, 8)] { - for _ in 0..100 { - let sk = SecretKey::random(¶ms, &mut rng); - let ctx = params.ctx_at_level(5)?; - let mut p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; - let ksk = KeySwitchingKey::new(&sk, &p, 5, 5, &mut rng)?; - let mut s = Poly::try_convert_from( - sk.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - ) - .map_err(crate::Error::MathError)?; - s.change_representation(Representation::Ntt); - - let mut input = Poly::random(ctx, Representation::PowerBasis, &mut rng); - let (c0, c1) = ksk.key_switch(&input)?; - - let mut c2 = &c0 + &(&c1 * &s); - c2.change_representation(Representation::PowerBasis); - - input.change_representation(Representation::Ntt); - p.change_representation(Representation::Ntt); - let mut c3 = &input * &p; - c3.change_representation(Representation::PowerBasis); - - let rns = RnsContext::new(ctx.moduli())?; - Vec::::from(&(&c2 - &c3)).iter().for_each(|b| { - assert!( - std::cmp::min(b.bits(), (rns.modulus() - b).bits()) - <= (rns.modulus().bits() / 2) + 10 - ) - }); - } - } - Ok(()) - } - - #[test] - fn proto_conversion() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - BfvParameters::default_arc(6, 8), - BfvParameters::default_arc(3, 8), - ] { - let sk = SecretKey::random(¶ms, &mut rng); - let ctx = params.ctx_at_level(0)?; - let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; - let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?; - let ksk_proto = KeySwitchingKeyProto::from(&ksk); - assert_eq!(ksk, KeySwitchingKey::try_convert_from(&ksk_proto, ¶ms)?); - } - Ok(()) - } -} From 3e3a6ed99097651c5ba4c9b8f408d74b0f336e06 Mon Sep 17 00:00:00 2001 From: "juhou.lwj" Date: Sun, 14 Jan 2024 10:28:18 +0800 Subject: [PATCH 4/4] fmt --- crates/fhe-math/src/rq/convert.rs | 4 ++-- crates/fhe-math/src/rq/ops.rs | 2 +- crates/fhe/src/bfv/parameters.rs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/fhe-math/src/rq/convert.rs b/crates/fhe-math/src/rq/convert.rs index 0f505ab5..efe1fb80 100644 --- a/crates/fhe-math/src/rq/convert.rs +++ b/crates/fhe-math/src/rq/convert.rs @@ -522,8 +522,8 @@ mod tests { assert!(Poly::try_convert_from(&[0u64], &ctx, false, Representation::Ntt).is_err()); assert!(Poly::try_convert_from(&[0i64], &ctx, false, Representation::Ntt).is_err()); assert_eq!( - Poly::try_convert_from(&[0u64; 16], &ctx, false, Representation::Ntt)?, - Poly::zero(&ctx, Representation::Ntt) + Poly::try_convert_from(&[0u64; 16], &ctx, false, Representation::Ntt)?, + Poly::zero(&ctx, Representation::Ntt) ); assert!(Poly::try_convert_from(&[0i64; 16], &ctx, false, Representation::Ntt).is_err()); assert!(Poly::try_convert_from( diff --git a/crates/fhe-math/src/rq/ops.rs b/crates/fhe-math/src/rq/ops.rs index 74da4c56..61b4a299 100644 --- a/crates/fhe-math/src/rq/ops.rs +++ b/crates/fhe-math/src/rq/ops.rs @@ -502,7 +502,7 @@ mod tests { let b = Vec::::from(&q); for i in 0..MODULI.len() { let m = Modulus::new(MODULI[i]).unwrap(); - m.add_vec(&mut a[i *16..(i + 1) * 16], &b[i * 16..(i + 1) * 16]) + m.add_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16]) } let r = &p + &q; assert_eq!(r.representation, Representation::PowerBasis); diff --git a/crates/fhe/src/bfv/parameters.rs b/crates/fhe/src/bfv/parameters.rs index 44afecc9..600e9263 100644 --- a/crates/fhe/src/bfv/parameters.rs +++ b/crates/fhe/src/bfv/parameters.rs @@ -637,7 +637,7 @@ mod tests { 4611686018427387073, 2305843009213693921, 1152921504606845473, - 2017 + 2017, ]) .build()?; assert_eq!(params.moduli_sizes.to_vec(), &[62, 62, 62, 61, 60, 11]);