Skip to content

Commit

Permalink
Channel
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Jul 9, 2024
1 parent 7af2aaf commit 232859c
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 3 deletions.
2 changes: 1 addition & 1 deletion stwo_cairo_verifier/.tool-versions
Original file line number Diff line number Diff line change
@@ -1 +1 @@
scarb nightly-2024-06-01
scarb nightly-2024-06-15
231 changes: 231 additions & 0 deletions stwo_cairo_verifier/src/channel.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
use core::array::SpanTrait;
use core::poseidon::{poseidon_hash_span, hades_permutation};
use core::traits::DivRem;

use stwo_cairo_verifier::{BaseField, SecureField};
use stwo_cairo_verifier::fields::qm31::QM31Trait;
use stwo_cairo_verifier::utils::pack4;

const M31_SHIFT: felt252 = 0x80000000; // 2**31.
const M31_SHIFT_NZ_U256: NonZero<u256> = 0x80000000; // 2**31.
pub const EXTENSION_FELTS_PER_HASH: usize = 2;
pub const FELTS_PER_HASH: usize = 8;

#[derive(Default, Drop)]
pub struct ChannelTime {
n_challenges: usize,
n_sent: usize,
}

#[generate_trait]
impl ChannelTimeImpl of ChannelTimeTrait {
fn inc_sent(ref self: ChannelTime) {
self.n_sent += 1;
}

fn inc_challenges(ref self: ChannelTime) {
self.n_challenges += 1;
self.n_sent = 0;
}
}

#[derive(Drop)]
pub struct Channel {
digest: felt252,
channel_time: ChannelTime,
}

#[generate_trait]
pub impl ChannelImpl of ChannelTrait {
fn new(digest: felt252) -> Channel {
Channel { digest, channel_time: Default::default(), }
}

fn get_digest(ref self: Channel) -> felt252 {
self.digest
}

fn draw_felt252(ref self: Channel) -> felt252 {
let (res, _, _) = hades_permutation(self.digest, self.channel_time.n_sent.into(), 2);
self.channel_time.inc_sent();
res
}

// TODO(spapini): Check that this is sound.
#[inline]
fn draw_base_felts(ref self: Channel) -> [BaseField; FELTS_PER_HASH] {
let mut cur = self.draw_felt252().into();
[
extract_m31(ref cur),
extract_m31(ref cur),
extract_m31(ref cur),
extract_m31(ref cur),
extract_m31(ref cur),
extract_m31(ref cur),
extract_m31(ref cur),
extract_m31(ref cur),
]
}

fn mix_digest(ref self: Channel, digest: felt252) {
let (s0, _, _) = hades_permutation(self.digest, digest, 2);
self.digest = s0;
self.channel_time.inc_challenges();
}

fn mix_felts(ref self: Channel, mut felts: Span<SecureField>) {
let mut res = array![self.digest];
loop {
match (felts.pop_front(), felts.pop_front()) {
(Option::None, _) => { break; },
(Option::Some(x), Option::None) => {
res.append(pack4(0, (*x).to_array()));
break;
},
(
Option::Some(x), Option::Some(y)
) => {
let cur = pack4(0, (*x).to_array());
res.append(pack4(cur, (*y).to_array()));
},
};
};

self.digest = poseidon_hash_span(res.span());

// TODO(spapini): do we need length padding?
self.channel_time.inc_challenges();
}

fn mix_nonce(ref self: Channel, nonce: u64) {
self.mix_digest(nonce.into())
}

fn draw_felt(ref self: Channel) -> SecureField {
let [r0, r1, r2, r3, _, _, _, _] = self.draw_base_felts();
QM31Trait::from_array([r0, r1, r2, r3])
}

fn draw_felts(ref self: Channel, mut n_felts: usize) -> Array<SecureField> {
let mut res: Array = Default::default();
loop {
if n_felts == 0 {
break;
}
let [r0, r1, r2, r3, r4, r5, r6, r7] = self.draw_base_felts();
res.append(QM31Trait::from_array([r0, r1, r2, r3]));
if n_felts == 1 {
break;
}
res.append(QM31Trait::from_array([r4, r5, r6, r7]));
n_felts -= 2;
};
res
}
}

#[inline]
fn extract_m31<const N: usize>(ref num: u256) -> BaseField {
let (q, r) = DivRem::div_rem(num, M31_SHIFT_NZ_U256);
num = q;
let r: u32 = r.try_into().unwrap();
if r.into() == M31_SHIFT - 1 {
BaseField { inner: 0 }
} else {
BaseField { inner: r }
}
}


#[cfg(test)]
mod tests {
use super::{Channel, ChannelTrait};
use stwo_cairo_verifier::fields::qm31::qm31;

#[test]
fn test_initialize_channel() {
let initial_digest = 0;
let channel = ChannelTrait::new(initial_digest);

// Assert that the channel is initialized correctly.
assert_eq!(channel.digest, initial_digest);
assert_eq!(channel.channel_time.n_challenges, 0);
assert_eq!(channel.channel_time.n_sent, 0);
}

#[test]
fn test_channel_time() {
let initial_digest = 0;
let mut channel = ChannelTrait::new(initial_digest);

assert_eq!(channel.channel_time.n_challenges, 0);
assert_eq!(channel.channel_time.n_sent, 0);

channel.draw_felt();
assert_eq!(channel.channel_time.n_challenges, 0);
assert_eq!(channel.channel_time.n_sent, 1);

channel.draw_felts(9);
assert_eq!(channel.channel_time.n_challenges, 0);
assert_eq!(channel.channel_time.n_sent, 6);

channel.mix_digest(0);
assert_eq!(channel.channel_time.n_challenges, 1);
assert_eq!(channel.channel_time.n_sent, 0);

channel.draw_felt();
assert_eq!(channel.channel_time.n_challenges, 1);
assert_eq!(channel.channel_time.n_sent, 1);
assert_ne!(channel.digest, initial_digest);
}


#[test]
pub fn test_draw_felt() {
let initial_digest = 0;
let mut channel = ChannelTrait::new(initial_digest);

let first_random_felt = channel.draw_felt();

// Assert that next random felt is different.
assert_ne!(first_random_felt, channel.draw_felt());
}

#[test]
pub fn test_draw_felts() {
let initial_digest = 0;
let mut channel = ChannelTrait::new(initial_digest);

let mut random_felts = channel.draw_felts(5);
random_felts.append_span(channel.draw_felts(4).span());

// Assert that all the random felts are unique.
assert_ne!(random_felts[0], random_felts[5]);
}

#[test]
pub fn test_mix_digest() {
let initial_digest = 0;
let mut channel = ChannelTrait::new(initial_digest);

let mut n: usize = 10;
while n > 0 {
n -= 1;
channel.draw_felt();
};

let prev_digest = channel.digest;
channel.mix_digest(0);
assert_ne!(prev_digest, channel.digest);
}

#[test]
pub fn test_mix_felts() {
let initial_digest = 0;
let mut channel = ChannelTrait::new(initial_digest);

channel.mix_felts(array![qm31(1, 2, 3, 4), qm31(5, 6, 7, 8), qm31(9, 10, 11, 12)].span());

assert_ne!(initial_digest, channel.digest);
}
}
1 change: 1 addition & 0 deletions stwo_cairo_verifier/src/fields.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pub mod cm31;
pub mod qm31;

pub type BaseField = m31::M31;
pub type SecureField = qm31::QM31;
5 changes: 5 additions & 0 deletions stwo_cairo_verifier/src/fields/m31.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ pub impl M31Neg of Neg<M31> {
}
}
}
impl M31IntoFelt252 of Into<M31, felt252> {
fn into(self: M31) -> felt252 {
self.inner.into()
}
}

pub fn m31(val: u32) -> M31 {
M31Impl::reduce_u32(val)
Expand Down
11 changes: 10 additions & 1 deletion stwo_cairo_verifier/src/fields/qm31.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@ pub struct QM31 {
}

#[generate_trait]
impl QM31Impl of QM31Trait {
pub impl QM31Impl of QM31Trait {
#[inline]
fn from_array(arr: [M31; 4]) -> QM31 {
let [a, b, c, d] = arr;
QM31 { a: CM31 { a: a, b: b }, b: CM31 { a: c, b: d } }
}
#[inline]
fn to_array(self: QM31) -> [M31; 4] {
[self.a.a, self.a.b, self.b.a, self.b.b]
}
fn inverse(self: QM31) -> QM31 {
assert_ne!(self, Zero::zero());
let b2 = self.b * self.b;
Expand Down
3 changes: 2 additions & 1 deletion stwo_cairo_verifier/src/lib.cairo
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
mod channel;
mod fields;
mod utils;
mod vcs;

pub use fields::BaseField;
pub use fields::{BaseField, SecureField};

fn main() {}
11 changes: 11 additions & 0 deletions stwo_cairo_verifier/src/utils.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use core::dict::Felt252DictEntryTrait;
use core::dict::Felt252DictTrait;
use core::iter::Iterator;

use stwo_cairo_verifier::BaseField;

#[generate_trait]
pub impl DictImpl<T, +Felt252DictValue<T>, +PanicDestruct<T>> of DictTrait<T> {
fn replace(ref self: Felt252Dict<T>, key: felt252, new_value: T) -> T {
Expand Down Expand Up @@ -69,3 +71,12 @@ pub impl SpanImpl<T> of SpanExTrait<T> {
Option::Some(max)
}
}

const M31_SHIFT: felt252 = 0x80000000; // 2**31.
// Packs 4 BaseField values and "append" to a felt252.
// The resulting felt252 is: cur || x0 || x1 || x2 || x3.
pub fn pack4(cur: felt252, values: [BaseField; 4]) -> felt252 {
let [x0, x1, x2, x3] = values;
(((cur * M31_SHIFT + x0.into()) * M31_SHIFT + x1.into()) * M31_SHIFT + x2.into()) * M31_SHIFT
+ x3.into()
}

0 comments on commit 232859c

Please sign in to comment.