From e7880ea1c6417a3400b65c9c5cd0a52555ea6884 Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Sun, 16 Jun 2024 11:24:50 +0300 Subject: [PATCH] QM31 --- stwo_cairo_verifier/src/fields.cairo | 1 + stwo_cairo_verifier/src/fields/cm31.cairo | 20 +++- stwo_cairo_verifier/src/fields/m31.cairo | 28 +++++- stwo_cairo_verifier/src/fields/qm31.cairo | 108 ++++++++++++++++++++++ 4 files changed, 151 insertions(+), 6 deletions(-) create mode 100644 stwo_cairo_verifier/src/fields/qm31.cairo diff --git a/stwo_cairo_verifier/src/fields.cairo b/stwo_cairo_verifier/src/fields.cairo index 146e65d6..67843e9c 100644 --- a/stwo_cairo_verifier/src/fields.cairo +++ b/stwo_cairo_verifier/src/fields.cairo @@ -1,4 +1,5 @@ pub mod m31; pub mod cm31; +pub mod qm31; pub type BaseField = m31::M31; diff --git a/stwo_cairo_verifier/src/fields/cm31.cairo b/stwo_cairo_verifier/src/fields/cm31.cairo index 8a6d8700..0f17441d 100644 --- a/stwo_cairo_verifier/src/fields/cm31.cairo +++ b/stwo_cairo_verifier/src/fields/cm31.cairo @@ -1,9 +1,19 @@ -use super::m31::{M31, m31}; +use core::num::traits::{One, Zero}; +use super::m31::{M31, m31, M31Trait}; #[derive(Copy, Drop, Debug, PartialEq, Eq)] pub struct CM31 { - a: M31, - b: M31, + pub a: M31, + pub b: M31, +} + +#[generate_trait] +pub impl CM31Impl of CM31Trait { + fn inverse(self: CM31) -> CM31 { + assert_ne!(self, Zero::zero()); + let denom_inverse: M31 = (self.a * self.a + self.b * self.b).inverse(); + CM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse } + } } pub impl CM31Add of core::traits::Add { @@ -21,7 +31,7 @@ pub impl CM31Mul of core::traits::Mul { CM31 { a: lhs.a * rhs.a - lhs.b * rhs.b, b: lhs.a * rhs.b + lhs.b * rhs.a } } } -pub impl CM31Zero of core::num::traits::Zero { +pub impl CM31Zero of Zero { fn zero() -> CM31 { cm31(0, 0) } @@ -32,7 +42,7 @@ pub impl CM31Zero of core::num::traits::Zero { (*self).a.is_non_zero() || (*self).b.is_non_zero() } } -pub impl CM31One of core::num::traits::One { +pub impl CM31One of One { fn one() -> CM31 { cm31(1, 0) } diff --git a/stwo_cairo_verifier/src/fields/m31.cairo b/stwo_cairo_verifier/src/fields/m31.cairo index 20d97db4..112e283c 100644 --- a/stwo_cairo_verifier/src/fields/m31.cairo +++ b/stwo_cairo_verifier/src/fields/m31.cairo @@ -22,6 +22,25 @@ pub impl M31Impl of M31Trait { let (_, res) = core::integer::u64_safe_divmod(val, P64NZ); M31 { inner: res.try_into().unwrap() } } + + #[inline] + fn sqn(v: M31, n: usize) -> M31 { + if n == 0 { + return v; + } + Self::sqn(v * v, n - 1) + } + + fn inverse(self: M31) -> M31 { + assert_ne!(self, core::num::traits::Zero::zero()); + let t0 = Self::sqn(self, 2) * self; + let t1 = Self::sqn(t0, 1) * t0; + let t2 = Self::sqn(t1, 3) * t0; + let t3 = Self::sqn(t2, 1) * t0; + let t4 = Self::sqn(t3, 8) * t3; + let t5 = Self::sqn(t4, 8) * t3; + Self::sqn(t5, 7) * t2 + } } pub impl M31Add of core::traits::Add { fn add(lhs: M31, rhs: M31) -> M31 { @@ -78,7 +97,9 @@ pub fn m31(val: u32) -> M31 { #[cfg(test)] mod tests { - use super::{m31, P}; + use super::{m31, P, M31, M31Trait}; + const POW2_15: u32 = 0b1000000000000000; + const POW2_16: u32 = 0b10000000000000000; #[test] fn test_m31() { @@ -90,4 +111,9 @@ mod tests { assert_eq!(m31(0) - m31(1), m31(P - 1)); assert_eq!(m31(0) - m31(P - 1), m31(1)); } + + #[test] + fn test_m31_inv() { + assert_eq!(m31(POW2_15).inverse(), m31(POW2_16)); + } } diff --git a/stwo_cairo_verifier/src/fields/qm31.cairo b/stwo_cairo_verifier/src/fields/qm31.cairo new file mode 100644 index 00000000..8ddd7336 --- /dev/null +++ b/stwo_cairo_verifier/src/fields/qm31.cairo @@ -0,0 +1,108 @@ +use super::m31::{M31, m31}; +use super::cm31::{CM31, cm31, CM31Trait}; +use core::num::traits::zero::Zero; +use core::num::traits::one::One; + +pub const R: CM31 = CM31 { a: M31 { inner: 2 }, b: M31 { inner: 1 } }; + +#[derive(Copy, Drop, Debug, PartialEq, Eq)] +pub struct QM31 { + a: CM31, + b: CM31, +} + +#[generate_trait] +impl QM31Impl of QM31Trait { + fn inverse(self: QM31) -> QM31 { + assert_ne!(self, Zero::zero()); + let b2 = self.b * self.b; + let ib2 = CM31 { a: -b2.b, b: b2.a }; + let denom = self.a * self.a - (b2 + b2 + ib2); + let denom_inverse = denom.inverse(); + QM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse } + } +} + +pub impl QM31Add of core::traits::Add { + fn add(lhs: QM31, rhs: QM31) -> QM31 { + QM31 { a: lhs.a + rhs.a, b: lhs.b + rhs.b } + } +} +pub impl QM31Sub of core::traits::Sub { + fn sub(lhs: QM31, rhs: QM31) -> QM31 { + QM31 { a: lhs.a - rhs.a, b: lhs.b - rhs.b } + } +} +pub impl QM31Mul of core::traits::Mul { + fn mul(lhs: QM31, rhs: QM31) -> QM31 { + // (a + bu) * (c + du) = (ac + rbd) + (ad + bc)u. + QM31 { a: lhs.a * rhs.a + R * lhs.b * rhs.b, b: lhs.a * rhs.b + lhs.b * rhs.a } + } +} +pub impl QM31Zero of Zero { + fn zero() -> QM31 { + QM31 { a: Zero::zero(), b: Zero::zero() } + } + fn is_zero(self: @QM31) -> bool { + (*self).a.is_zero() && (*self).b.is_zero() + } + fn is_non_zero(self: @QM31) -> bool { + (*self).a.is_non_zero() || (*self).b.is_non_zero() + } +} +pub impl QM31One of One { + fn one() -> QM31 { + QM31 { a: One::one(), b: Zero::zero() } + } + fn is_one(self: @QM31) -> bool { + (*self).a.is_one() && (*self).b.is_zero() + } + fn is_non_one(self: @QM31) -> bool { + (*self).a.is_non_one() || (*self).b.is_non_zero() + } +} +pub impl M31IntoQM31 of core::traits::Into { + fn into(self: M31) -> QM31 { + QM31 { a: self.into(), b: Zero::zero() } + } +} +pub impl CM31IntoQM31 of core::traits::Into { + fn into(self: CM31) -> QM31 { + QM31 { a: self, b: Zero::zero() } + } +} +pub impl QM31Neg of Neg { + fn neg(a: QM31) -> QM31 { + QM31 { a: -a.a, b: -a.b } + } +} + +pub fn qm31(a: u32, b: u32, c: u32, d: u32) -> QM31 { + QM31 { a: cm31(a, b), b: cm31(c, d) } +} + + +#[cfg(test)] +mod tests { + use super::{QM31, qm31, QM31Trait}; + use super::super::m31::{m31, P, M31Trait}; + + #[test] + fn test_QM31() { + let qm0 = qm31(1, 2, 3, 4); + let qm1 = qm31(4, 5, 6, 7); + let m = m31(8); + let qm = Into::<_, QM31>::into(m); + let qm0_x_qm1 = qm31(P - 71, 93, P - 16, 50); + + assert_eq!(qm0 + qm1, qm31(5, 7, 9, 11)); + assert_eq!(qm1 + m.into(), qm1 + qm); + assert_eq!(qm0 * qm1, qm0_x_qm1); + assert_eq!(qm1 * m.into(), qm1 * qm); + assert_eq!(-qm0, qm31(P - 1, P - 2, P - 3, P - 4)); + assert_eq!(qm0 - qm1, qm31(P - 3, P - 3, P - 3, P - 3)); + assert_eq!(qm1 - m.into(), qm1 - qm); + assert_eq!(qm0_x_qm1 * qm1.inverse(), qm31(1, 2, 3, 4)); + assert_eq!(qm1 * m.inverse().into(), qm1 * qm.inverse()); + } +}