Skip to content

Commit

Permalink
[feat] Implement Stateful for VmChipComplex (#1211)
Browse files Browse the repository at this point in the history
* Implement Stateful for system chips

* Implement Stateful for VmChipComplex

* Testing simulates all steps in the serving flow

* Move Stateful macro into BytesStateful in openvm-circuit-primitives-derive
  • Loading branch information
nyunyunyunyu authored Jan 15, 2025
1 parent 3fe3dc5 commit ebf7aa6
Show file tree
Hide file tree
Showing 53 changed files with 635 additions and 357 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions benchmarks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ license.workspace = true
[dependencies]
openvm-build.workspace = true
openvm-circuit.workspace = true
openvm-circuit-primitives-derive.workspace = true
openvm-sdk.workspace = true
openvm-stark-backend.workspace = true
openvm-stark-sdk.workspace = true
Expand Down
10 changes: 4 additions & 6 deletions crates/circuits/mod-builder/src/core_chip.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use std::sync::Arc;

use itertools::Itertools;
use num_bigint_dig::BigUint;
use openvm_circuit::arch::{
AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface, DynArray, MinimalInstruction,
Result, VmAdapterInterface, VmCoreAir, VmCoreChip,
};
use openvm_circuit_primitives::{
var_range::VariableRangeCheckerChip, SubAir, TraceSubRowGenerator,
var_range::SharedVariableRangeCheckerChip, SubAir, TraceSubRowGenerator,
};
use openvm_instructions::instruction::Instruction;
use openvm_stark_backend::{
Expand Down Expand Up @@ -170,7 +168,7 @@ pub struct FieldExpressionRecord {

pub struct FieldExpressionCoreChip {
pub air: FieldExpressionCoreAir,
pub range_checker: Arc<VariableRangeCheckerChip>,
pub range_checker: SharedVariableRangeCheckerChip,

pub name: String,

Expand All @@ -184,7 +182,7 @@ impl FieldExpressionCoreChip {
offset: usize,
local_opcode_idx: Vec<usize>,
opcode_flag_idx: Vec<usize>,
range_checker: Arc<VariableRangeCheckerChip>,
range_checker: SharedVariableRangeCheckerChip,
name: &str,
should_finalize: bool,
) -> Self {
Expand Down Expand Up @@ -277,7 +275,7 @@ where

fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
self.air.expr.generate_subrow(
(&self.range_checker, record.inputs, record.flags),
(self.range_checker.as_ref(), record.inputs, record.flags),
row_slice,
);
}
Expand Down
88 changes: 88 additions & 0 deletions crates/circuits/primitives/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,91 @@ pub fn chip_usage_getter_derive(input: TokenStream) -> TokenStream {
Data::Union(_) => unimplemented!("Unions are not supported"),
}
}

#[proc_macro_derive(BytesStateful)]
pub fn bytes_stateful_derive(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();

let name = &ast.ident;
let generics = &ast.generics;
let (impl_generics, ty_generics, _) = generics.split_for_impl();

match &ast.data {
Data::Struct(inner) => {
// Check if the struct has only one unnamed field
let inner_ty = match &inner.fields {
Fields::Unnamed(fields) => {
if fields.unnamed.len() != 1 {
panic!("Only one unnamed field is supported");
}
fields.unnamed.first().unwrap().ty.clone()
}
_ => panic!("Only unnamed fields are supported"),
};
// Use full path ::openvm_circuit... so it can be used either within or outside the vm crate.
// Assume F is already generic of the field.
let mut new_generics = generics.clone();
let where_clause = new_generics.make_where_clause();
where_clause
.predicates
.push(syn::parse_quote! { #inner_ty: ::openvm_stark_backend::Stateful<Vec<u8>> });

quote! {
impl #impl_generics ::openvm_stark_backend::Stateful<Vec<u8>> for #name #ty_generics #where_clause {
fn load_state(&mut self, state: Vec<u8>) {
self.0.load_state(state)
}

fn store_state(&self) -> Vec<u8> {
self.0.store_state()
}
}
}
.into()
}
Data::Enum(e) => {
let variants = e
.variants
.iter()
.map(|variant| {
let variant_name = &variant.ident;

let mut fields = variant.fields.iter();
let field = fields.next().unwrap();
assert!(fields.next().is_none(), "Only one field is supported");
(variant_name, field)
})
.collect::<Vec<_>>();
// Use full path ::openvm_stark_backend... so it can be used either within or outside the vm crate.
let (load_state_arms, store_state_arms): (Vec<_>, Vec<_>) =
multiunzip(variants.iter().map(|(variant_name, field)| {
let field_ty = &field.ty;
let load_state_arm = quote! {
#name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful<Vec<u8>>>::load_state(x, state)
};
let store_state_arm = quote! {
#name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful<Vec<u8>>>::store_state(x)
};

(load_state_arm, store_state_arm)
}));
quote! {
impl #impl_generics ::openvm_stark_backend::Stateful<Vec<u8>> for #name #ty_generics {
fn load_state(&mut self, state: Vec<u8>) {
match self {
#(#load_state_arms,)*
}
}

fn store_state(&self) -> Vec<u8> {
match self {
#(#store_state_arms,)*
}
}
}
}
.into()
}
_ => unimplemented!(),
}
}
86 changes: 84 additions & 2 deletions crates/circuits/primitives/src/var_range/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
use core::mem::size_of;
use std::{
borrow::{Borrow, BorrowMut},
sync::{atomic::AtomicU32, Arc},
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
};

use openvm_circuit_primitives_derive::AlignedBorrow;
Expand All @@ -18,7 +21,7 @@ use openvm_stark_backend::{
p3_matrix::{dense::RowMajorMatrix, Matrix},
prover::types::AirProofInput,
rap::{get_air_name, AnyRap, BaseAirWithPublicValues, PartitionedBaseAir},
Chip, ChipUsageGetter,
Chip, ChipUsageGetter, Stateful,
};
use tracing::instrument;

Expand Down Expand Up @@ -99,6 +102,9 @@ pub struct VariableRangeCheckerChip {
count: Vec<AtomicU32>,
}

#[derive(Clone)]
pub struct SharedVariableRangeCheckerChip(Arc<VariableRangeCheckerChip>);

impl VariableRangeCheckerChip {
pub fn new(bus: VariableRangeCheckerBus) -> Self {
let num_rows = (1 << (bus.range_max_bits + 1)) as usize;
Expand Down Expand Up @@ -179,6 +185,36 @@ impl VariableRangeCheckerChip {
}
}

impl SharedVariableRangeCheckerChip {
pub fn new(bus: VariableRangeCheckerBus) -> Self {
Self(Arc::new(VariableRangeCheckerChip::new(bus)))
}

pub fn bus(&self) -> VariableRangeCheckerBus {
self.0.bus()
}

pub fn range_max_bits(&self) -> usize {
self.0.range_max_bits()
}

pub fn air_width(&self) -> usize {
self.0.air_width()
}

pub fn add_count(&self, value: u32, max_bits: usize) {
self.0.add_count(value, max_bits)
}

pub fn clear(&self) {
self.0.clear()
}

pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
self.0.generate_trace()
}
}

impl<SC: StarkGenericConfig> Chip<SC> for VariableRangeCheckerChip
where
Val<SC>: PrimeField32,
Expand All @@ -193,6 +229,19 @@ where
}
}

impl<SC: StarkGenericConfig> Chip<SC> for SharedVariableRangeCheckerChip
where
Val<SC>: PrimeField32,
{
fn air(&self) -> Arc<dyn AnyRap<SC>> {
self.0.air()
}

fn generate_air_proof_input(self) -> AirProofInput<SC> {
self.0.generate_air_proof_input()
}
}

impl ChipUsageGetter for VariableRangeCheckerChip {
fn air_name(&self) -> String {
get_air_name(&self.air)
Expand All @@ -207,3 +256,36 @@ impl ChipUsageGetter for VariableRangeCheckerChip {
NUM_VARIABLE_RANGE_COLS
}
}

impl ChipUsageGetter for SharedVariableRangeCheckerChip {
fn air_name(&self) -> String {
self.0.air_name()
}

fn current_trace_height(&self) -> usize {
self.0.current_trace_height()
}

fn trace_width(&self) -> usize {
self.0.trace_width()
}
}

impl Stateful<Vec<u8>> for SharedVariableRangeCheckerChip {
fn load_state(&mut self, state: Vec<u8>) {
let count_vals: Vec<u32> = bitcode::deserialize(&state).unwrap();
for (x, val) in self.0.count.iter().zip(count_vals) {
x.store(val, Ordering::Relaxed);
}
}

fn store_state(&self) -> Vec<u8> {
bitcode::serialize(&self.0.count).unwrap()
}
}

impl AsRef<VariableRangeCheckerChip> for SharedVariableRangeCheckerChip {
fn as_ref(&self) -> &VariableRangeCheckerChip {
&self.0
}
}
6 changes: 3 additions & 3 deletions crates/sdk/src/config/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use openvm_circuit::{
arch::{
SystemConfig, SystemExecutor, SystemPeriphery, VmChipComplex, VmConfig, VmInventoryError,
},
circuit_derive::{Chip, ChipUsageGetter},
circuit_derive::{BytesStateful, Chip, ChipUsageGetter},
derive::{AnyEnum, InstructionExecutor},
};
use openvm_ecc_circuit::{
Expand Down Expand Up @@ -63,7 +63,7 @@ pub struct SdkVmConfig {
pub castf: Option<CastFExtension>,
}

#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)]
#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, BytesStateful)]
pub enum SdkVmConfigExecutor<F: PrimeField32> {
#[any_enum]
System(SystemExecutor<F>),
Expand Down Expand Up @@ -93,7 +93,7 @@ pub enum SdkVmConfigExecutor<F: PrimeField32> {
CastF(CastFExtensionExecutor<F>),
}

#[derive(From, ChipUsageGetter, Chip, AnyEnum)]
#[derive(From, ChipUsageGetter, Chip, AnyEnum, BytesStateful)]
pub enum SdkVmConfigPeriphery<F: PrimeField32> {
#[any_enum]
System(SystemPeriphery<F>),
Expand Down
Loading

0 comments on commit ebf7aa6

Please sign in to comment.