From ebf7aa6a27ac14cbdd1aadece153b2c83c89efe0 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Thu, 16 Jan 2025 00:01:48 +0800 Subject: [PATCH] [feat] Implement Stateful for VmChipComplex (#1211) * Implement Stateful for system chips * Implement Stateful for VmChipComplex * Testing simulates all steps in the serving flow * Move Stateful macro into BytesStateful in openvm-circuit-primitives-derive --- Cargo.lock | 1 + benchmarks/Cargo.toml | 1 + crates/circuits/mod-builder/src/core_chip.rs | 10 +- crates/circuits/primitives/derive/src/lib.rs | 88 +++++++++++++ .../circuits/primitives/src/var_range/mod.rs | 86 +++++++++++- crates/sdk/src/config/global.rs | 6 +- crates/vm/derive/src/lib.rs | 92 +------------ crates/vm/src/arch/config.rs | 6 +- crates/vm/src/arch/extensions.rs | 122 +++++++++++++++--- crates/vm/src/arch/segment.rs | 34 ++++- crates/vm/src/arch/testing/mod.rs | 10 +- crates/vm/src/system/connector/mod.rs | 15 ++- crates/vm/src/system/memory/adapter/mod.rs | 14 +- crates/vm/src/system/memory/controller/mod.rs | 28 ++-- crates/vm/src/system/memory/mod.rs | 2 +- crates/vm/src/system/memory/offline.rs | 50 +++---- crates/vm/src/system/memory/online.rs | 3 +- crates/vm/src/system/memory/tests.rs | 8 +- crates/vm/src/system/memory/volatile/mod.rs | 10 +- crates/vm/src/system/memory/volatile/tests.rs | 8 +- crates/vm/src/system/poseidon2/mod.rs | 4 +- crates/vm/src/system/program/mod.rs | 12 +- crates/vm/tests/integration_test.rs | 64 +++++++-- .../algebra/circuit/src/fp2_chip/addsub.rs | 12 +- .../algebra/circuit/src/fp2_chip/muldiv.rs | 12 +- .../algebra/circuit/src/fp2_extension.rs | 8 +- .../circuit/src/modular_chip/addsub.rs | 12 +- .../circuit/src/modular_chip/muldiv.rs | 12 +- .../algebra/circuit/src/modular_extension.rs | 8 +- extensions/bigint/circuit/src/extension.rs | 8 +- .../circuit/src/weierstrass_chip/double.rs | 10 +- .../ecc/circuit/src/weierstrass_chip/mod.rs | 14 +- .../ecc/circuit/src/weierstrass_extension.rs | 8 +- extensions/keccak256/circuit/src/extension.rs | 8 +- extensions/native/circuit/src/castf/core.rs | 13 +- extensions/native/circuit/src/extension.rs | 12 +- .../native/circuit/src/poseidon2/mod.rs | 16 +-- .../pairing/circuit/src/fp12_chip/mul.rs | 12 +- .../line/d_type/mul_013_by_013.rs | 12 +- .../pairing_chip/line/d_type/mul_by_01234.rs | 12 +- .../src/pairing_chip/line/evaluate_line.rs | 12 +- .../line/m_type/mul_023_by_023.rs | 12 +- .../pairing_chip/line/m_type/mul_by_02345.rs | 12 +- .../miller_double_and_add_step.rs | 12 +- .../src/pairing_chip/miller_double_step.rs | 12 +- .../pairing/circuit/src/pairing_extension.rs | 8 +- .../rv32im/circuit/src/adapters/hintstore.rs | 9 +- .../rv32im/circuit/src/adapters/loadstore.rs | 7 +- extensions/rv32im/circuit/src/extension.rs | 16 +-- extensions/rv32im/circuit/src/jalr/core.rs | 7 +- .../circuit/src/load_sign_extend/core.rs | 7 +- extensions/rv32im/circuit/src/shift/core.rs | 7 +- extensions/sha256/circuit/src/extension.rs | 8 +- 53 files changed, 635 insertions(+), 357 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 78c0329008..81d3c06916 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3499,6 +3499,7 @@ dependencies = [ "openvm-algebra-transpiler", "openvm-build", "openvm-circuit", + "openvm-circuit-primitives-derive", "openvm-ecc-circuit", "openvm-ecc-transpiler", "openvm-keccak256-circuit", diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 5a38c380d9..fa3797077e 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -10,6 +10,7 @@ license.workspace = true [dependencies] openvm-build.workspace = true openvm-circuit.workspace = true +openvm-circuit-primitives-derive.workspace = true openvm-sdk.workspace = true openvm-stark-backend.workspace = true openvm-stark-sdk.workspace = true diff --git a/crates/circuits/mod-builder/src/core_chip.rs b/crates/circuits/mod-builder/src/core_chip.rs index ca5033d462..b35aff0c36 100644 --- a/crates/circuits/mod-builder/src/core_chip.rs +++ b/crates/circuits/mod-builder/src/core_chip.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use itertools::Itertools; use num_bigint_dig::BigUint; use openvm_circuit::arch::{ @@ -7,7 +5,7 @@ use openvm_circuit::arch::{ Result, VmAdapterInterface, VmCoreAir, VmCoreChip, }; use openvm_circuit_primitives::{ - var_range::VariableRangeCheckerChip, SubAir, TraceSubRowGenerator, + var_range::SharedVariableRangeCheckerChip, SubAir, TraceSubRowGenerator, }; use openvm_instructions::instruction::Instruction; use openvm_stark_backend::{ @@ -170,7 +168,7 @@ pub struct FieldExpressionRecord { pub struct FieldExpressionCoreChip { pub air: FieldExpressionCoreAir, - pub range_checker: Arc, + pub range_checker: SharedVariableRangeCheckerChip, pub name: String, @@ -184,7 +182,7 @@ impl FieldExpressionCoreChip { offset: usize, local_opcode_idx: Vec, opcode_flag_idx: Vec, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, name: &str, should_finalize: bool, ) -> Self { @@ -277,7 +275,7 @@ where fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { self.air.expr.generate_subrow( - (&self.range_checker, record.inputs, record.flags), + (self.range_checker.as_ref(), record.inputs, record.flags), row_slice, ); } diff --git a/crates/circuits/primitives/derive/src/lib.rs b/crates/circuits/primitives/derive/src/lib.rs index 18db7d2970..298452ff14 100644 --- a/crates/circuits/primitives/derive/src/lib.rs +++ b/crates/circuits/primitives/derive/src/lib.rs @@ -310,3 +310,91 @@ pub fn chip_usage_getter_derive(input: TokenStream) -> TokenStream { Data::Union(_) => unimplemented!("Unions are not supported"), } } + +#[proc_macro_derive(BytesStateful)] +pub fn bytes_stateful_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + let generics = &ast.generics; + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + match &ast.data { + Data::Struct(inner) => { + // Check if the struct has only one unnamed field + let inner_ty = match &inner.fields { + Fields::Unnamed(fields) => { + if fields.unnamed.len() != 1 { + panic!("Only one unnamed field is supported"); + } + fields.unnamed.first().unwrap().ty.clone() + } + _ => panic!("Only unnamed fields are supported"), + }; + // Use full path ::openvm_circuit... so it can be used either within or outside the vm crate. + // Assume F is already generic of the field. + let mut new_generics = generics.clone(); + let where_clause = new_generics.make_where_clause(); + where_clause + .predicates + .push(syn::parse_quote! { #inner_ty: ::openvm_stark_backend::Stateful> }); + + quote! { + impl #impl_generics ::openvm_stark_backend::Stateful> for #name #ty_generics #where_clause { + fn load_state(&mut self, state: Vec) { + self.0.load_state(state) + } + + fn store_state(&self) -> Vec { + self.0.store_state() + } + } + } + .into() + } + Data::Enum(e) => { + let variants = e + .variants + .iter() + .map(|variant| { + let variant_name = &variant.ident; + + let mut fields = variant.fields.iter(); + let field = fields.next().unwrap(); + assert!(fields.next().is_none(), "Only one field is supported"); + (variant_name, field) + }) + .collect::>(); + // Use full path ::openvm_stark_backend... so it can be used either within or outside the vm crate. + let (load_state_arms, store_state_arms): (Vec<_>, Vec<_>) = + multiunzip(variants.iter().map(|(variant_name, field)| { + let field_ty = &field.ty; + let load_state_arm = quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful>>::load_state(x, state) + }; + let store_state_arm = quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful>>::store_state(x) + }; + + (load_state_arm, store_state_arm) + })); + quote! { + impl #impl_generics ::openvm_stark_backend::Stateful> for #name #ty_generics { + fn load_state(&mut self, state: Vec) { + match self { + #(#load_state_arms,)* + } + } + + fn store_state(&self) -> Vec { + match self { + #(#store_state_arms,)* + } + } + } + } + .into() + } + _ => unimplemented!(), + } +} diff --git a/crates/circuits/primitives/src/var_range/mod.rs b/crates/circuits/primitives/src/var_range/mod.rs index 8df350eb9d..feef96a17a 100644 --- a/crates/circuits/primitives/src/var_range/mod.rs +++ b/crates/circuits/primitives/src/var_range/mod.rs @@ -6,7 +6,10 @@ use core::mem::size_of; use std::{ borrow::{Borrow, BorrowMut}, - sync::{atomic::AtomicU32, Arc}, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -18,7 +21,7 @@ use openvm_stark_backend::{ p3_matrix::{dense::RowMajorMatrix, Matrix}, prover::types::AirProofInput, rap::{get_air_name, AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, - Chip, ChipUsageGetter, + Chip, ChipUsageGetter, Stateful, }; use tracing::instrument; @@ -99,6 +102,9 @@ pub struct VariableRangeCheckerChip { count: Vec, } +#[derive(Clone)] +pub struct SharedVariableRangeCheckerChip(Arc); + impl VariableRangeCheckerChip { pub fn new(bus: VariableRangeCheckerBus) -> Self { let num_rows = (1 << (bus.range_max_bits + 1)) as usize; @@ -179,6 +185,36 @@ impl VariableRangeCheckerChip { } } +impl SharedVariableRangeCheckerChip { + pub fn new(bus: VariableRangeCheckerBus) -> Self { + Self(Arc::new(VariableRangeCheckerChip::new(bus))) + } + + pub fn bus(&self) -> VariableRangeCheckerBus { + self.0.bus() + } + + pub fn range_max_bits(&self) -> usize { + self.0.range_max_bits() + } + + pub fn air_width(&self) -> usize { + self.0.air_width() + } + + pub fn add_count(&self, value: u32, max_bits: usize) { + self.0.add_count(value, max_bits) + } + + pub fn clear(&self) { + self.0.clear() + } + + pub fn generate_trace(&self) -> RowMajorMatrix { + self.0.generate_trace() + } +} + impl Chip for VariableRangeCheckerChip where Val: PrimeField32, @@ -193,6 +229,19 @@ where } } +impl Chip for SharedVariableRangeCheckerChip +where + Val: PrimeField32, +{ + fn air(&self) -> Arc> { + self.0.air() + } + + fn generate_air_proof_input(self) -> AirProofInput { + self.0.generate_air_proof_input() + } +} + impl ChipUsageGetter for VariableRangeCheckerChip { fn air_name(&self) -> String { get_air_name(&self.air) @@ -207,3 +256,36 @@ impl ChipUsageGetter for VariableRangeCheckerChip { NUM_VARIABLE_RANGE_COLS } } + +impl ChipUsageGetter for SharedVariableRangeCheckerChip { + fn air_name(&self) -> String { + self.0.air_name() + } + + fn current_trace_height(&self) -> usize { + self.0.current_trace_height() + } + + fn trace_width(&self) -> usize { + self.0.trace_width() + } +} + +impl Stateful> for SharedVariableRangeCheckerChip { + fn load_state(&mut self, state: Vec) { + let count_vals: Vec = bitcode::deserialize(&state).unwrap(); + for (x, val) in self.0.count.iter().zip(count_vals) { + x.store(val, Ordering::Relaxed); + } + } + + fn store_state(&self) -> Vec { + bitcode::serialize(&self.0.count).unwrap() + } +} + +impl AsRef for SharedVariableRangeCheckerChip { + fn as_ref(&self) -> &VariableRangeCheckerChip { + &self.0 + } +} diff --git a/crates/sdk/src/config/global.rs b/crates/sdk/src/config/global.rs index cc1a97ef70..a062269b53 100644 --- a/crates/sdk/src/config/global.rs +++ b/crates/sdk/src/config/global.rs @@ -11,7 +11,7 @@ use openvm_circuit::{ arch::{ SystemConfig, SystemExecutor, SystemPeriphery, VmChipComplex, VmConfig, VmInventoryError, }, - circuit_derive::{Chip, ChipUsageGetter}, + circuit_derive::{BytesStateful, Chip, ChipUsageGetter}, derive::{AnyEnum, InstructionExecutor}, }; use openvm_ecc_circuit::{ @@ -63,7 +63,7 @@ pub struct SdkVmConfig { pub castf: Option, } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, BytesStateful)] pub enum SdkVmConfigExecutor { #[any_enum] System(SystemExecutor), @@ -93,7 +93,7 @@ pub enum SdkVmConfigExecutor { CastF(CastFExtensionExecutor), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, BytesStateful)] pub enum SdkVmConfigPeriphery { #[any_enum] System(SystemPeriphery), diff --git a/crates/vm/derive/src/lib.rs b/crates/vm/derive/src/lib.rs index 5c36046877..90a9a0e03b 100644 --- a/crates/vm/derive/src/lib.rs +++ b/crates/vm/derive/src/lib.rs @@ -113,94 +113,6 @@ pub fn instruction_executor_derive(input: TokenStream) -> TokenStream { } } -#[proc_macro_derive(Stateful)] -pub fn stateful_derive(input: TokenStream) -> TokenStream { - let ast: syn::DeriveInput = syn::parse(input).unwrap(); - - let name = &ast.ident; - let generics = &ast.generics; - let (impl_generics, ty_generics, _) = generics.split_for_impl(); - - match &ast.data { - Data::Struct(inner) => { - // Check if the struct has only one unnamed field - let inner_ty = match &inner.fields { - Fields::Unnamed(fields) => { - if fields.unnamed.len() != 1 { - panic!("Only one unnamed field is supported"); - } - fields.unnamed.first().unwrap().ty.clone() - } - _ => panic!("Only unnamed fields are supported"), - }; - // Use full path ::openvm_circuit... so it can be used either within or outside the vm crate. - // Assume F is already generic of the field. - let mut new_generics = generics.clone(); - let where_clause = new_generics.make_where_clause(); - where_clause - .predicates - .push(syn::parse_quote! { #inner_ty: ::openvm_stark_backend::Stateful> }); - - quote! { - impl #impl_generics ::openvm_stark_backend::Stateful> for #name #ty_generics #where_clause { - fn load_state(&mut self, state: Vec) { - self.0.load_state(state) - } - - fn store_state(&self) -> Vec { - self.0.store_state() - } - } - } - .into() - } - Data::Enum(e) => { - let variants = e - .variants - .iter() - .map(|variant| { - let variant_name = &variant.ident; - - let mut fields = variant.fields.iter(); - let field = fields.next().unwrap(); - assert!(fields.next().is_none(), "Only one field is supported"); - (variant_name, field) - }) - .collect::>(); - // Use full path ::openvm_stark_backend... so it can be used either within or outside the vm crate. - let (load_state_arms, store_state_arms): (Vec<_>, Vec<_>) = - multiunzip(variants.iter().map(|(variant_name, field)| { - let field_ty = &field.ty; - let load_state_arm = quote! { - #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful>>::load_state(x, state) - }; - let store_state_arm = quote! { - #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful>>::store_state(x) - }; - - (load_state_arm, store_state_arm) - })); - quote! { - impl #impl_generics ::openvm_stark_backend::Stateful> for #name #ty_generics { - fn load_state(&mut self, state: Vec) { - match self { - #(#load_state_arms,)* - } - } - - fn store_state(&self) -> Vec { - match self { - #(#store_state_arms,)* - } - } - } - } - .into() - } - _ => unimplemented!(), - } -} - /// Derives `AnyEnum` trait on an enum type. /// By default an enum arm will just return `self` as `&dyn Any`. /// @@ -396,14 +308,14 @@ pub fn vm_generic_config_derive(input: proc_macro::TokenStream) -> proc_macro::T let executor_type = Ident::new(&format!("{}Executor", name), name.span()); let periphery_type = Ident::new(&format!("{}Periphery", name), name.span()); TokenStream::from(quote! { - #[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, ::openvm_circuit::derive::Stateful)] + #[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, ::openvm_circuit_primitives_derive::BytesStateful)] pub enum #executor_type { #[any_enum] #system_name_upper(SystemExecutor), #(#executor_enum_fields)* } - #[derive(ChipUsageGetter, Chip, From, AnyEnum)] + #[derive(ChipUsageGetter, Chip, From, AnyEnum, ::openvm_circuit_primitives_derive::BytesStateful)] pub enum #periphery_type { #[any_enum] #system_name_upper(SystemPeriphery), diff --git a/crates/vm/src/arch/config.rs b/crates/vm/src/arch/config.rs index 5b1335cb6f..d789e040f7 100644 --- a/crates/vm/src/arch/config.rs +++ b/crates/vm/src/arch/config.rs @@ -2,7 +2,7 @@ use derive_new::new; use openvm_circuit::system::memory::MemoryTraceHeights; use openvm_instructions::program::DEFAULT_MAX_NUM_PUBLIC_VALUES; use openvm_poseidon2_air::Poseidon2Config; -use openvm_stark_backend::{p3_field::PrimeField32, ChipUsageGetter}; +use openvm_stark_backend::{p3_field::PrimeField32, ChipUsageGetter, Stateful}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; // TODO[jpw]: re-exporting hardcoded bus constants for tests. Import paths should be @@ -30,8 +30,8 @@ pub fn vm_poseidon2_config() -> Poseidon2Config { } pub trait VmConfig: Clone + Serialize + DeserializeOwned { - type Executor: InstructionExecutor + AnyEnum + ChipUsageGetter; - type Periphery: AnyEnum + ChipUsageGetter; + type Executor: InstructionExecutor + AnyEnum + ChipUsageGetter + Stateful>; + type Periphery: AnyEnum + ChipUsageGetter + Stateful>; /// Must contain system config fn system(&self) -> &SystemConfig; diff --git a/crates/vm/src/arch/extensions.rs b/crates/vm/src/arch/extensions.rs index 3018c30b0b..aafbaa10e6 100644 --- a/crates/vm/src/arch/extensions.rs +++ b/crates/vm/src/arch/extensions.rs @@ -7,14 +7,15 @@ use std::{ use derive_more::derive::From; use getset::Getters; +use itertools::Itertools; #[cfg(feature = "bench-metrics")] use metrics::counter; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; use openvm_circuit_primitives::{ utils::next_power_of_two_or_zero, - var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_instructions::{ program::Program, PhantomDiscriminant, PublishOpcode, SystemOpcode, UsizeOpcode, VmOpcode, }; @@ -25,7 +26,7 @@ use openvm_stark_backend::{ p3_matrix::Matrix, prover::types::{AirProofInput, CommittedTraceData, ProofInput}, rap::AnyRap, - Chip, ChipUsageGetter, + Chip, ChipUsageGetter, Stateful, }; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; @@ -39,8 +40,10 @@ use crate::metrics::VmMetrics; use crate::system::{ connector::VmConnectorChip, memory::{ + interface::MemoryInterface, merkle::{DirectCompressionBus, MemoryMerkleBus}, offline_checker::{MemoryBridge, MemoryBus}, + online::MemoryLogEntry, MemoryController, MemoryImage, OfflineMemory, BOUNDARY_AIR_OFFSET, MERKLE_AIR_OFFSET, }, native_adapter::NativeAdapterChip, @@ -197,6 +200,14 @@ pub struct VmInventory { insertion_order: Vec, } +#[derive(Clone, Serialize, Deserialize)] +pub struct VmInventoryState { + /// Executor states in order + executors: Vec>, + /// Periphery states in order + periphery: Vec>, +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct VmInventoryTraceHeights { pub chips: FxHashMap, @@ -380,6 +391,27 @@ impl VmInventory { } } +impl>, P: Stateful>> Stateful for VmInventory { + fn load_state(&mut self, state: VmInventoryState) { + for (e, s) in self.executors.iter_mut().zip_eq(state.executors) { + e.load_state(s) + } + for (p, s) in self.periphery.iter_mut().zip_eq(state.periphery) { + p.load_state(s) + } + } + + fn store_state(&self) -> VmInventoryState { + // TODO: parallelize this. Now some implementations of Executor/Periphery are not Send + Sync. + let executors = self.executors.iter().map(|e| e.store_state()).collect(); + let periphery = self.periphery.iter().map(|p| p.store_state()).collect(); + VmInventoryState { + executors, + periphery, + } + } +} + impl VmInventoryTraceHeights { /// Round all trace heights to the next power of two. This will round trace heights of 0 to 1. pub fn round_to_next_power_of_two(&mut self) { @@ -431,6 +463,12 @@ pub struct VmChipComplex { bus_idx_max: usize, } +#[derive(Clone, Serialize, Deserialize)] +pub struct VmChipComplexState { + base: SystemBaseState, + inventory: VmInventoryState, +} + /// The base [VmChipComplex] with only system chips. pub type SystemComplex = VmChipComplex, SystemPeriphery>; @@ -439,17 +477,24 @@ pub type SystemComplex = VmChipComplex, SystemPeriphery< /// for the VM architecture. pub struct SystemBase { // RangeCheckerChip **must** be the last chip to have trace generation called on - pub range_checker_chip: Arc, + pub range_checker_chip: SharedVariableRangeCheckerChip, pub memory_controller: MemoryController, pub connector_chip: VmConnectorChip, pub program_chip: ProgramChip, +} - range_checker_bus: VariableRangeCheckerBus, +#[derive(Clone, Serialize, Deserialize)] +pub struct SystemBaseState { + pub range_checker_chip: Vec, + pub initial_memory: Option>, + pub memory_logs: Vec>, + pub connector_chip: Vec, + pub program_chip: Vec, } impl SystemBase { pub fn range_checker_bus(&self) -> VariableRangeCheckerBus { - self.range_checker_bus + self.range_checker_chip.bus() } pub fn memory_bus(&self) -> MemoryBus { @@ -489,13 +534,37 @@ impl SystemBase { } } -#[derive(ChipUsageGetter, Chip, AnyEnum, From, InstructionExecutor, Stateful)] +impl Stateful> for SystemBase { + fn load_state(&mut self, state: SystemBaseState) { + self.range_checker_chip.load_state(state.range_checker_chip); + if let Some(initial_memory) = state.initial_memory { + self.memory_controller.set_initial_memory(initial_memory); + } + self.memory_controller.set_memory_logs(state.memory_logs); + self.connector_chip.load_state(state.connector_chip); + self.program_chip.load_state(state.program_chip); + } + fn store_state(&self) -> SystemBaseState { + SystemBaseState { + range_checker_chip: self.range_checker_chip.store_state(), + initial_memory: match &self.memory_controller.interface_chip { + MemoryInterface::Volatile { .. } => None, + MemoryInterface::Persistent { initial_memory, .. } => Some(initial_memory.clone()), + }, + memory_logs: self.memory_controller.get_memory_logs(), + connector_chip: self.connector_chip.store_state(), + program_chip: self.program_chip.store_state(), + } + } +} + +#[derive(ChipUsageGetter, Chip, AnyEnum, From, InstructionExecutor, BytesStateful)] pub enum SystemExecutor { PublicValues(PublicValuesChip), Phantom(RefCell>), } -#[derive(ChipUsageGetter, Chip, AnyEnum, From, Stateful)] +#[derive(ChipUsageGetter, Chip, AnyEnum, From, BytesStateful)] pub enum SystemPeriphery { /// Poseidon2 chip with direct compression interactions Poseidon2(Poseidon2PeripheryChip), @@ -507,7 +576,7 @@ impl SystemComplex { VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, config.memory_config.decomp); let mut bus_idx_max = RANGE_CHECKER_BUS; - let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus)); + let range_checker = SharedVariableRangeCheckerChip::new(range_bus); let memory_controller = if config.continuation_enabled { bus_idx_max += 2; MemoryController::with_persistent_memory( @@ -581,7 +650,6 @@ impl SystemComplex { connector_chip, memory_controller, range_checker_chip: range_checker, - range_checker_bus: range_bus, }; Self { @@ -683,7 +751,7 @@ impl VmChipComplex { &self.base.memory_controller } - pub fn range_checker_chip(&self) -> &Arc { + pub fn range_checker_chip(&self) -> &SharedVariableRangeCheckerChip { &self.base.range_checker_chip } @@ -1043,6 +1111,22 @@ impl VmChipComplex { } } +impl>, P: Stateful>> Stateful> + for VmChipComplex +{ + fn load_state(&mut self, state: VmChipComplexState) { + self.base.load_state(state.base); + self.inventory.load_state(state.inventory); + } + + fn store_state(&self) -> VmChipComplexState { + VmChipComplexState { + base: self.base.store_state(), + inventory: self.inventory.store_state(), + } + } +} + struct VmProofInputBuilder { curr_air_id: usize, proof_input_per_air: Vec<(usize, AirProofInput)>, @@ -1120,7 +1204,7 @@ impl AnyEnum for () { } } -impl AnyEnum for Arc { +impl AnyEnum for SharedVariableRangeCheckerChip { fn as_any_kind(&self) -> &dyn Any { self } @@ -1151,18 +1235,18 @@ where Either::Periphery(chip) => chip.current_trace_height(), } } - fn current_trace_cells(&self) -> usize { - match self { - Either::Executor(chip) => chip.current_trace_cells(), - Either::Periphery(chip) => chip.current_trace_cells(), - } - } fn trace_width(&self) -> usize { match self { Either::Executor(chip) => chip.trace_width(), Either::Periphery(chip) => chip.trace_width(), } } + fn current_trace_cells(&self) -> usize { + match self { + Either::Executor(chip) => chip.current_trace_cells(), + Either::Periphery(chip) => chip.current_trace_cells(), + } + } } #[cfg(test)] diff --git a/crates/vm/src/arch/segment.rs b/crates/vm/src/arch/segment.rs index c95a921101..3d4d2bb546 100644 --- a/crates/vm/src/arch/segment.rs +++ b/crates/vm/src/arch/segment.rs @@ -10,12 +10,12 @@ use openvm_stark_backend::{ p3_field::PrimeField32, prover::types::{CommittedTraceData, ProofInput}, utils::metrics_span, - Chip, + Chip, Stateful, }; use super::{ - ExecutionError, Streams, SystemBase, SystemConfig, VmChipComplex, VmComplexTraceHeights, - VmConfig, + ExecutionError, Streams, SystemBase, SystemConfig, VmChipComplex, VmChipComplexState, + VmComplexTraceHeights, VmConfig, }; #[cfg(feature = "bench-metrics")] use crate::metrics::VmMetrics; @@ -85,6 +85,31 @@ impl> ExecutionSegment { } } + /// Creates a new execution segment just for proving. + pub fn new_for_proving( + config: &VC, + program: Program, + vm_chip_complex_state: VmChipComplexState, + ) -> Self { + let mut chip_complex = config.create_chip_complex().unwrap(); + let program = if !config.system().profiling { + program.strip_debug_infos() + } else { + program + }; + chip_complex.set_program(program); + chip_complex.load_state(vm_chip_complex_state); + let air_names = chip_complex.air_names(); + Self { + chip_complex, + final_memory: None, + air_names, + #[cfg(feature = "bench-metrics")] + metrics: Default::default(), + since_last_segment_check: 0, + } + } + pub fn system_config(&self) -> &SystemConfig { self.chip_complex.config() } @@ -274,4 +299,7 @@ impl> ExecutionSegment { pub fn current_trace_heights(&self) -> Vec { self.chip_complex.current_trace_heights() } + pub fn store_chip_complex_state(&self) -> VmChipComplexState { + self.chip_complex.store_state() + } } diff --git a/crates/vm/src/arch/testing/mod.rs b/crates/vm/src/arch/testing/mod.rs index 2dd1d2963e..bc6da329e2 100644 --- a/crates/vm/src/arch/testing/mod.rs +++ b/crates/vm/src/arch/testing/mod.rs @@ -4,7 +4,9 @@ use std::{ sync::{Arc, Mutex}, }; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; use openvm_instructions::instruction::Instruction; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, @@ -170,7 +172,7 @@ impl VmChipTestBuilder { self.memory.controller.clone() } - pub fn range_checker(&self) -> Arc { + pub fn range_checker(&self) -> SharedVariableRangeCheckerChip { self.memory.controller.borrow().range_checker.clone() } @@ -241,10 +243,10 @@ impl VmChipTestBuilder { impl Default for VmChipTestBuilder { fn default() -> Self { let mem_config = MemoryConfig::new(2, 1, 29, 29, 17, 64, 1 << 22); - let range_checker = Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new( + let range_checker = SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new( RANGE_CHECKER_BUS, mem_config.decomp, - ))); + )); let memory_controller = MemoryController::with_volatile_memory( MemoryBus(MEMORY_BUS), mem_config, diff --git a/crates/vm/src/system/connector/mod.rs b/crates/vm/src/system/connector/mod.rs index 88cdc88ed1..f4c8bc322c 100644 --- a/crates/vm/src/system/connector/mod.rs +++ b/crates/vm/src/system/connector/mod.rs @@ -14,8 +14,9 @@ use openvm_stark_backend::{ p3_matrix::{dense::RowMajorMatrix, Matrix}, prover::types::AirProofInput, rap::{AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, - Chip, ChipUsageGetter, + Chip, ChipUsageGetter, Stateful, }; +use serde::{Deserialize, Serialize}; use crate::{ arch::{instructions::SystemOpcode::TERMINATE, ExecutionBus, ExecutionState}, @@ -66,7 +67,7 @@ impl BaseAir for VmConnectorAir { } } -#[derive(Debug, Copy, Clone, AlignedBorrow)] +#[derive(Debug, Copy, Clone, AlignedBorrow, Serialize, Deserialize)] #[repr(C)] pub struct ConnectorCols { pub pc: T, @@ -218,3 +219,13 @@ impl ChipUsageGetter for VmConnectorChip { 4 } } + +impl Stateful> for VmConnectorChip { + fn load_state(&mut self, state: Vec) { + self.boundary_states = bitcode::deserialize(&state).unwrap(); + } + + fn store_state(&self) -> Vec { + bitcode::serialize(&self.boundary_states).unwrap() + } +} diff --git a/crates/vm/src/system/memory/adapter/mod.rs b/crates/vm/src/system/memory/adapter/mod.rs index 8341b8a769..9f0d202fe4 100644 --- a/crates/vm/src/system/memory/adapter/mod.rs +++ b/crates/vm/src/system/memory/adapter/mod.rs @@ -5,7 +5,7 @@ pub use columns::*; use enum_dispatch::enum_dispatch; use openvm_circuit_primitives::{ is_less_than::IsLtSubAir, utils::next_power_of_two_or_zero, - var_range::VariableRangeCheckerChip, TraceSubRowGenerator, + var_range::SharedVariableRangeCheckerChip, TraceSubRowGenerator, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_stark_backend::{ @@ -35,7 +35,7 @@ pub struct AccessAdapterInventory { impl AccessAdapterInventory { pub fn new( - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, clk_max_bits: usize, max_access_adapter_n: usize, @@ -119,7 +119,7 @@ impl AccessAdapterInventory { } fn create_access_adapter_chip( - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, clk_max_bits: usize, max_access_adapter_n: usize, @@ -178,7 +178,7 @@ enum GenericAccessAdapterChip { impl GenericAccessAdapterChip { fn new( - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, clk_max_bits: usize, ) -> Self { @@ -198,13 +198,13 @@ impl GenericAccessAdapterChip { } pub struct AccessAdapterChip { air: AccessAdapterAir, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, records: Vec>, overridden_height: Option, } impl AccessAdapterChip { pub fn new( - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, clk_max_bits: usize, ) -> Self { @@ -268,7 +268,7 @@ impl GenericAccessAdapterChipTrait for AccessAdapterChip { #[getset(get = "pub")] pub(crate) mem_config: MemoryConfig, - pub range_checker: Arc, + pub range_checker: SharedVariableRangeCheckerChip, // Store separately to avoid smart pointer reference each time range_checker_bus: VariableRangeCheckerBus, @@ -226,7 +226,7 @@ impl MemoryController { pub fn with_volatile_memory( memory_bus: MemoryBus, mem_config: MemoryConfig, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, ) -> Self { let range_checker_bus = range_checker.bus(); let initial_memory = MemoryImage::default(); @@ -267,7 +267,7 @@ impl MemoryController { pub fn with_persistent_memory( memory_bus: MemoryBus, mem_config: MemoryConfig, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, merkle_bus: MemoryMerkleBus, compression_bus: DirectCompressionBus, ) -> Self { @@ -694,10 +694,17 @@ impl MemoryController { pub fn offline_memory(&self) -> Arc>> { self.offline_memory.clone() } + pub fn get_memory_logs(&self) -> Vec> { + // TODO: can we avoid clone? + self.memory.log.clone() + } + pub fn set_memory_logs(&mut self, logs: Vec>) { + self.memory.log = logs; + } } pub struct MemoryAuxColsFactory { - pub(crate) range_checker: Arc, + pub(crate) range_checker: SharedVariableRangeCheckerChip, pub(crate) timestamp_lt_air: AssertLtSubAir, pub(crate) _marker: PhantomData, } @@ -753,7 +760,7 @@ impl MemoryAuxColsFactory { debug_assert!(prev_timestamp < timestamp); let mut decomp = [F::ZERO; AUX_LEN]; self.timestamp_lt_air.generate_subrow( - (&self.range_checker, prev_timestamp, timestamp), + (self.range_checker.as_ref(), prev_timestamp, timestamp), &mut decomp, ); LessThanAuxCols::new(decomp) @@ -762,9 +769,10 @@ impl MemoryAuxColsFactory { #[cfg(test)] mod tests { - use std::sync::Arc; - use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; + use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, + }; use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::p3_baby_bear::BabyBear; use rand::{prelude::SliceRandom, thread_rng, Rng}; @@ -784,7 +792,7 @@ mod tests { let memory_bus = MemoryBus(MEMORY_BUS); let memory_config = MemoryConfig::default(); let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, memory_config.decomp); - let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus)); + let range_checker = SharedVariableRangeCheckerChip::new(range_bus); let mut memory_controller = MemoryController::with_volatile_memory( memory_bus, diff --git a/crates/vm/src/system/memory/mod.rs b/crates/vm/src/system/memory/mod.rs index 9bac67aae6..f8c7b6451b 100644 --- a/crates/vm/src/system/memory/mod.rs +++ b/crates/vm/src/system/memory/mod.rs @@ -5,7 +5,7 @@ mod controller; pub mod merkle; mod offline; pub mod offline_checker; -mod online; +pub mod online; mod persistent; #[cfg(test)] mod tests; diff --git a/crates/vm/src/system/memory/offline.rs b/crates/vm/src/system/memory/offline.rs index 89ab2b2908..195fe595e2 100644 --- a/crates/vm/src/system/memory/offline.rs +++ b/crates/vm/src/system/memory/offline.rs @@ -1,7 +1,7 @@ -use std::{array, cmp::max, sync::Arc}; +use std::{array, cmp::max}; use openvm_circuit_primitives::{ - assert_less_than::AssertLtSubAir, var_range::VariableRangeCheckerChip, + assert_less_than::AssertLtSubAir, var_range::SharedVariableRangeCheckerChip, }; use openvm_stark_backend::p3_field::PrimeField32; use rustc_hash::{FxHashMap, FxHashSet}; @@ -41,7 +41,7 @@ pub struct OfflineMemory { timestamp_max_bits: usize, memory_bus: MemoryBus, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, log: Vec>>, } @@ -54,7 +54,7 @@ impl OfflineMemory { initial_memory: MemoryImage, initial_block_size: usize, memory_bus: MemoryBus, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, timestamp_max_bits: usize, ) -> Self { assert!(initial_block_size.is_power_of_two()); @@ -446,9 +446,9 @@ impl OfflineMemory { #[cfg(test)] mod tests { - use std::sync::Arc; - - use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; + use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, + }; use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::p3_baby_bear::BabyBear; @@ -486,9 +486,7 @@ mod tests { initial_memory, 8, MemoryBus(0), - Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new( - 1, 29, - ))), + SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new(1, 29)), 29, ); assert_eq!( @@ -535,9 +533,7 @@ mod tests { initial_memory, 1, MemoryBus(0), - Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new( - 1, 29, - ))), + SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new(1, 29)), 29, ); let address_space = 1; @@ -563,9 +559,7 @@ mod tests { initial_memory, 1, MemoryBus(0), - Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new( - 1, 29, - ))), + SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new(1, 29)), 29, ); @@ -710,9 +704,7 @@ mod tests { initial_memory, 8, MemoryBus(0), - Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new( - 1, 29, - ))), + SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new(1, 29)), 29, ); @@ -827,9 +819,7 @@ mod tests { initial_memory, 1, MemoryBus(0), - Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new( - 1, 29, - ))), + SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new(1, 29)), 29, ); @@ -851,9 +841,7 @@ mod tests { initial_memory, 8, MemoryBus(0), - Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new( - 1, 29, - ))), + SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new(1, 29)), 29, ); @@ -875,9 +863,7 @@ mod tests { initial_memory, 4, MemoryBus(0), - Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new( - 1, 29, - ))), + SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new(1, 29)), 29, ); @@ -893,9 +879,7 @@ mod tests { initial_memory, 8, MemoryBus(0), - Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new( - 1, 29, - ))), + SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new(1, 29)), 29, ); // Make block 0:4 in address space 1 active. @@ -965,9 +949,7 @@ mod tests { initial_memory, 8, MemoryBus(0), - Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new( - 1, 29, - ))), + SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new(1, 29)), 29, ); diff --git a/crates/vm/src/system/memory/online.rs b/crates/vm/src/system/memory/online.rs index 61fb2773b2..6c30176a51 100644 --- a/crates/vm/src/system/memory/online.rs +++ b/crates/vm/src/system/memory/online.rs @@ -2,10 +2,11 @@ use std::{array, fmt::Debug}; use openvm_stark_backend::p3_field::PrimeField32; use rustc_hash::FxHashMap; +use serde::{Deserialize, Serialize}; use crate::system::memory::{offline::INITIAL_TIMESTAMP, MemoryImage, RecordId}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum MemoryLogEntry { Read { address_space: u32, diff --git a/crates/vm/src/system/memory/tests.rs b/crates/vm/src/system/memory/tests.rs index 9a73601535..094b2204c5 100644 --- a/crates/vm/src/system/memory/tests.rs +++ b/crates/vm/src/system/memory/tests.rs @@ -5,7 +5,9 @@ use std::{ }; use itertools::Itertools; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_poseidon2_air::Poseidon2Config; use openvm_stark_backend::{ @@ -205,7 +207,7 @@ fn test_memory_controller() { let memory_bus = MemoryBus(MEMORY_BUS); let memory_config = MemoryConfig::default(); let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, memory_config.decomp); - let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus)); + let range_checker = SharedVariableRangeCheckerChip::new(range_bus); let mut memory_controller = MemoryController::with_volatile_memory(memory_bus, memory_config, range_checker.clone()); @@ -241,7 +243,7 @@ fn test_memory_controller_persistent() { let compression_bus = DirectCompressionBus(POSEIDON2_DIRECT_BUS); let memory_config = MemoryConfig::default(); let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, memory_config.decomp); - let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus)); + let range_checker = SharedVariableRangeCheckerChip::new(range_bus); let mut memory_controller = MemoryController::with_persistent_memory( memory_bus, diff --git a/crates/vm/src/system/memory/volatile/mod.rs b/crates/vm/src/system/memory/volatile/mod.rs index 074fe6fe18..0924c15dd4 100644 --- a/crates/vm/src/system/memory/volatile/mod.rs +++ b/crates/vm/src/system/memory/volatile/mod.rs @@ -8,7 +8,7 @@ use openvm_circuit_primitives::{ IsLtArrayAuxCols, IsLtArrayIo, IsLtArraySubAir, IsLtArrayWhenTransitionAir, }, utils::implies, - var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, SubAir, TraceSubRowGenerator, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -133,7 +133,7 @@ impl Air for VolatileBoundaryAir { pub struct VolatileBoundaryChip { pub air: VolatileBoundaryAir, touched_addresses: FxHashSet<(u32, u32)>, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, overridden_height: Option, final_memory: Option>, } @@ -143,7 +143,7 @@ impl VolatileBoundaryChip { memory_bus: MemoryBus, addr_space_max_bits: usize, pointer_max_bits: usize, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, ) -> Self { let range_bus = range_checker.bus(); Self { @@ -232,7 +232,7 @@ where let mut out = Val::::ZERO; air.addr_lt_air.0.generate_subrow( ( - &self.range_checker, + self.range_checker.as_ref(), &[row.addr_space, row.pointer], &[ Val::::from_canonical_u32(next_addr_space), @@ -250,7 +250,7 @@ where let row: &mut VolatileBoundaryCols<_> = rows[width * (trace_height - 1)..].borrow_mut(); air.addr_lt_air.0.generate_subrow( ( - &self.range_checker, + self.range_checker.as_ref(), &[Val::::ZERO, Val::::ZERO], &[Val::::ZERO, Val::::ZERO], ), diff --git a/crates/vm/src/system/memory/volatile/tests.rs b/crates/vm/src/system/memory/volatile/tests.rs index b3ce49c4f0..8149232e2f 100644 --- a/crates/vm/src/system/memory/volatile/tests.rs +++ b/crates/vm/src/system/memory/volatile/tests.rs @@ -1,6 +1,8 @@ use std::{collections::HashSet, iter, sync::Arc}; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; use openvm_stark_backend::{ p3_field::FieldAlgebra, p3_matrix::dense::RowMajorMatrix, prover::types::AirProofInput, Chip, }; @@ -42,7 +44,7 @@ fn boundary_air_test() { } let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, DECOMP); - let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus)); + let range_checker = SharedVariableRangeCheckerChip::new(range_bus); let mut boundary_chip = VolatileBoundaryChip::new(memory_bus, 2, LIMB_BITS, range_checker.clone()); @@ -110,7 +112,7 @@ fn boundary_air_test() { // test trace height override { let overridden_height = boundary_api.main_trace_height() * 2; - let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus)); + let range_checker = SharedVariableRangeCheckerChip::new(range_bus); let mut boundary_chip = VolatileBoundaryChip::new(memory_bus, 2, LIMB_BITS, range_checker.clone()); boundary_chip.set_overridden_height(overridden_height); diff --git a/crates/vm/src/system/poseidon2/mod.rs b/crates/vm/src/system/poseidon2/mod.rs index 2c31feae4e..87849f6c33 100644 --- a/crates/vm/src/system/poseidon2/mod.rs +++ b/crates/vm/src/system/poseidon2/mod.rs @@ -24,7 +24,7 @@ pub mod tests; pub mod air; mod chip; pub use chip::*; -use openvm_circuit_derive::Stateful; +use openvm_circuit_primitives_derive::BytesStateful; use crate::arch::hasher::{Hasher, HasherChip}; pub mod columns; @@ -33,7 +33,7 @@ pub mod trace; pub const PERIPHERY_POSEIDON2_WIDTH: usize = 16; pub const PERIPHERY_POSEIDON2_CHUNK_SIZE: usize = 8; -#[derive(Stateful)] +#[derive(BytesStateful)] pub enum Poseidon2PeripheryChip { Register0(Poseidon2PeripheryBaseChip), Register1(Poseidon2PeripheryBaseChip), diff --git a/crates/vm/src/system/program/mod.rs b/crates/vm/src/system/program/mod.rs index 3d68632d0f..d4c0addc48 100644 --- a/crates/vm/src/system/program/mod.rs +++ b/crates/vm/src/system/program/mod.rs @@ -2,7 +2,7 @@ use openvm_instructions::{ instruction::{DebugInfo, Instruction}, program::Program, }; -use openvm_stark_backend::{p3_field::PrimeField64, ChipUsageGetter}; +use openvm_stark_backend::{p3_field::PrimeField64, ChipUsageGetter, Stateful}; use crate::{arch::ExecutionError, system::program::trace::padding_instruction}; @@ -103,3 +103,13 @@ impl ChipUsageGetter for ProgramChip { 1 } } + +impl Stateful> for ProgramChip { + fn load_state(&mut self, state: Vec) { + self.execution_frequencies = bitcode::deserialize(&state).unwrap(); + } + + fn store_state(&self) -> Vec { + bitcode::serialize(&self.execution_frequencies).unwrap() + } +} diff --git a/crates/vm/tests/integration_test.rs b/crates/vm/tests/integration_test.rs index 4a45ba7e92..ad77b22786 100644 --- a/crates/vm/tests/integration_test.rs +++ b/crates/vm/tests/integration_test.rs @@ -4,9 +4,10 @@ use derive_more::derive::From; use openvm_circuit::{ arch::{ hasher::{poseidon2::vm_poseidon2_hasher, Hasher}, - ChipId, ExitCode, MemoryConfig, SingleSegmentVmExecutor, SystemConfig, SystemExecutor, - SystemPeriphery, SystemTraceHeights, VirtualMachine, VmChipComplex, VmComplexTraceHeights, - VmConfig, VmInventoryError, VmInventoryTraceHeights, + ChipId, ExecutionSegment, ExitCode, MemoryConfig, SingleSegmentVmExecutor, Streams, + SystemConfig, SystemExecutor, SystemPeriphery, SystemTraceHeights, VirtualMachine, + VmChipComplex, VmComplexTraceHeights, VmConfig, VmExecutorResult, VmInventoryError, + VmInventoryTraceHeights, }, derive::{AnyEnum, InstructionExecutor, VmConfig}, system::{ @@ -437,16 +438,13 @@ fn test_vm_1_persistent() { .expect("Verification failed"); } -#[test] -fn test_vm_continuations() { - let n = 200000; - +fn gen_continuation_test_program(n: isize) -> Program { // Simple Fibonacci program to compute nth Fibonacci number mod BabyBear (with F_0 = 1). // Register [0]_1 <- stores the loop counter. // Register [1]_1 <- stores F_i at the beginning of iteration i. // Register [2]_1 <- stores F_{i+1} at the beginning of iteration i. // Register [3]_1 is used as a temporary register. - let program = Program::from_instructions(&[ + Program::from_instructions(&[ // [0]_1 <- 0 Instruction::from_isize(VmOpcode::with_default_offset(ADD), 0, 0, 0, 1, 0), // [1]_1 <- 0 @@ -481,8 +479,13 @@ fn test_vm_continuations() { 0, 0, ), - ]); + ]) +} +#[test] +fn test_vm_continuations() { + let n = 200000; + let program = gen_continuation_test_program(n); let config = NativeConfig { system: SystemConfig::new(3, MemoryConfig::default(), 0).with_max_segment_len(200000), native: Default::default(), @@ -509,6 +512,49 @@ fn test_vm_continuations() { assert_eq!(pv_proof.public_values[0], expected_output); } +#[test] +fn test_vm_continuations_recover_state() { + let n = 2000; + let program = gen_continuation_test_program(n); + let config = NativeConfig { + system: SystemConfig::new(3, MemoryConfig::default(), 0).with_max_segment_len(500), + native: Default::default(), + } + .with_continuations(); + let engine = BabyBearPoseidon2Engine::new(FriParameters::standard_fast()); + let vm = VirtualMachine::new(engine, config.clone()); + let pk = vm.keygen(); + let segments = vm + .executor + .execute_segments(program.clone(), Streams::default()) + .unwrap(); + // Simulate remote proving which chip complex state needs to be serialized then deserialized. + let states: Vec<_> = segments + .iter() + .map(|s| bitcode::serialize(&s.store_chip_complex_state()).unwrap()) + .collect(); + let proof_inputs_per_seg = states + .into_iter() + .map(|s| { + ExecutionSegment::new_for_proving( + &config, + program.clone(), + bitcode::deserialize(&s).unwrap(), + ) + .generate_proof_input(None) + }) + .collect(); + let proofs = vm.prove( + &pk, + VmExecutorResult { + per_segment: proof_inputs_per_seg, + final_memory: None, + }, + ); + vm.verify(&pk.get_vk(), proofs) + .expect("Verification failed"); +} + #[test] fn test_vm_without_field_arithmetic() { /* diff --git a/extensions/algebra/circuit/src/fp2_chip/addsub.rs b/extensions/algebra/circuit/src/fp2_chip/addsub.rs index 76dd6a575d..b1c5fba9a4 100644 --- a/extensions/algebra/circuit/src/fp2_chip/addsub.rs +++ b/extensions/algebra/circuit/src/fp2_chip/addsub.rs @@ -6,9 +6,11 @@ use std::{ use openvm_algebra_transpiler::Fp2Opcode; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, }; @@ -19,7 +21,7 @@ use crate::Fp2; // Input: Fp2 * 2 // Output: Fp2 -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct Fp2AddSubChip( pub VmChipWrapper< F, @@ -35,7 +37,7 @@ impl adapter: Rv32VecHeapAdapterChip, config: ExprBuilderConfig, offset: usize, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, offline_memory: Arc>>, ) -> Self { let (expr, is_add_flag, is_sub_flag) = fp2_addsub_expr(config, range_checker.bus()); diff --git a/extensions/algebra/circuit/src/fp2_chip/muldiv.rs b/extensions/algebra/circuit/src/fp2_chip/muldiv.rs index 8701dc0aa5..34fc0c7f67 100644 --- a/extensions/algebra/circuit/src/fp2_chip/muldiv.rs +++ b/extensions/algebra/circuit/src/fp2_chip/muldiv.rs @@ -6,9 +6,11 @@ use std::{ use openvm_algebra_transpiler::Fp2Opcode; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, SymbolicExpr, }; @@ -19,7 +21,7 @@ use crate::Fp2; // Input: Fp2 * 2 // Output: Fp2 -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct Fp2MulDivChip( pub VmChipWrapper< F, @@ -35,7 +37,7 @@ impl adapter: Rv32VecHeapAdapterChip, config: ExprBuilderConfig, offset: usize, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, offline_memory: Arc>>, ) -> Self { let (expr, is_mul_flag, is_div_flag) = fp2_muldiv_expr(config, range_checker.bus()); diff --git a/extensions/algebra/circuit/src/fp2_extension.rs b/extensions/algebra/circuit/src/fp2_extension.rs index 6ab21eaa3e..49c56b53c6 100644 --- a/extensions/algebra/circuit/src/fp2_extension.rs +++ b/extensions/algebra/circuit/src/fp2_extension.rs @@ -5,11 +5,11 @@ use openvm_circuit::{ arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_instructions::{UsizeOpcode, VmOpcode}; use openvm_mod_circuit_builder::ExprBuilderConfig; use openvm_rv32_adapters::Rv32VecHeapAdapterChip; @@ -27,7 +27,7 @@ pub struct Fp2Extension { pub supported_modulus: Vec, } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From, Stateful)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From, BytesStateful)] pub enum Fp2ExtensionExecutor { // 32 limbs prime Fp2AddSubRv32_32(Fp2AddSubChip), @@ -37,7 +37,7 @@ pub enum Fp2ExtensionExecutor { Fp2MulDivRv32_48(Fp2MulDivChip), } -#[derive(ChipUsageGetter, Chip, AnyEnum, From, Stateful)] +#[derive(ChipUsageGetter, Chip, AnyEnum, From, BytesStateful)] pub enum Fp2ExtensionPeriphery { BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), // We put this only to get the generic to work diff --git a/extensions/algebra/circuit/src/modular_chip/addsub.rs b/extensions/algebra/circuit/src/modular_chip/addsub.rs index 846c50d075..4fdc19c40f 100644 --- a/extensions/algebra/circuit/src/modular_chip/addsub.rs +++ b/extensions/algebra/circuit/src/modular_chip/addsub.rs @@ -6,9 +6,11 @@ use std::{ use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, }; @@ -41,7 +43,7 @@ pub fn addsub_expr( ) } -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct ModularAddSubChip( pub VmChipWrapper< F, @@ -57,7 +59,7 @@ impl adapter: Rv32VecHeapAdapterChip, config: ExprBuilderConfig, offset: usize, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, offline_memory: Arc>>, ) -> Self { let (expr, is_add_flag, is_sub_flag) = addsub_expr(config, range_checker.bus()); diff --git a/extensions/algebra/circuit/src/modular_chip/muldiv.rs b/extensions/algebra/circuit/src/modular_chip/muldiv.rs index c24c7284be..d520bf4b4a 100644 --- a/extensions/algebra/circuit/src/modular_chip/muldiv.rs +++ b/extensions/algebra/circuit/src/modular_chip/muldiv.rs @@ -6,9 +6,11 @@ use std::{ use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, SymbolicExpr, }; @@ -55,7 +57,7 @@ pub fn muldiv_expr( ) } -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct ModularMulDivChip( pub VmChipWrapper< F, @@ -71,7 +73,7 @@ impl adapter: Rv32VecHeapAdapterChip, config: ExprBuilderConfig, offset: usize, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, offline_memory: Arc>>, ) -> Self { let (expr, is_mul_flag, is_div_flag) = muldiv_expr(config, range_checker.bus()); diff --git a/extensions/algebra/circuit/src/modular_extension.rs b/extensions/algebra/circuit/src/modular_extension.rs index fc0798e020..9acf66e4ee 100644 --- a/extensions/algebra/circuit/src/modular_extension.rs +++ b/extensions/algebra/circuit/src/modular_extension.rs @@ -6,11 +6,11 @@ use openvm_circuit::{ arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_instructions::{UsizeOpcode, VmOpcode}; use openvm_mod_circuit_builder::ExprBuilderConfig; use openvm_rv32_adapters::{Rv32IsEqualModAdapterChip, Rv32VecHeapAdapterChip}; @@ -30,7 +30,7 @@ pub struct ModularExtension { pub supported_modulus: Vec, } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From, Stateful)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From, BytesStateful)] pub enum ModularExtensionExecutor { // 32 limbs prime ModularAddSubRv32_32(ModularAddSubChip), @@ -42,7 +42,7 @@ pub enum ModularExtensionExecutor { ModularIsEqualRv32_48(ModularIsEqualChip), } -#[derive(ChipUsageGetter, Chip, AnyEnum, From, Stateful)] +#[derive(ChipUsageGetter, Chip, AnyEnum, From, BytesStateful)] pub enum ModularExtensionPeriphery { BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), // We put this only to get the generic to work diff --git a/extensions/bigint/circuit/src/extension.rs b/extensions/bigint/circuit/src/extension.rs index fcfffd860a..3c7d950827 100644 --- a/extensions/bigint/circuit/src/extension.rs +++ b/extensions/bigint/circuit/src/extension.rs @@ -10,12 +10,12 @@ use openvm_circuit::{ }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_instructions::{program::DEFAULT_PC_STEP, UsizeOpcode, VmOpcode}; use openvm_rv32im_circuit::{ Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, @@ -70,7 +70,7 @@ fn default_range_tuple_checker_sizes() -> [u32; 2] { [1 << 8, 32 * (1 << 8)] } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, BytesStateful)] pub enum Int256Executor { BaseAlu256(Rv32BaseAlu256Chip), LessThan256(Rv32LessThan256Chip), @@ -80,7 +80,7 @@ pub enum Int256Executor { Shift256(Rv32Shift256Chip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, BytesStateful)] pub enum Int256Periphery { BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), /// Only needed for multiplication extension diff --git a/extensions/ecc/circuit/src/weierstrass_chip/double.rs b/extensions/ecc/circuit/src/weierstrass_chip/double.rs index e9a4f56367..911acd3fd1 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/double.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/double.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, iter, rc::Rc, sync::Arc}; +use std::{cell::RefCell, iter, rc::Rc}; use itertools::{zip_eq, Itertools}; use num_bigint_dig::BigUint; @@ -9,7 +9,7 @@ use openvm_circuit::arch::{ }; use openvm_circuit_primitives::{ bigint::utils::big_uint_to_num_limbs, - var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, SubAir, TraceSubRowGenerator, }; use openvm_ecc_transpiler::Rv32WeierstrassOpcode; @@ -161,13 +161,13 @@ where pub struct EcDoubleCoreChip { pub air: EcDoubleCoreAir, - pub range_checker: Arc, + pub range_checker: SharedVariableRangeCheckerChip, } impl EcDoubleCoreChip { pub fn new( config: ExprBuilderConfig, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, a_biguint: BigUint, offset: usize, ) -> Self { @@ -258,7 +258,7 @@ where fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { self.air.expr.generate_subrow( ( - &self.range_checker, + self.range_checker.as_ref(), vec![record.x, record.y], vec![record.is_double_flag], ), diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mod.rs b/extensions/ecc/circuit/src/weierstrass_chip/mod.rs index 70c681dda0..b9f8425acc 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/mod.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/mod.rs @@ -13,9 +13,9 @@ use std::sync::Mutex; use num_bigint_dig::BigUint; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::var_range::VariableRangeCheckerChip; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_primitives::var_range::SharedVariableRangeCheckerChip; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_ecc_transpiler::Rv32WeierstrassOpcode; use openvm_mod_circuit_builder::{ExprBuilderConfig, FieldExpressionCoreChip}; use openvm_rv32_adapters::Rv32VecHeapAdapterChip; @@ -25,7 +25,7 @@ use openvm_stark_backend::p3_field::PrimeField32; /// BLOCKS: how many blocks do we need to represent one input or output /// For example, for bls12_381, BLOCK_SIZE = 16, each element has 3 blocks and with two elements per input AffinePoint, BLOCKS = 6. /// For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2. -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct EcAddNeChip( VmChipWrapper< F, @@ -41,7 +41,7 @@ impl adapter: Rv32VecHeapAdapterChip, config: ExprBuilderConfig, offset: usize, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, offline_memory: Arc>>, ) -> Self { let expr = ec_add_ne_expr(config, range_checker.bus()); @@ -61,7 +61,7 @@ impl } } -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct EcDoubleChip( VmChipWrapper< F, @@ -75,7 +75,7 @@ impl { pub fn new( adapter: Rv32VecHeapAdapterChip, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, config: ExprBuilderConfig, offset: usize, a: BigUint, diff --git a/extensions/ecc/circuit/src/weierstrass_extension.rs b/extensions/ecc/circuit/src/weierstrass_extension.rs index 0e4920f20c..2c1649c67c 100644 --- a/extensions/ecc/circuit/src/weierstrass_extension.rs +++ b/extensions/ecc/circuit/src/weierstrass_extension.rs @@ -7,11 +7,11 @@ use openvm_circuit::{ arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_ecc_guest::{ k256::{SECP256K1_MODULUS, SECP256K1_ORDER}, p256::{CURVE_A as P256_A, CURVE_B as P256_B, P256_MODULUS, P256_ORDER}, @@ -63,7 +63,7 @@ pub struct WeierstrassExtension { pub supported_curves: Vec, } -#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum, BytesStateful)] pub enum WeierstrassExtensionExecutor { // 32 limbs prime EcAddNeRv32_32(EcAddNeChip), @@ -73,7 +73,7 @@ pub enum WeierstrassExtensionExecutor { EcDoubleRv32_48(EcDoubleChip), } -#[derive(ChipUsageGetter, Chip, AnyEnum, From, Stateful)] +#[derive(ChipUsageGetter, Chip, AnyEnum, From, BytesStateful)] pub enum WeierstrassExtensionPeriphery { BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), Phantom(PhantomChip), diff --git a/extensions/keccak256/circuit/src/extension.rs b/extensions/keccak256/circuit/src/extension.rs index 131c59bf18..4b8d9c0fee 100644 --- a/extensions/keccak256/circuit/src/extension.rs +++ b/extensions/keccak256/circuit/src/extension.rs @@ -6,9 +6,9 @@ use openvm_circuit::{ }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; use openvm_circuit_primitives::bitwise_op_lookup::BitwiseOperationLookupBus; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_instructions::*; use openvm_rv32im_circuit::{ Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, @@ -49,12 +49,12 @@ impl Default for Keccak256Rv32Config { #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct Keccak256; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, BytesStateful)] pub enum Keccak256Executor { Keccak256(KeccakVmChip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, BytesStateful)] pub enum Keccak256Periphery { BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), Phantom(PhantomChip), diff --git a/extensions/native/circuit/src/castf/core.rs b/extensions/native/circuit/src/castf/core.rs index 39b17751cd..aca23aebe9 100644 --- a/extensions/native/circuit/src/castf/core.rs +++ b/extensions/native/circuit/src/castf/core.rs @@ -1,13 +1,12 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - sync::Arc, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::arch::{ AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, VmCoreAir, VmCoreChip, }; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::instruction::Instruction; use openvm_native_compiler::CastfOpcode; @@ -105,11 +104,11 @@ pub struct CastFRecord { pub struct CastFCoreChip { pub air: CastFCoreAir, - pub range_checker_chip: Arc, + pub range_checker_chip: SharedVariableRangeCheckerChip, } impl CastFCoreChip { - pub fn new(range_checker_chip: Arc, offset: usize) -> Self { + pub fn new(range_checker_chip: SharedVariableRangeCheckerChip, offset: usize) -> Self { Self { air: CastFCoreAir { bus: range_checker_chip.bus(), diff --git a/extensions/native/circuit/src/extension.rs b/extensions/native/circuit/src/extension.rs index cf2af0b0bd..db1a3dc98e 100644 --- a/extensions/native/circuit/src/extension.rs +++ b/extensions/native/circuit/src/extension.rs @@ -10,8 +10,8 @@ use openvm_circuit::{ }, system::{native_adapter::NativeAdapterChip, phantom::PhantomChip}, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful, VmConfig}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_instructions::{ program::DEFAULT_PC_STEP, PhantomDiscriminant, Poseidon2Opcode, UsizeOpcode, VmOpcode, }; @@ -77,7 +77,7 @@ impl NativeConfig { #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct Native; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, BytesStateful)] pub enum NativeExecutor { LoadStore(NativeLoadStoreChip), BlockLoadStore(NativeLoadStoreChip), @@ -89,7 +89,7 @@ pub enum NativeExecutor { FriReducedOpening(FriReducedOpeningChip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, BytesStateful)] pub enum NativePeriphery { Phantom(PhantomChip), } @@ -320,12 +320,12 @@ pub(crate) mod phantom { #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct CastFExtension; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, BytesStateful)] pub enum CastFExtensionExecutor { CastF(CastFChip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, BytesStateful)] pub enum CastFExtensionPeriphery { Placeholder(CastFChip), } diff --git a/extensions/native/circuit/src/poseidon2/mod.rs b/extensions/native/circuit/src/poseidon2/mod.rs index d16256e850..8f33d6aa9e 100644 --- a/extensions/native/circuit/src/poseidon2/mod.rs +++ b/extensions/native/circuit/src/poseidon2/mod.rs @@ -1,9 +1,14 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; +pub use columns::*; use openvm_circuit::{ arch::{ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor}, - system::{memory::MemoryController, program::ProgramBus}, + system::{ + memory::{offline_checker::MemoryBridge, MemoryController, OfflineMemory}, + program::ProgramBus, + }, }; +use openvm_circuit_primitives_derive::BytesStateful; use openvm_instructions::instruction::Instruction; use openvm_poseidon2_air::Poseidon2Config; use openvm_stark_backend::{ @@ -19,11 +24,6 @@ pub use air::*; mod chip; pub use chip::*; mod columns; -use std::sync::Mutex; - -pub use columns::*; -use openvm_circuit::system::memory::{offline_checker::MemoryBridge, OfflineMemory}; -use openvm_circuit_derive::Stateful; mod trace; @@ -33,7 +33,7 @@ mod tests; pub const NATIVE_POSEIDON2_WIDTH: usize = 16; pub const NATIVE_POSEIDON2_CHUNK_SIZE: usize = 8; -#[derive(Stateful)] +#[derive(BytesStateful)] pub enum NativePoseidon2Chip { Register0(NativePoseidon2BaseChip), Register1(NativePoseidon2BaseChip), diff --git a/extensions/pairing/circuit/src/fp12_chip/mul.rs b/extensions/pairing/circuit/src/fp12_chip/mul.rs index fcb44c3044..9ec637fa15 100644 --- a/extensions/pairing/circuit/src/fp12_chip/mul.rs +++ b/extensions/pairing/circuit/src/fp12_chip/mul.rs @@ -5,9 +5,11 @@ use std::{ }; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, }; @@ -18,7 +20,7 @@ use openvm_stark_backend::p3_field::PrimeField32; use crate::Fp12; // Input: Fp12 * 2 // Output: Fp12 -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct Fp12MulChip( pub VmChipWrapper< F, @@ -35,7 +37,7 @@ impl config: ExprBuilderConfig, xi: [isize; 2], offset: usize, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, offline_memory: Arc>>, ) -> Self { let expr = fp12_mul_expr(config, range_checker.bus(), xi); diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs index 77e53c25cc..e2f3bcbe25 100644 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs +++ b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs @@ -6,9 +6,11 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, }; @@ -18,7 +20,7 @@ use openvm_stark_backend::p3_field::PrimeField32; // Input: line0.b, line0.c, line1.b, line1.c : 2 x 4 field elements // Output: 5 Fp2 coefficients -> 10 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct EcLineMul013By013Chip< F: PrimeField32, const INPUT_BLOCKS: usize, @@ -41,7 +43,7 @@ impl< { pub fn new( adapter: Rv32VecHeapAdapterChip, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, config: ExprBuilderConfig, xi: [isize; 2], offset: usize, diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs index e846b42339..5d4057bb5f 100644 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs +++ b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs @@ -6,9 +6,11 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, }; @@ -20,7 +22,7 @@ use crate::Fp12; // Input: Fp12 (12 field elements), [Fp2; 5] (5 x 2 field elements) // Output: Fp12 (12 field elements) -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct EcLineMulBy01234Chip< F: PrimeField32, const INPUT_BLOCKS1: usize, @@ -62,7 +64,7 @@ impl< config: ExprBuilderConfig, xi: [isize; 2], offset: usize, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, offline_memory: Arc>>, ) -> Self { assert!( diff --git a/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs b/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs index 5ceda2bc57..f58d0550ff 100644 --- a/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs +++ b/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs @@ -6,9 +6,11 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, }; @@ -18,7 +20,7 @@ use openvm_stark_backend::p3_field::PrimeField32; // Input: UnevaluatedLine, (Fp, Fp) // Output: EvaluatedLine -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct EvaluateLineChip< F: PrimeField32, const INPUT_BLOCKS1: usize, @@ -59,7 +61,7 @@ impl< >, config: ExprBuilderConfig, offset: usize, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, offline_memory: Arc>>, ) -> Self { let expr = evaluate_line_expr(config, range_checker.bus()); diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs index d0e4f56c4d..2182702456 100644 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs +++ b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs @@ -6,9 +6,11 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, }; @@ -18,7 +20,7 @@ use openvm_stark_backend::p3_field::PrimeField32; // Input: line0.b, line0.c, line1.b, line1.c : 2 x 4 field elements // Output: 5 Fp2 coefficients -> 10 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct EcLineMul023By023Chip< F: PrimeField32, const INPUT_BLOCKS: usize, @@ -41,7 +43,7 @@ impl< { pub fn new( adapter: Rv32VecHeapAdapterChip, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, config: ExprBuilderConfig, xi: [isize; 2], offset: usize, diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs index 170a8a4c3d..04154d42cd 100644 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs +++ b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs @@ -6,9 +6,11 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, }; @@ -20,7 +22,7 @@ use crate::Fp12; // Input: 2 Fp12: 2 x 12 field elements // Output: Fp12 -> 12 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct EcLineMulBy02345Chip< F: PrimeField32, const INPUT_BLOCKS1: usize, @@ -59,7 +61,7 @@ impl< BLOCK_SIZE, BLOCK_SIZE, >, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, config: ExprBuilderConfig, xi: [isize; 2], offset: usize, diff --git a/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs b/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs index 7e7a602cfe..17adb213fe 100644 --- a/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs +++ b/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs @@ -6,9 +6,11 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, }; @@ -18,7 +20,7 @@ use openvm_stark_backend::p3_field::PrimeField32; // Input: two AffinePoint: 4 field elements each // Output: (AffinePoint, UnevaluatedLine, UnevaluatedLine) -> 2*2 + 2*2 + 2*2 = 12 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct MillerDoubleAndAddStepChip< F: PrimeField32, const INPUT_BLOCKS: usize, @@ -43,7 +45,7 @@ impl< adapter: Rv32VecHeapAdapterChip, config: ExprBuilderConfig, offset: usize, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, offline_memory: Arc>>, ) -> Self { let expr = miller_double_and_add_step_expr(config, range_checker.bus()); diff --git a/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs b/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs index 14bc7d29e5..5793a188c7 100644 --- a/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs +++ b/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs @@ -6,9 +6,11 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, }; @@ -18,7 +20,7 @@ use openvm_stark_backend::p3_field::PrimeField32; // Input: AffinePoint: 4 field elements // Output: (AffinePoint, Fp2, Fp2) -> 8 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, BytesStateful)] pub struct MillerDoubleStepChip< F: PrimeField32, const INPUT_BLOCKS: usize, @@ -43,7 +45,7 @@ impl< adapter: Rv32VecHeapAdapterChip, config: ExprBuilderConfig, offset: usize, - range_checker: Arc, + range_checker: SharedVariableRangeCheckerChip, offline_memory: Arc>>, ) -> Self { let expr = miller_double_step_expr(config, range_checker.bus()); diff --git a/extensions/pairing/circuit/src/pairing_extension.rs b/extensions/pairing/circuit/src/pairing_extension.rs index e08633ab1d..51b2c9fdd8 100644 --- a/extensions/pairing/circuit/src/pairing_extension.rs +++ b/extensions/pairing/circuit/src/pairing_extension.rs @@ -5,11 +5,11 @@ use openvm_circuit::{ arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_ecc_circuit::CurveConfig; use openvm_instructions::{PhantomDiscriminant, UsizeOpcode, VmOpcode}; use openvm_mod_circuit_builder::ExprBuilderConfig; @@ -64,7 +64,7 @@ pub struct PairingExtension { pub supported_curves: Vec, } -#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum, Stateful)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum, BytesStateful)] pub enum PairingExtensionExecutor { // bn254 (32 limbs) MillerDoubleStepRv32_32(MillerDoubleStepChip), @@ -82,7 +82,7 @@ pub enum PairingExtensionExecutor { EcLineMulBy02345(EcLineMulBy02345Chip), } -#[derive(ChipUsageGetter, Chip, AnyEnum, From, Stateful)] +#[derive(ChipUsageGetter, Chip, AnyEnum, From, BytesStateful)] pub enum PairingExtensionPeriphery { BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), Phantom(PhantomChip), diff --git a/extensions/rv32im/circuit/src/adapters/hintstore.rs b/extensions/rv32im/circuit/src/adapters/hintstore.rs index 14b143f424..fe658e46cd 100644 --- a/extensions/rv32im/circuit/src/adapters/hintstore.rs +++ b/extensions/rv32im/circuit/src/adapters/hintstore.rs @@ -2,7 +2,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, marker::PhantomData, - sync::Arc, }; use openvm_circuit::{ @@ -19,7 +18,9 @@ use openvm_circuit::{ program::ProgramBus, }, }; -use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +}; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, @@ -40,7 +41,7 @@ use crate::adapters::RV32_CELL_BITS; /// It writes to the memory at the intermediate pointer. pub struct Rv32HintStoreAdapterChip { pub air: Rv32HintStoreAdapterAir, - pub range_checker_chip: Arc, + pub range_checker_chip: SharedVariableRangeCheckerChip, _marker: PhantomData, } @@ -50,7 +51,7 @@ impl Rv32HintStoreAdapterChip { program_bus: ProgramBus, memory_bridge: MemoryBridge, pointer_max_bits: usize, - range_checker_chip: Arc, + range_checker_chip: SharedVariableRangeCheckerChip, ) -> Self { assert!(range_checker_chip.range_max_bits() >= 16); Self { diff --git a/extensions/rv32im/circuit/src/adapters/loadstore.rs b/extensions/rv32im/circuit/src/adapters/loadstore.rs index 005742c9c8..fdb4c3b62d 100644 --- a/extensions/rv32im/circuit/src/adapters/loadstore.rs +++ b/extensions/rv32im/circuit/src/adapters/loadstore.rs @@ -2,7 +2,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, marker::PhantomData, - sync::Arc, }; use openvm_circuit::{ @@ -22,7 +21,7 @@ use openvm_circuit::{ }; use openvm_circuit_primitives::{ utils::select, - var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -94,7 +93,7 @@ impl VmAdapterInterface for Rv32LoadStoreAdapt /// In case of Stores, reads from rs2 and writes to the shifted intermediate pointer. pub struct Rv32LoadStoreAdapterChip { pub air: Rv32LoadStoreAdapterAir, - pub range_checker_chip: Arc, + pub range_checker_chip: SharedVariableRangeCheckerChip, offset: usize, _marker: PhantomData, } @@ -105,7 +104,7 @@ impl Rv32LoadStoreAdapterChip { program_bus: ProgramBus, memory_bridge: MemoryBridge, pointer_max_bits: usize, - range_checker_chip: Arc, + range_checker_chip: SharedVariableRangeCheckerChip, offset: usize, ) -> Self { assert!(range_checker_chip.range_max_bits() >= 15); diff --git a/extensions/rv32im/circuit/src/extension.rs b/extensions/rv32im/circuit/src/extension.rs index 7eb1ac7d69..d737dfb6bf 100644 --- a/extensions/rv32im/circuit/src/extension.rs +++ b/extensions/rv32im/circuit/src/extension.rs @@ -6,12 +6,12 @@ use openvm_circuit::{ }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_instructions::{program::DEFAULT_PC_STEP, PhantomDiscriminant, UsizeOpcode, VmOpcode}; use openvm_rv32im_transpiler::{ BaseAluOpcode, BranchEqualOpcode, BranchLessThanOpcode, DivRemOpcode, LessThanOpcode, @@ -150,7 +150,7 @@ fn default_range_tuple_checker_sizes() -> [u32; 2] { // ============ Executor and Periphery Enums for Extension ============ /// RISC-V 32-bit Base (RV32I) Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, BytesStateful)] pub enum Rv32IExecutor { // Rv32 (for standard 32-bit integers): BaseAlu(Rv32BaseAluChip), @@ -166,7 +166,7 @@ pub enum Rv32IExecutor { } /// RISC-V 32-bit Multiplication Extension (RV32M) Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, BytesStateful)] pub enum Rv32MExecutor { Multiplication(Rv32MultiplicationChip), MultiplicationHigh(Rv32MulHChip), @@ -174,19 +174,19 @@ pub enum Rv32MExecutor { } /// RISC-V 32-bit Io Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, BytesStateful)] pub enum Rv32IoExecutor { HintStore(Rv32HintStoreChip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, BytesStateful)] pub enum Rv32IPeriphery { BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), // We put this only to get the generic to work Phantom(PhantomChip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, BytesStateful)] pub enum Rv32MPeriphery { BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), /// Only needed for multiplication extension @@ -195,7 +195,7 @@ pub enum Rv32MPeriphery { Phantom(PhantomChip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, BytesStateful)] pub enum Rv32IoPeriphery { BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), // We put this only to get the generic to work diff --git a/extensions/rv32im/circuit/src/jalr/core.rs b/extensions/rv32im/circuit/src/jalr/core.rs index 7938f77c91..7da98539a7 100644 --- a/extensions/rv32im/circuit/src/jalr/core.rs +++ b/extensions/rv32im/circuit/src/jalr/core.rs @@ -1,7 +1,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - sync::Arc, }; use openvm_circuit::arch::{ @@ -10,7 +9,7 @@ use openvm_circuit::arch::{ }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, - var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -175,13 +174,13 @@ where pub struct Rv32JalrCoreChip { pub air: Rv32JalrCoreAir, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - pub range_checker_chip: Arc, + pub range_checker_chip: SharedVariableRangeCheckerChip, } impl Rv32JalrCoreChip { pub fn new( bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - range_checker_chip: Arc, + range_checker_chip: SharedVariableRangeCheckerChip, offset: usize, ) -> Self { assert!(range_checker_chip.range_max_bits() >= 16); diff --git a/extensions/rv32im/circuit/src/load_sign_extend/core.rs b/extensions/rv32im/circuit/src/load_sign_extend/core.rs index e4a224bb0e..da07d16838 100644 --- a/extensions/rv32im/circuit/src/load_sign_extend/core.rs +++ b/extensions/rv32im/circuit/src/load_sign_extend/core.rs @@ -1,7 +1,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - sync::Arc, }; use openvm_circuit::arch::{ @@ -9,7 +8,7 @@ use openvm_circuit::arch::{ }; use openvm_circuit_primitives::{ utils::select, - var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, UsizeOpcode}; @@ -176,11 +175,11 @@ where pub struct LoadSignExtendCoreChip { pub air: LoadSignExtendCoreAir, - pub range_checker_chip: Arc, + pub range_checker_chip: SharedVariableRangeCheckerChip, } impl LoadSignExtendCoreChip { - pub fn new(range_checker_chip: Arc, offset: usize) -> Self { + pub fn new(range_checker_chip: SharedVariableRangeCheckerChip, offset: usize) -> Self { Self { air: LoadSignExtendCoreAir:: { range_bus: range_checker_chip.bus(), diff --git a/extensions/rv32im/circuit/src/shift/core.rs b/extensions/rv32im/circuit/src/shift/core.rs index 28eca7833a..f7de777964 100644 --- a/extensions/rv32im/circuit/src/shift/core.rs +++ b/extensions/rv32im/circuit/src/shift/core.rs @@ -1,7 +1,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - sync::Arc, }; use openvm_circuit::arch::{ @@ -11,7 +10,7 @@ use openvm_circuit::arch::{ use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, - var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, UsizeOpcode}; @@ -252,13 +251,13 @@ pub struct ShiftCoreRecord { pub struct ShiftCoreChip { pub air: ShiftCoreAir, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - pub range_checker_chip: Arc, + pub range_checker_chip: SharedVariableRangeCheckerChip, } impl ShiftCoreChip { pub fn new( bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - range_checker_chip: Arc, + range_checker_chip: SharedVariableRangeCheckerChip, offset: usize, ) -> Self { assert_eq!(NUM_LIMBS % 2, 0, "Number of limbs must be divisible by 2"); diff --git a/extensions/sha256/circuit/src/extension.rs b/extensions/sha256/circuit/src/extension.rs index baeeccf76b..d2772c6b8d 100644 --- a/extensions/sha256/circuit/src/extension.rs +++ b/extensions/sha256/circuit/src/extension.rs @@ -6,11 +6,11 @@ use openvm_circuit::{ }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_primitives_derive::{BytesStateful, Chip, ChipUsageGetter}; use openvm_instructions::*; use openvm_rv32im_circuit::{ Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, @@ -52,12 +52,12 @@ impl Default for Sha256Rv32Config { #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct Sha256; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, BytesStateful)] pub enum Sha256Executor { Sha256(Sha256VmChip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, BytesStateful)] pub enum Sha256Periphery { BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), Phantom(PhantomChip),