Skip to content

Commit

Permalink
Generalize field in CirclePoint
Browse files Browse the repository at this point in the history
  • Loading branch information
atgrosso committed Oct 4, 2024
1 parent 7385e7d commit ca1ea82
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 47 deletions.
201 changes: 163 additions & 38 deletions stwo_cairo_verifier/src/circle.cairo
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use stwo_cairo_verifier::fields::m31::{M31, m31, M31One};
use stwo_cairo_verifier::fields::m31::M31;
use stwo_cairo_verifier::fields::qm31::{QM31, QM31Trait};
use super::utils::pow;
use core::num::traits::zero::Zero;
use core::num::traits::one::One;

pub const M31_CIRCLE_GEN: CirclePointM31 =
CirclePointM31 { x: M31 { inner: 2 }, y: M31 { inner: 1268011823 }, };
pub const M31_CIRCLE_GEN: CirclePoint<M31> =
CirclePoint::<M31> { x: M31 { inner: 2 }, y: M31 { inner: 1268011823 }, };

pub const CIRCLE_LOG_ORDER: u32 = 31;

Expand All @@ -15,17 +18,19 @@ pub const CIRCLE_ORDER_BIT_MASK: u32 = 0x7fffffff;
// `U32_BIT_MASK` equals 2^32 - 1
pub const U32_BIT_MASK: u64 = 0xffffffff;

/// A point on the complex circle. Treated as an additive group.
#[derive(Drop, Copy, Debug, PartialEq, Eq)]
pub struct CirclePointM31 {
pub x: M31,
pub y: M31,
pub struct CirclePoint<F> {
pub x: F,
pub y: F
}

#[generate_trait]
pub impl CirclePointM31Impl of CirclePointM31Trait {
pub trait CirclePointTrait<
F, +Add<F>, +Sub<F>, +Mul<F>, +Drop<F>, +Copy<F>, +Zero<F>, +One<F>, +PartialEq<F>
> {
// Returns the neutral element of the circle.
fn zero() -> CirclePointM31 {
CirclePointM31 { x: m31(1), y: m31(0) }
fn zero() -> CirclePoint<F> {
CirclePoint::<F> { x: One::<F>::one(), y: Zero::<F>::zero() }
}

/// Applies the circle's x-coordinate doubling map.
Expand All @@ -38,9 +43,9 @@ pub impl CirclePointM31Impl of CirclePointM31Trait {
/// let p = M31_CIRCLE_GEN.mul(17);
/// assert_eq!(CirclePoint::double_x(p.x), (p + p).x);
/// ```
fn double_x(x: M31) -> M31 {
fn double_x(x: F) -> F {
let sx = x.clone() * x.clone();
sx.clone() + sx - M31One::one()
sx.clone() + sx - One::<F>::one()
}

/// Returns the log order of a point.
Expand All @@ -54,21 +59,28 @@ pub impl CirclePointM31Impl of CirclePointM31Trait {
/// use stwo_prover::core::fields::m31::M31;
/// assert_eq!(M31_CIRCLE_GEN.log_order(), M31_CIRCLE_LOG_ORDER);
/// ```
fn log_order(self: @CirclePointM31) -> u32 {
fn log_order(
self: @CirclePoint<F>
) -> u32 {
// we only need the x-coordinate to check order since the only point
// with x=1 is the circle's identity
let mut res = 0;
let mut cur = self.x.clone();
while cur != M31One::one() {
while cur != One::<F>::one() {
cur = Self::double_x(cur);
res += 1;
};
res
}

fn mul(self: @CirclePointM31, mut scalar: u32) -> CirclePointM31 {
let mut result = Self::zero();
let mut cur = *self;
fn mul(
self: @CirclePoint<F>, initial_scalar: u128
) -> CirclePoint<
F
> {
let mut scalar = initial_scalar;
let mut result: CirclePoint<F> = Self::zero();
let mut cur: CirclePoint<F> = *self;
while scalar > 0 {
if scalar & 1 == 1 {
result = result + cur;
Expand All @@ -80,13 +92,28 @@ pub impl CirclePointM31Impl of CirclePointM31Trait {
}
}

impl CirclePointM31Add of Add<CirclePointM31> {
impl CirclePointAdd<F, +Add<F>, +Sub<F>, +Mul<F>, +Drop<F>, +Copy<F>> of Add<CirclePoint<F>> {
// The operation of the circle as a group with additive notation.
fn add(lhs: CirclePointM31, rhs: CirclePointM31) -> CirclePointM31 {
CirclePointM31 { x: lhs.x * rhs.x - lhs.y * rhs.y, y: lhs.x * rhs.y + lhs.y * rhs.x }
fn add(lhs: CirclePoint<F>, rhs: CirclePoint<F>) -> CirclePoint<F> {
CirclePoint::<F> { x: lhs.x * rhs.x - lhs.y * rhs.y, y: lhs.x * rhs.y + lhs.y * rhs.x }
}
}

pub impl CirclePointM31Impl of CirclePointTrait<M31> {}

pub impl CirclePointQM31Impl of CirclePointTrait<QM31> {}

trait ComplexConjugate {
fn complex_conjugate(self: CirclePoint<QM31>) -> CirclePoint<QM31>;
}

pub impl ComplexConjugateImpl of ComplexConjugate {
fn complex_conjugate(self: CirclePoint<QM31>) -> CirclePoint<QM31> {
CirclePoint { x: self.x.complex_conjugate(), y: self.y.complex_conjugate() }
}
}

/// Represents the coset initial + <step>.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Drop)]
pub struct Coset {
// This is an index in the range [0, 2^31)
Expand Down Expand Up @@ -124,50 +151,77 @@ pub impl CosetImpl of CosetTrait {
}
}

fn at(self: @Coset, index: usize) -> CirclePointM31 {
M31_CIRCLE_GEN.mul(self.index_at(index))
fn at(self: @Coset, index: usize) -> CirclePoint::<M31> {
M31_CIRCLE_GEN.mul(self.index_at(index).into())
}

/// Returns the size of the coset.
fn size(self: @Coset) -> usize {
pow(2, *self.log_size)
}

/// Creates a coset of the form G_2n + \<G_n\>.
/// For example, for n=8, we get the point indices \[1,3,5,7,9,11,13,15\].
fn odds(log_size: u32) -> Coset {
//CIRCLE_LOG_ORDER
let subgroup_generator_index = Self::subgroup_generator_index(log_size);
Self::new(subgroup_generator_index, log_size)
}

/// Creates a coset of the form G_4n + <G_n>.
/// For example, for n=8, we get the point indices \[1,5,9,13,17,21,25,29\].
/// Its conjugate will be \[3,7,11,15,19,23,27,31\].
fn half_odds(log_size: u32) -> Coset {
Self::new(Self::subgroup_generator_index(log_size + 2), log_size)
}

fn subgroup_generator_index(log_size: u32) -> u32 {
assert!(log_size <= CIRCLE_LOG_ORDER);
pow(2, CIRCLE_LOG_ORDER - log_size)
}
}


#[cfg(test)]
mod tests {
use super::{M31_CIRCLE_GEN, CIRCLE_ORDER, CirclePointM31, CirclePointM31Impl, Coset, CosetImpl};
use stwo_cairo_verifier::fields::m31::m31;
use super::{M31_CIRCLE_GEN, CIRCLE_ORDER, CirclePoint, CirclePointM31Impl, Coset, CosetImpl};
use core::option::OptionTrait;
use core::array::ArrayTrait;
use core::traits::TryInto;
use super::CirclePointQM31Impl;
use stwo_cairo_verifier::fields::m31::{m31, M31};
use stwo_cairo_verifier::fields::qm31::{qm31, QM31, QM31One};
use stwo_cairo_verifier::utils::pow;

#[test]
fn test_add_1() {
let i = CirclePointM31 { x: m31(0), y: m31(1) };
let i = CirclePoint::<M31> { x: m31(0), y: m31(1) };
let result = i + i;
let expected_result = CirclePointM31 { x: -m31(1), y: m31(0) };
let expected_result = CirclePoint::<M31> { x: -m31(1), y: m31(0) };

assert_eq!(result, expected_result);
}

#[test]
fn test_add_2() {
let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) };
let point_2 = CirclePointM31 { x: m31(1737427771), y: m31(309481134) };
let point_1 = CirclePoint::<M31> { x: m31(750649172), y: m31(1991648574) };
let point_2 = CirclePoint::<M31> { x: m31(1737427771), y: m31(309481134) };
let result = point_1 + point_2;
let expected_result = CirclePointM31 { x: m31(1476625263), y: m31(1040927458) };
let expected_result = CirclePoint::<M31> { x: m31(1476625263), y: m31(1040927458) };

assert_eq!(result, expected_result);
}

#[test]
fn test_zero_1() {
let result = CirclePointM31Impl::zero();
let expected_result = CirclePointM31 { x: m31(1), y: m31(0) };
let expected_result = CirclePoint::<M31> { x: m31(1), y: m31(0) };
assert_eq!(result, expected_result);
}

#[test]
fn test_zero_2() {
let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) };
let point_1 = CirclePoint::<M31> { x: m31(750649172), y: m31(1991648574) };
let point_2 = CirclePointM31Impl::zero();
let expected_result = point_1.clone();
let result = point_1 + point_2;
Expand All @@ -177,7 +231,7 @@ mod tests {

#[test]
fn test_mul_1() {
let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) };
let point_1 = CirclePoint::<M31> { x: m31(750649172), y: m31(1991648574) };
let result = point_1.mul(5);
let expected_result = point_1 + point_1 + point_1 + point_1 + point_1;

Expand All @@ -186,7 +240,7 @@ mod tests {

#[test]
fn test_mul_2() {
let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) };
let point_1 = CirclePoint::<M31> { x: m31(750649172), y: m31(1991648574) };
let result = point_1.mul(8);
let mut expected_result = point_1 + point_1;
expected_result = expected_result + expected_result;
Expand All @@ -197,23 +251,30 @@ mod tests {

#[test]
fn test_mul_3() {
let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) };
let point_1 = CirclePoint::<M31> { x: m31(750649172), y: m31(1991648574) };
let result = point_1.mul(418776494);
let expected_result = CirclePointM31 { x: m31(1987283985), y: m31(1500510905) };
let expected_result = CirclePoint::<M31> { x: m31(1987283985), y: m31(1500510905) };

assert_eq!(result, expected_result);
}

#[test]
fn test_generator_order() {
let half_order = CIRCLE_ORDER / 2;
let mut result = M31_CIRCLE_GEN.mul(half_order);
let expected_result = CirclePointM31 { x: -m31(1), y: m31(0) };
let mut result = M31_CIRCLE_GEN.mul(half_order.into());
let expected_result = CirclePoint::<M31> { x: -m31(1), y: m31(0) };

// Assert `M31_CIRCLE_GEN^{2^30}` equals `-1`.
assert_eq!(expected_result, result);
}

#[test]
fn test_generator() {
let mut result = M31_CIRCLE_GEN.mul(pow(2, 30).try_into().unwrap());
let expected_result = CirclePoint::<M31> { x: -m31(1), y: m31(0) };
assert_eq!(expected_result, result);
}

#[test]
fn test_coset_index_at() {
let coset = Coset { initial_index: 16777216, log_size: 5, step_size: 67108864 };
Expand Down Expand Up @@ -242,7 +303,7 @@ mod tests {
fn test_coset_at() {
let coset = Coset { initial_index: 16777216, step_size: 67108864, log_size: 5 };
let result = coset.at(17);
let expected_result = CirclePointM31 { x: m31(7144319), y: m31(1742797653) };
let expected_result = CirclePoint::<M31> { x: m31(7144319), y: m31(1742797653) };
assert_eq!(expected_result, result);
}

Expand All @@ -253,5 +314,69 @@ mod tests {
let expected_result = 32;
assert_eq!(result, expected_result);
}

#[test]
fn test_qm31_circle_gen() {
let P4: u128 = 21267647892944572736998860269687930881;

let QM31_CIRCLE_GEN: CirclePoint<QM31> = CirclePoint::<
QM31
> {
x: qm31(1, 0, 478637715, 513582971),
y: qm31(992285211, 649143431, 740191619, 1186584352),
};

let first_prime = 2;
let last_prime = 368140581013;
let prime_factors: Array<(u128, u32)> = array![
(first_prime, 33),
(3, 2),
(5, 1),
(7, 1),
(11, 1),
(31, 1),
(151, 1),
(331, 1),
(733, 1),
(1709, 1),
(last_prime, 1),
];

let product = iter_product(first_prime, @prime_factors, last_prime);

assert_eq!(product, P4 - 1);

assert_eq!(
QM31_CIRCLE_GEN.x * QM31_CIRCLE_GEN.x + QM31_CIRCLE_GEN.y * QM31_CIRCLE_GEN.y,
QM31One::one()
);

assert_eq!(QM31_CIRCLE_GEN.mul(P4 - 1), CirclePointQM31Impl::zero());

let mut i = 0;
while i < prime_factors.len() {
let (p, _) = *prime_factors.at(i);
assert_ne!(QM31_CIRCLE_GEN.mul((P4 - 1) / p.into()), CirclePointQM31Impl::zero());

i = i + 1;
}
}

fn iter_product(
first_prime: u128, prime_factors: @Array<(u128, u32)>, last_prime: u128
) -> u128 {
let mut accum_product: u128 = 1;
accum_product = accum_product
* pow(first_prime.try_into().unwrap(), 31).into()
* 4; // * 2^33
let mut i = 1;
while i < prime_factors.len() - 1 {
let (prime, exponent): (u128, u32) = *prime_factors.at(i);
accum_product = accum_product * pow(prime.try_into().unwrap(), exponent).into();
i = i + 1;
};
accum_product = accum_product * last_prime;
accum_product
}
}

4 changes: 4 additions & 0 deletions stwo_cairo_verifier/src/fields/qm31.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ pub impl QM31Impl of QM31Trait {
b: CM31 { a: self.b.a * multiplier, b: self.b.b * multiplier }
}
}

fn complex_conjugate(self: QM31) -> QM31 {
QM31 { a: self.a, b: -self.b }
}
}

pub impl QM31Add of core::traits::Add<QM31> {
Expand Down
Loading

0 comments on commit ca1ea82

Please sign in to comment.