Skip to content

Commit

Permalink
Implement Stateful for all Chips (#1199)
Browse files Browse the repository at this point in the history
  • Loading branch information
nyunyunyunyu authored Jan 13, 2025
1 parent bca69b3 commit 8ccdf01
Show file tree
Hide file tree
Showing 82 changed files with 699 additions and 482 deletions.
40 changes: 24 additions & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ lto = "thin"

[workspace.dependencies]
# Stark Backend
openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", rev = "8c9f94", default-features = false }
openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", rev = "8c9f94", default-features = false }
openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", rev = "25265e7", default-features = false }
openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", rev = "25265e7", default-features = false }

# OpenVM
openvm-sdk = { path = "crates/sdk", default-features = false }
Expand Down
1 change: 1 addition & 0 deletions benchmarks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ openvm-sdk.workspace = true
openvm-stark-backend.workspace = true
openvm-stark-sdk.workspace = true
openvm-transpiler.workspace = true
openvm-circuit-derive.workspace = true

openvm-algebra-circuit.workspace = true
openvm-algebra-transpiler.workspace = true
Expand Down
1 change: 1 addition & 0 deletions crates/circuits/primitives/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ num-bigint-dig.workspace = true
num-traits.workspace = true
lazy_static.workspace = true
tracing.workspace = true
bitcode.workspace = true

[dev-dependencies]
p3-dft = { workspace = true }
Expand Down
85 changes: 83 additions & 2 deletions crates/circuits/primitives/src/bitwise_op_lookup/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::{
borrow::{Borrow, BorrowMut},
mem::size_of,
sync::{atomic::AtomicU32, Arc},
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
};

use itertools::Itertools;
use openvm_circuit_primitives_derive::AlignedBorrow;
use openvm_stark_backend::{
config::{StarkGenericConfig, Val},
Expand All @@ -13,7 +17,7 @@ use openvm_stark_backend::{
p3_matrix::{dense::RowMajorMatrix, Matrix},
prover::types::AirProofInput,
rap::{get_air_name, AnyRap, BaseAirWithPublicValues, PartitionedBaseAir},
Chip, ChipUsageGetter,
Chip, ChipUsageGetter, Stateful,
};

mod bus;
Expand Down Expand Up @@ -109,6 +113,11 @@ pub struct BitwiseOperationLookupChip<const NUM_BITS: usize> {
count_xor: Vec<AtomicU32>,
}

#[derive(Clone)]
pub struct SharedBitwiseOperationLookupChip<const NUM_BITS: usize>(
Arc<BitwiseOperationLookupChip<NUM_BITS>>,
);

impl<const NUM_BITS: usize> BitwiseOperationLookupChip<NUM_BITS> {
pub fn new(bus: BitwiseOperationLookupBus) -> Self {
let num_rows = (1 << NUM_BITS) * (1 << NUM_BITS);
Expand Down Expand Up @@ -169,6 +178,35 @@ impl<const NUM_BITS: usize> BitwiseOperationLookupChip<NUM_BITS> {
}
}

impl<const NUM_BITS: usize> SharedBitwiseOperationLookupChip<NUM_BITS> {
pub fn new(bus: BitwiseOperationLookupBus) -> Self {
Self(Arc::new(BitwiseOperationLookupChip::new(bus)))
}
pub fn bus(&self) -> BitwiseOperationLookupBus {
self.0.bus()
}

pub fn air_width(&self) -> usize {
self.0.air_width()
}

pub fn request_range(&self, x: u32, y: u32) {
self.0.request_range(x, y);
}

pub fn request_xor(&self, x: u32, y: u32) -> u32 {
self.0.request_xor(x, y)
}

pub fn clear(&self) {
self.0.clear()
}

pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
self.0.generate_trace()
}
}

impl<SC: StarkGenericConfig, const NUM_BITS: usize> Chip<SC>
for BitwiseOperationLookupChip<NUM_BITS>
{
Expand All @@ -182,6 +220,18 @@ impl<SC: StarkGenericConfig, const NUM_BITS: usize> Chip<SC>
}
}

impl<SC: StarkGenericConfig, const NUM_BITS: usize> Chip<SC>
for SharedBitwiseOperationLookupChip<NUM_BITS>
{
fn air(&self) -> Arc<dyn AnyRap<SC>> {
self.0.air()
}

fn generate_air_proof_input(self) -> AirProofInput<SC> {
self.0.generate_air_proof_input()
}
}

impl<const NUM_BITS: usize> ChipUsageGetter for BitwiseOperationLookupChip<NUM_BITS> {
fn air_name(&self) -> String {
get_air_name(&self.air)
Expand All @@ -196,3 +246,34 @@ impl<const NUM_BITS: usize> ChipUsageGetter for BitwiseOperationLookupChip<NUM_B
NUM_BITWISE_OP_LOOKUP_COLS
}
}

impl<const NUM_BITS: usize> ChipUsageGetter for SharedBitwiseOperationLookupChip<NUM_BITS> {
fn air_name(&self) -> String {
self.0.air_name()
}

fn current_trace_height(&self) -> usize {
self.0.current_trace_height()
}

fn trace_width(&self) -> usize {
self.0.trace_width()
}
}

impl<const NUM_BITS: usize> Stateful<Vec<u8>> for SharedBitwiseOperationLookupChip<NUM_BITS> {
fn load_state(&mut self, state: Vec<u8>) {
// AtomicU32 can be deserialized as u32
let (count_range, count_xor): (Vec<u32>, Vec<u32>) = bitcode::deserialize(&state).unwrap();
for (x, v) in self.0.count_range.iter().zip_eq(count_range) {
x.store(v, Ordering::Relaxed);
}
for (x, v) in self.0.count_xor.iter().zip_eq(count_xor) {
x.store(v, Ordering::Relaxed);
}
}

fn store_state(&self) -> Vec<u8> {
bitcode::serialize(&(&self.0.count_range, &self.0.count_xor)).unwrap()
}
}
Loading

0 comments on commit 8ccdf01

Please sign in to comment.