From 8ccdf01a3f4d86d834090f556946321ec2f1dbdb Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Mon, 13 Jan 2025 17:16:17 +0800 Subject: [PATCH] Implement Stateful for all Chips (#1199) --- Cargo.lock | 40 +++++---- Cargo.toml | 4 +- benchmarks/Cargo.toml | 1 + crates/circuits/primitives/Cargo.toml | 1 + .../primitives/src/bitwise_op_lookup/mod.rs | 85 +++++++++++++++++- .../primitives/src/range_tuple/mod.rs | 76 +++++++++++++++- crates/circuits/sha256-air/src/tests.rs | 10 +-- crates/circuits/sha256-air/src/trace.rs | 8 +- crates/toolchain/instructions/src/lib.rs | 14 ++- crates/toolchain/tests/Cargo.toml | 1 + crates/vm/Cargo.toml | 1 + crates/vm/derive/src/lib.rs | 90 ++++++++++++++++++- crates/vm/src/arch/extensions.rs | 6 +- crates/vm/src/arch/integration_api.rs | 14 ++- crates/vm/src/system/phantom/mod.rs | 17 +++- crates/vm/src/system/poseidon2/chip.rs | 14 ++- crates/vm/src/system/poseidon2/mod.rs | 2 + .../algebra/circuit/src/fp2_chip/addsub.rs | 11 +-- .../algebra/circuit/src/fp2_chip/muldiv.rs | 11 +-- .../algebra/circuit/src/fp2_extension.rs | 20 ++--- .../algebra/circuit/src/modular_chip/is_eq.rs | 7 +- .../algebra/circuit/src/modular_chip/tests.rs | 14 ++- .../algebra/circuit/src/modular_extension.rs | 20 ++--- extensions/bigint/circuit/src/extension.rs | 28 +++--- extensions/bigint/circuit/src/tests.rs | 32 ++----- .../ecc/circuit/src/weierstrass_chip/mod.rs | 6 +- .../ecc/circuit/src/weierstrass_chip/tests.rs | 16 ++-- .../ecc/circuit/src/weierstrass_extension.rs | 20 ++--- extensions/keccak256/circuit/Cargo.toml | 2 + extensions/keccak256/circuit/src/extension.rs | 16 ++-- extensions/keccak256/circuit/src/lib.rs | 25 ++++-- extensions/keccak256/circuit/src/tests.rs | 6 +- extensions/native/circuit/Cargo.toml | 1 + extensions/native/circuit/src/extension.rs | 6 +- extensions/native/circuit/src/fri/mod.rs | 15 +++- .../native/circuit/src/poseidon2/chip.rs | 20 ++++- .../native/circuit/src/poseidon2/mod.rs | 2 + .../pairing/circuit/src/fp12_chip/mul.rs | 12 +-- .../pairing/circuit/src/fp12_chip/tests.rs | 8 +- .../line/d_type/mul_013_by_013.rs | 4 +- .../pairing_chip/line/d_type/mul_by_01234.rs | 4 +- .../src/pairing_chip/line/d_type/tests.rs | 16 +--- .../src/pairing_chip/line/evaluate_line.rs | 4 +- .../line/m_type/mul_023_by_023.rs | 4 +- .../pairing_chip/line/m_type/mul_by_02345.rs | 4 +- .../src/pairing_chip/line/m_type/tests.rs | 12 +-- .../miller_double_and_add_step.rs | 12 +-- .../src/pairing_chip/miller_double_step.rs | 16 ++-- .../pairing/circuit/src/pairing_extension.rs | 20 ++--- extensions/rv32-adapters/src/eq_mod.rs | 7 +- extensions/rv32-adapters/src/heap.rs | 9 +- extensions/rv32-adapters/src/heap_branch.rs | 7 +- extensions/rv32-adapters/src/vec_heap.rs | 11 ++- .../rv32-adapters/src/vec_heap_two_reads.rs | 11 ++- extensions/rv32im/circuit/src/auipc/core.rs | 7 +- extensions/rv32im/circuit/src/auipc/tests.rs | 16 ++-- .../rv32im/circuit/src/base_alu/core.rs | 7 +- .../rv32im/circuit/src/base_alu/tests.rs | 12 +-- .../rv32im/circuit/src/branch_lt/core.rs | 7 +- .../rv32im/circuit/src/branch_lt/tests.rs | 16 ++-- extensions/rv32im/circuit/src/divrem/core.rs | 13 ++- extensions/rv32im/circuit/src/divrem/tests.rs | 18 ++-- extensions/rv32im/circuit/src/extension.rs | 56 ++++++------ .../rv32im/circuit/src/hintstore/core.rs | 6 +- .../rv32im/circuit/src/hintstore/tests.rs | 14 +-- extensions/rv32im/circuit/src/jal_lui/core.rs | 7 +- .../rv32im/circuit/src/jal_lui/tests.rs | 16 ++-- extensions/rv32im/circuit/src/jalr/core.rs | 6 +- extensions/rv32im/circuit/src/jalr/tests.rs | 16 ++-- .../rv32im/circuit/src/less_than/core.rs | 7 +- .../rv32im/circuit/src/less_than/tests.rs | 12 +-- extensions/rv32im/circuit/src/mul/core.rs | 7 +- extensions/rv32im/circuit/src/mul/tests.rs | 8 +- extensions/rv32im/circuit/src/mulh/core.rs | 13 ++- extensions/rv32im/circuit/src/mulh/tests.rs | 18 ++-- extensions/rv32im/circuit/src/shift/core.rs | 6 +- extensions/rv32im/circuit/src/shift/tests.rs | 12 +-- extensions/sha256/circuit/Cargo.toml | 1 + extensions/sha256/circuit/src/extension.rs | 20 ++--- .../sha256/circuit/src/sha256_chip/mod.rs | 23 +++-- .../sha256/circuit/src/sha256_chip/tests.rs | 12 +-- .../sha256/circuit/src/sha256_chip/trace.rs | 2 +- 82 files changed, 699 insertions(+), 482 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 64ff70abe5..75e6ec0dcd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1213,9 +1213,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.24" +version = "4.5.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9560b07a799281c7e0958b9296854d6fafd4c5f31444a7e5bb1ad6dde5ccf1bd" +checksum = "b95dca1b68188a08ca6af9d96a6576150f598824bdb528c1190460c2940a0b48" dependencies = [ "clap_builder", "clap_derive", @@ -1223,9 +1223,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.24" +version = "4.5.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "874e0dd3eb68bf99058751ac9712f622e61e6f393a94f7128fa26e3f02f5c7cd" +checksum = "9ab52925392148efd3f7562f2136a81ffb778076bcc85727c6e020d6dd57cf15" dependencies = [ "anstream", "anstyle", @@ -3499,6 +3499,7 @@ dependencies = [ "openvm-algebra-transpiler", "openvm-build", "openvm-circuit", + "openvm-circuit-derive", "openvm-circuit-primitives-derive", "openvm-ecc-circuit", "openvm-ecc-transpiler", @@ -3615,6 +3616,7 @@ dependencies = [ "ark-ff 0.4.2", "async-trait", "backtrace", + "bitcode", "cfg-if", "derivative", "derive-new", @@ -3675,6 +3677,7 @@ dependencies = [ name = "openvm-circuit-primitives" version = "0.2.0-alpha" dependencies = [ + "bitcode", "derive-new", "itertools 0.13.0", "lazy_static", @@ -3846,6 +3849,7 @@ dependencies = [ name = "openvm-keccak256-circuit" version = "0.2.0-alpha" dependencies = [ + "bitcode", "derive-new", "derive_more 1.0.0", "eyre", @@ -3865,6 +3869,7 @@ dependencies = [ "p3-keccak-air", "rand", "serde", + "serde-big-array", "strum", "test-case", "test-log", @@ -3947,6 +3952,7 @@ dependencies = [ name = "openvm-native-circuit" version = "0.2.0-alpha" dependencies = [ + "bitcode", "derive-new", "derive_more 1.0.0", "eyre", @@ -4353,6 +4359,7 @@ dependencies = [ name = "openvm-sha256-circuit" version = "0.2.0-alpha" dependencies = [ + "bitcode", "derive-new", "derive_more 1.0.0", "hex", @@ -4445,7 +4452,7 @@ dependencies = [ [[package]] name = "openvm-stark-backend" version = "0.2.0-alpha" -source = "git+https://github.com/openvm-org/stark-backend.git?rev=8c9f94#8c9f94b71607eacfc740ab3d677f676d20796ace" +source = "git+https://github.com/openvm-org/stark-backend.git?rev=25265e7#25265e7abd59c21ad110e2c673d549801bbf9a17" dependencies = [ "async-trait", "cfg-if", @@ -4472,7 +4479,7 @@ dependencies = [ [[package]] name = "openvm-stark-sdk" version = "0.2.0-alpha" -source = "git+https://github.com/openvm-org/stark-backend.git?rev=8c9f94#8c9f94b71607eacfc740ab3d677f676d20796ace" +source = "git+https://github.com/openvm-org/stark-backend.git?rev=25265e7#25265e7abd59c21ad110e2c673d549801bbf9a17" dependencies = [ "derive_more 0.99.18", "ff 0.13.0", @@ -4517,6 +4524,7 @@ dependencies = [ "openvm-bigint-transpiler", "openvm-build", "openvm-circuit", + "openvm-circuit-derive", "openvm-circuit-primitives-derive", "openvm-ecc-circuit", "openvm-ecc-guest", @@ -5047,7 +5055,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" dependencies = [ "memchr", - "thiserror 2.0.9", + "thiserror 2.0.10", "ucd-trie", ] @@ -6312,9 +6320,9 @@ dependencies = [ [[package]] name = "symbolic-common" -version = "12.12.4" +version = "12.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd33e73f154e36ec223c18013f7064a2c120f1162fc086ac9933542def186b00" +checksum = "bf08b42a6f9469bd8584daee39a1352c8133ccabc5151ccccb15896ef047d99a" dependencies = [ "debugid", "memmap2", @@ -6324,9 +6332,9 @@ dependencies = [ [[package]] name = "symbolic-demangle" -version = "12.12.4" +version = "12.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89e51191290147f071777e37fe111800bb82a9059f9c95b19d2dd41bfeddf477" +checksum = "32f73b5a5bd4da72720c45756a2d11edf110116b87f998bda59b97be8c2c7cf1" dependencies = [ "cpp_demangle", "rustc-demangle", @@ -6480,11 +6488,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.9" +version = "2.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc" +checksum = "a3ac7f54ca534db81081ef1c1e7f6ea8a3ef428d2fc069097c079443d24124d3" dependencies = [ - "thiserror-impl 2.0.9", + "thiserror-impl 2.0.10", ] [[package]] @@ -6500,9 +6508,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.9" +version = "2.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" +checksum = "9e9465d30713b56a37ede7185763c3492a91be2f5fa68d958c44e41ab9248beb" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index ee5ebc186d..f06dff4aae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -103,8 +103,8 @@ lto = "thin" [workspace.dependencies] # Stark Backend -openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", rev = "8c9f94", default-features = false } -openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", rev = "8c9f94", default-features = false } +openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", rev = "25265e7", default-features = false } +openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", rev = "25265e7", default-features = false } # OpenVM openvm-sdk = { path = "crates/sdk", default-features = false } diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index bce3002295..456fb54b34 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -15,6 +15,7 @@ openvm-sdk.workspace = true openvm-stark-backend.workspace = true openvm-stark-sdk.workspace = true openvm-transpiler.workspace = true +openvm-circuit-derive.workspace = true openvm-algebra-circuit.workspace = true openvm-algebra-transpiler.workspace = true diff --git a/crates/circuits/primitives/Cargo.toml b/crates/circuits/primitives/Cargo.toml index 1197cff018..70f3c700a8 100644 --- a/crates/circuits/primitives/Cargo.toml +++ b/crates/circuits/primitives/Cargo.toml @@ -19,6 +19,7 @@ num-bigint-dig.workspace = true num-traits.workspace = true lazy_static.workspace = true tracing.workspace = true +bitcode.workspace = true [dev-dependencies] p3-dft = { workspace = true } diff --git a/crates/circuits/primitives/src/bitwise_op_lookup/mod.rs b/crates/circuits/primitives/src/bitwise_op_lookup/mod.rs index 0cc8245dc7..3d214e2b12 100644 --- a/crates/circuits/primitives/src/bitwise_op_lookup/mod.rs +++ b/crates/circuits/primitives/src/bitwise_op_lookup/mod.rs @@ -1,9 +1,13 @@ use std::{ borrow::{Borrow, BorrowMut}, mem::size_of, - sync::{atomic::AtomicU32, Arc}, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, }; +use itertools::Itertools; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, @@ -13,7 +17,7 @@ use openvm_stark_backend::{ p3_matrix::{dense::RowMajorMatrix, Matrix}, prover::types::AirProofInput, rap::{get_air_name, AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, - Chip, ChipUsageGetter, + Chip, ChipUsageGetter, Stateful, }; mod bus; @@ -109,6 +113,11 @@ pub struct BitwiseOperationLookupChip { count_xor: Vec, } +#[derive(Clone)] +pub struct SharedBitwiseOperationLookupChip( + Arc>, +); + impl BitwiseOperationLookupChip { pub fn new(bus: BitwiseOperationLookupBus) -> Self { let num_rows = (1 << NUM_BITS) * (1 << NUM_BITS); @@ -169,6 +178,35 @@ impl BitwiseOperationLookupChip { } } +impl SharedBitwiseOperationLookupChip { + pub fn new(bus: BitwiseOperationLookupBus) -> Self { + Self(Arc::new(BitwiseOperationLookupChip::new(bus))) + } + pub fn bus(&self) -> BitwiseOperationLookupBus { + self.0.bus() + } + + pub fn air_width(&self) -> usize { + self.0.air_width() + } + + pub fn request_range(&self, x: u32, y: u32) { + self.0.request_range(x, y); + } + + pub fn request_xor(&self, x: u32, y: u32) -> u32 { + self.0.request_xor(x, y) + } + + pub fn clear(&self) { + self.0.clear() + } + + pub fn generate_trace(&self) -> RowMajorMatrix { + self.0.generate_trace() + } +} + impl Chip for BitwiseOperationLookupChip { @@ -182,6 +220,18 @@ impl Chip } } +impl Chip + for SharedBitwiseOperationLookupChip +{ + fn air(&self) -> Arc> { + self.0.air() + } + + fn generate_air_proof_input(self) -> AirProofInput { + self.0.generate_air_proof_input() + } +} + impl ChipUsageGetter for BitwiseOperationLookupChip { fn air_name(&self) -> String { get_air_name(&self.air) @@ -196,3 +246,34 @@ impl ChipUsageGetter for BitwiseOperationLookupChip ChipUsageGetter for SharedBitwiseOperationLookupChip { + fn air_name(&self) -> String { + self.0.air_name() + } + + fn current_trace_height(&self) -> usize { + self.0.current_trace_height() + } + + fn trace_width(&self) -> usize { + self.0.trace_width() + } +} + +impl Stateful> for SharedBitwiseOperationLookupChip { + fn load_state(&mut self, state: Vec) { + // AtomicU32 can be deserialized as u32 + let (count_range, count_xor): (Vec, Vec) = bitcode::deserialize(&state).unwrap(); + for (x, v) in self.0.count_range.iter().zip_eq(count_range) { + x.store(v, Ordering::Relaxed); + } + for (x, v) in self.0.count_xor.iter().zip_eq(count_xor) { + x.store(v, Ordering::Relaxed); + } + } + + fn store_state(&self) -> Vec { + bitcode::serialize(&(&self.0.count_range, &self.0.count_xor)).unwrap() + } +} diff --git a/crates/circuits/primitives/src/range_tuple/mod.rs b/crates/circuits/primitives/src/range_tuple/mod.rs index 16644c7944..d4f5859d15 100644 --- a/crates/circuits/primitives/src/range_tuple/mod.rs +++ b/crates/circuits/primitives/src/range_tuple/mod.rs @@ -5,9 +5,13 @@ use std::{ mem::size_of, - sync::{atomic::AtomicU32, Arc}, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, }; +use itertools::Itertools; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, @@ -16,7 +20,7 @@ use openvm_stark_backend::{ p3_matrix::{dense::RowMajorMatrix, Matrix}, prover::types::AirProofInput, rap::{get_air_name, AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, - Chip, ChipUsageGetter, + Chip, ChipUsageGetter, Stateful, }; mod bus; @@ -100,6 +104,9 @@ pub struct RangeTupleCheckerChip { count: Vec>, } +#[derive(Debug, Clone)] +pub struct SharedRangeTupleCheckerChip(Arc>); + impl RangeTupleCheckerChip { pub fn new(bus: RangeTupleCheckerBus) -> Self { let range_max = bus.sizes.iter().product(); @@ -152,6 +159,31 @@ impl RangeTupleCheckerChip { } } +impl SharedRangeTupleCheckerChip { + pub fn new(bus: RangeTupleCheckerBus) -> Self { + Self(Arc::new(RangeTupleCheckerChip::new(bus))) + } + pub fn bus(&self) -> &RangeTupleCheckerBus { + self.0.bus() + } + + pub fn sizes(&self) -> &[u32; N] { + self.0.sizes() + } + + pub fn add_count(&self, ids: &[u32]) { + self.0.add_count(ids); + } + + pub fn clear(&self) { + self.0.clear(); + } + + pub fn generate_trace(&self) -> RowMajorMatrix { + self.0.generate_trace() + } +} + impl Chip for RangeTupleCheckerChip where Val: PrimeField32, @@ -166,6 +198,19 @@ where } } +impl Chip for SharedRangeTupleCheckerChip +where + Val: PrimeField32, +{ + fn air(&self) -> Arc> { + self.0.air() + } + + fn generate_air_proof_input(self) -> AirProofInput { + self.0.generate_air_proof_input() + } +} + impl ChipUsageGetter for RangeTupleCheckerChip { fn air_name(&self) -> String { get_air_name(&self.air) @@ -180,3 +225,30 @@ impl ChipUsageGetter for RangeTupleCheckerChip { NUM_RANGE_TUPLE_COLS } } + +impl ChipUsageGetter for SharedRangeTupleCheckerChip { + fn air_name(&self) -> String { + self.0.air_name() + } + + fn current_trace_height(&self) -> usize { + self.0.current_trace_height() + } + + fn trace_width(&self) -> usize { + self.0.trace_width() + } +} + +impl Stateful> for SharedRangeTupleCheckerChip { + fn load_state(&mut self, state: Vec) { + let vals: Vec = bitcode::deserialize(&state).unwrap(); + for (x, v) in self.0.count.iter().zip_eq(vals) { + x.store(v, Ordering::Relaxed); + } + } + + fn store_state(&self) -> Vec { + bitcode::serialize(&self.0.count).unwrap() + } +} diff --git a/crates/circuits/sha256-air/src/tests.rs b/crates/circuits/sha256-air/src/tests.rs index 3f8ead2736..92fee92f69 100644 --- a/crates/circuits/sha256-air/src/tests.rs +++ b/crates/circuits/sha256-air/src/tests.rs @@ -4,7 +4,7 @@ use openvm_circuit::arch::{ instructions::riscv::RV32_CELL_BITS, testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, SubAir, }; use openvm_stark_backend::{ @@ -46,7 +46,7 @@ impl Air for Sha256TestAir { // A wrapper Chip purely for testing purposes pub struct Sha256TestChip { pub air: Sha256TestAir, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, pub records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, } @@ -62,7 +62,7 @@ where let air = self.air(); let trace = crate::generate_trace::>( &self.air.sub_air, - &self.bitwise_lookup_chip, + self.bitwise_lookup_chip.clone(), self.records, ); AirProofInput::simple(air, trace, vec![]) @@ -88,9 +88,7 @@ fn rand_sha256_test() { let mut rng = create_seeded_rng(); let tester = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let len = rng.gen_range(1..100); let random_records: Vec<_> = (0..len) .map(|_| (array::from_fn(|_| rng.gen::()), true)) diff --git a/crates/circuits/sha256-air/src/trace.rs b/crates/circuits/sha256-air/src/trace.rs index 45477a472c..734f724ec6 100644 --- a/crates/circuits/sha256-air/src/trace.rs +++ b/crates/circuits/sha256-air/src/trace.rs @@ -1,7 +1,7 @@ use std::{array, borrow::BorrowMut, ops::Range}; use openvm_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupChip, utils::next_power_of_two_or_zero, + bitwise_op_lookup::SharedBitwiseOperationLookupChip, utils::next_power_of_two_or_zero, }; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, @@ -50,7 +50,7 @@ impl Sha256Air { trace_width: usize, trace_start_col: usize, input: &[u32; SHA256_BLOCK_WORDS], - bitwise_lookup_chip: &BitwiseOperationLookupChip<8>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, prev_hash: &[u32; SHA256_HASH_WORDS], is_last_block: bool, global_block_idx: u32, @@ -463,7 +463,7 @@ impl Sha256Air { /// `records` consists of pairs of `(input_block, is_last_block)`. pub fn generate_trace( sub_air: &Sha256Air, - bitwise_lookup_chip: &BitwiseOperationLookupChip<8>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, ) -> RowMajorMatrix { let non_padded_height = records.len() * SHA256_ROWS_PER_BLOCK; @@ -521,7 +521,7 @@ pub fn generate_trace( width, 0, &input_words, - bitwise_lookup_chip, + bitwise_lookup_chip.clone(), &prev_hash, is_last_block, global_block_idx, diff --git a/crates/toolchain/instructions/src/lib.rs b/crates/toolchain/instructions/src/lib.rs index d91df2c013..0a46753e63 100644 --- a/crates/toolchain/instructions/src/lib.rs +++ b/crates/toolchain/instructions/src/lib.rs @@ -89,7 +89,19 @@ pub enum PublishOpcode { } #[derive( - Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, UsizeOpcode, + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + EnumCount, + EnumIter, + FromRepr, + UsizeOpcode, + Serialize, + Deserialize, )] #[opcode_offset = 0x150] #[repr(usize)] diff --git a/crates/toolchain/tests/Cargo.toml b/crates/toolchain/tests/Cargo.toml index a17e18e0d2..e96f7b8159 100644 --- a/crates/toolchain/tests/Cargo.toml +++ b/crates/toolchain/tests/Cargo.toml @@ -44,6 +44,7 @@ tempfile.workspace = true serde = { workspace = true, features = ["alloc"] } rand = { workspace = true } derive_more = { workspace = true, features = ["from"] } +openvm-circuit-derive.workspace = true [target.'cfg(not(target_os = "zkvm"))'.dependencies] num-bigint-dig.workspace = true diff --git a/crates/vm/Cargo.toml b/crates/vm/Cargo.toml index 25708a8a1b..f9c9b60145 100644 --- a/crates/vm/Cargo.toml +++ b/crates/vm/Cargo.toml @@ -19,6 +19,7 @@ openvm-circuit-derive = { workspace = true } openvm-instructions = { workspace = true } openvm-stark-sdk = { workspace = true, optional = true } +bitcode.workspace = true itertools.workspace = true tracing.workspace = true derive-new.workspace = true diff --git a/crates/vm/derive/src/lib.rs b/crates/vm/derive/src/lib.rs index 8c09c43f07..385b99696c 100644 --- a/crates/vm/derive/src/lib.rs +++ b/crates/vm/derive/src/lib.rs @@ -113,6 +113,94 @@ pub fn instruction_executor_derive(input: TokenStream) -> TokenStream { } } +#[proc_macro_derive(Stateful)] +pub fn stateful_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + let generics = &ast.generics; + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + match &ast.data { + Data::Struct(inner) => { + // Check if the struct has only one unnamed field + let inner_ty = match &inner.fields { + Fields::Unnamed(fields) => { + if fields.unnamed.len() != 1 { + panic!("Only one unnamed field is supported"); + } + fields.unnamed.first().unwrap().ty.clone() + } + _ => panic!("Only unnamed fields are supported"), + }; + // Use full path ::openvm_circuit... so it can be used either within or outside the vm crate. + // Assume F is already generic of the field. + let mut new_generics = generics.clone(); + let where_clause = new_generics.make_where_clause(); + where_clause + .predicates + .push(syn::parse_quote! { #inner_ty: ::openvm_stark_backend::Stateful> }); + + quote! { + impl #impl_generics ::openvm_stark_backend::Stateful> for #name #ty_generics #where_clause { + fn load_state(&mut self, state: Vec) { + self.0.load_state(state) + } + + fn store_state(&self) -> Vec { + self.0.store_state() + } + } + } + .into() + } + Data::Enum(e) => { + let variants = e + .variants + .iter() + .map(|variant| { + let variant_name = &variant.ident; + + let mut fields = variant.fields.iter(); + let field = fields.next().unwrap(); + assert!(fields.next().is_none(), "Only one field is supported"); + (variant_name, field) + }) + .collect::>(); + // Use full path ::openvm_stark_backend... so it can be used either within or outside the vm crate. + let (load_state_arms, store_state_arms): (Vec<_>, Vec<_>) = + multiunzip(variants.iter().map(|(variant_name, field)| { + let field_ty = &field.ty; + let load_state_arm = quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful>>::load_state(x, state) + }; + let store_state_arm = quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful>>::store_state(x) + }; + + (load_state_arm, store_state_arm) + })); + quote! { + impl #impl_generics ::openvm_stark_backend::Stateful> for #name #ty_generics { + fn load_state(&mut self, state: Vec) { + match self { + #(#load_state_arms,)* + } + } + + fn store_state(&self) -> Vec { + match self { + #(#store_state_arms,)* + } + } + } + } + .into() + } + _ => unimplemented!(), + } +} + /// Derives `AnyEnum` trait on an enum type. /// By default an enum arm will just return `self` as `&dyn Any`. /// @@ -308,7 +396,7 @@ pub fn vm_generic_config_derive(input: proc_macro::TokenStream) -> proc_macro::T let executor_type = Ident::new(&format!("{}Executor", name), name.span()); let periphery_type = Ident::new(&format!("{}Periphery", name), name.span()); TokenStream::from(quote! { - #[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] + #[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, ::openvm_circuit_derive::Stateful)] pub enum #executor_type { #[any_enum] #system_name_upper(SystemExecutor), diff --git a/crates/vm/src/arch/extensions.rs b/crates/vm/src/arch/extensions.rs index 55d68fe4c3..3018c30b0b 100644 --- a/crates/vm/src/arch/extensions.rs +++ b/crates/vm/src/arch/extensions.rs @@ -9,7 +9,7 @@ use derive_more::derive::From; use getset::Getters; #[cfg(feature = "bench-metrics")] use metrics::counter; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful}; use openvm_circuit_primitives::{ utils::next_power_of_two_or_zero, var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}, @@ -489,13 +489,13 @@ impl SystemBase { } } -#[derive(ChipUsageGetter, Chip, AnyEnum, From, InstructionExecutor)] +#[derive(ChipUsageGetter, Chip, AnyEnum, From, InstructionExecutor, Stateful)] pub enum SystemExecutor { PublicValues(PublicValuesChip), Phantom(RefCell>), } -#[derive(ChipUsageGetter, Chip, AnyEnum, From)] +#[derive(ChipUsageGetter, Chip, AnyEnum, From, Stateful)] pub enum SystemPeriphery { /// Poseidon2 chip with direct compression interactions Poseidon2(Poseidon2PeripheryChip), diff --git a/crates/vm/src/arch/integration_api.rs b/crates/vm/src/arch/integration_api.rs index c58436813a..9dc4124dfa 100644 --- a/crates/vm/src/arch/integration_api.rs +++ b/crates/vm/src/arch/integration_api.rs @@ -17,7 +17,7 @@ use openvm_stark_backend::{ p3_maybe_rayon::prelude::*, prover::types::AirProofInput, rap::{get_air_name, AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, - Chip, ChipUsageGetter, + Chip, ChipUsageGetter, Stateful, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -209,6 +209,18 @@ where } } +impl, C: VmCoreChip> Stateful> + for VmChipWrapper +{ + fn load_state(&mut self, state: Vec) { + self.records = bitcode::deserialize(&state).unwrap(); + } + + fn store_state(&self) -> Vec { + bitcode::serialize(&self.records).unwrap() + } +} + impl InstructionExecutor for VmChipWrapper where F: PrimeField32, diff --git a/crates/vm/src/system/phantom/mod.rs b/crates/vm/src/system/phantom/mod.rs index 84c9ce763a..2bdf4deb26 100644 --- a/crates/vm/src/system/phantom/mod.rs +++ b/crates/vm/src/system/phantom/mod.rs @@ -17,9 +17,11 @@ use openvm_stark_backend::{ p3_maybe_rayon::prelude::*, prover::types::AirProofInput, rap::{get_air_name, AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, - Chip, ChipUsageGetter, + Chip, ChipUsageGetter, Stateful, }; use rustc_hash::FxHashMap; +use serde::{Deserialize, Serialize}; +use serde_big_array::BigArray; use super::memory::MemoryController; use crate::{ @@ -44,9 +46,10 @@ pub struct PhantomAir { pub phantom_opcode: VmOpcode, } -#[derive(AlignedBorrow, Copy, Clone)] +#[derive(AlignedBorrow, Copy, Clone, Serialize, Deserialize)] pub struct PhantomCols { pub pc: T, + #[serde(with = "BigArray")] pub operands: [T; NUM_PHANTOM_OPERANDS], pub timestamp: T, pub is_valid: T, @@ -217,3 +220,13 @@ where AirProofInput::simple(self.air(), trace, vec![]) } } + +impl Stateful> for PhantomChip { + fn load_state(&mut self, state: Vec) { + self.rows = bitcode::deserialize(&state).unwrap(); + } + + fn store_state(&self) -> Vec { + bitcode::serialize(&self.rows).unwrap() + } +} diff --git a/crates/vm/src/system/poseidon2/chip.rs b/crates/vm/src/system/poseidon2/chip.rs index 335ed43a84..6296907eed 100644 --- a/crates/vm/src/system/poseidon2/chip.rs +++ b/crates/vm/src/system/poseidon2/chip.rs @@ -4,7 +4,7 @@ use std::{ }; use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubChip}; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_stark_backend::{p3_field::PrimeField32, Stateful}; use rustc_hash::FxHashMap; use super::{ @@ -71,3 +71,15 @@ impl HasherChip Stateful> + for Poseidon2PeripheryBaseChip +{ + fn load_state(&mut self, state: Vec) { + self.records = bitcode::deserialize(&state).unwrap(); + } + + fn store_state(&self) -> Vec { + bitcode::serialize(&self.records).unwrap() + } +} diff --git a/crates/vm/src/system/poseidon2/mod.rs b/crates/vm/src/system/poseidon2/mod.rs index f24dd496ed..2c31feae4e 100644 --- a/crates/vm/src/system/poseidon2/mod.rs +++ b/crates/vm/src/system/poseidon2/mod.rs @@ -24,6 +24,7 @@ pub mod tests; pub mod air; mod chip; pub use chip::*; +use openvm_circuit_derive::Stateful; use crate::arch::hasher::{Hasher, HasherChip}; pub mod columns; @@ -32,6 +33,7 @@ pub mod trace; pub const PERIPHERY_POSEIDON2_WIDTH: usize = 16; pub const PERIPHERY_POSEIDON2_CHUNK_SIZE: usize = 8; +#[derive(Stateful)] pub enum Poseidon2PeripheryChip { Register0(Poseidon2PeripheryBaseChip), Register1(Poseidon2PeripheryBaseChip), diff --git a/extensions/algebra/circuit/src/fp2_chip/addsub.rs b/extensions/algebra/circuit/src/fp2_chip/addsub.rs index 1a89dbbd8f..76dd6a575d 100644 --- a/extensions/algebra/circuit/src/fp2_chip/addsub.rs +++ b/extensions/algebra/circuit/src/fp2_chip/addsub.rs @@ -6,7 +6,7 @@ use std::{ use openvm_algebra_transpiler::Fp2Opcode; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_derive::{InstructionExecutor, Stateful}; use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ @@ -19,7 +19,7 @@ use crate::Fp2; // Input: Fp2 * 2 // Output: Fp2 -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] pub struct Fp2AddSubChip( pub VmChipWrapper< F, @@ -85,14 +85,13 @@ pub fn fp2_addsub_expr( #[cfg(test)] mod tests { - use std::sync::Arc; use halo2curves_axiom::{bn256::Fq2, ff::Field}; use itertools::Itertools; use openvm_algebra_transpiler::Fp2Opcode; use openvm_circuit::arch::{testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{riscv::RV32_CELL_BITS, UsizeOpcode}; use openvm_mod_circuit_builder::{ @@ -121,9 +120,7 @@ mod tests { limb_bits: LIMB_BITS, }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapAdapterChip::::new( tester.execution_bus(), tester.program_bus(), diff --git a/extensions/algebra/circuit/src/fp2_chip/muldiv.rs b/extensions/algebra/circuit/src/fp2_chip/muldiv.rs index 727d6a192f..8701dc0aa5 100644 --- a/extensions/algebra/circuit/src/fp2_chip/muldiv.rs +++ b/extensions/algebra/circuit/src/fp2_chip/muldiv.rs @@ -6,7 +6,7 @@ use std::{ use openvm_algebra_transpiler::Fp2Opcode; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_derive::{InstructionExecutor, Stateful}; use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ @@ -19,7 +19,7 @@ use crate::Fp2; // Input: Fp2 * 2 // Output: Fp2 -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] pub struct Fp2MulDivChip( pub VmChipWrapper< F, @@ -124,14 +124,13 @@ pub fn fp2_muldiv_expr( #[cfg(test)] mod tests { - use std::sync::Arc; use halo2curves_axiom::{bn256::Fq2, ff::Field}; use itertools::Itertools; use openvm_algebra_transpiler::Fp2Opcode; use openvm_circuit::arch::{testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{riscv::RV32_CELL_BITS, UsizeOpcode}; use openvm_mod_circuit_builder::{ @@ -160,9 +159,7 @@ mod tests { limb_bits: LIMB_BITS, }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapAdapterChip::::new( tester.execution_bus(), tester.program_bus(), diff --git a/extensions/algebra/circuit/src/fp2_extension.rs b/extensions/algebra/circuit/src/fp2_extension.rs index 9d1628c9ec..6ab21eaa3e 100644 --- a/extensions/algebra/circuit/src/fp2_extension.rs +++ b/extensions/algebra/circuit/src/fp2_extension.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use derive_more::derive::From; use num_bigint_dig::BigUint; use openvm_algebra_transpiler::Fp2Opcode; @@ -7,9 +5,9 @@ use openvm_circuit::{ arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{UsizeOpcode, VmOpcode}; @@ -29,7 +27,7 @@ pub struct Fp2Extension { pub supported_modulus: Vec, } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From, Stateful)] pub enum Fp2ExtensionExecutor { // 32 limbs prime Fp2AddSubRv32_32(Fp2AddSubChip), @@ -39,9 +37,9 @@ pub enum Fp2ExtensionExecutor { Fp2MulDivRv32_48(Fp2MulDivChip), } -#[derive(ChipUsageGetter, Chip, AnyEnum, From)] +#[derive(ChipUsageGetter, Chip, AnyEnum, From, Stateful)] pub enum Fp2ExtensionPeriphery { - BitwiseOperationLookup(Arc>), + BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), // We put this only to get the generic to work Phantom(PhantomChip), } @@ -60,14 +58,14 @@ impl VmExtension for Fp2Extension { program_bus, memory_bridge, } = builder.system_port(); - let bitwise_lu_chip = if let Some(chip) = builder - .find_chip::>>() + let bitwise_lu_chip = if let Some(&chip) = builder + .find_chip::>() .first() { - Arc::clone(chip) + chip.clone() } else { let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = Arc::new(BitwiseOperationLookupChip::new(bitwise_lu_bus)); + let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); inventory.add_periphery_chip(chip.clone()); chip }; diff --git a/extensions/algebra/circuit/src/modular_chip/is_eq.rs b/extensions/algebra/circuit/src/modular_chip/is_eq.rs index 7c4d4491e5..c950fb47fd 100644 --- a/extensions/algebra/circuit/src/modular_chip/is_eq.rs +++ b/extensions/algebra/circuit/src/modular_chip/is_eq.rs @@ -1,7 +1,6 @@ use std::{ array::{self, from_fn}, borrow::{Borrow, BorrowMut}, - sync::Arc, }; use num_bigint_dig::BigUint; @@ -12,7 +11,7 @@ use openvm_circuit::arch::{ }; use openvm_circuit_primitives::{ bigint::utils::big_uint_to_limbs, - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, is_equal_array::{IsEqArrayIo, IsEqArraySubAir}, SubAir, TraceSubRowGenerator, }; @@ -255,7 +254,7 @@ pub struct ModularIsEqualCoreChip< const LIMB_BITS: usize, > { pub air: ModularIsEqualCoreAir, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } impl @@ -263,7 +262,7 @@ impl { pub fn new( modulus: BigUint, - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, offset: usize, ) -> Self { Self { diff --git a/extensions/algebra/circuit/src/modular_chip/tests.rs b/extensions/algebra/circuit/src/modular_chip/tests.rs index b870fffebe..729cfe5461 100644 --- a/extensions/algebra/circuit/src/modular_chip/tests.rs +++ b/extensions/algebra/circuit/src/modular_chip/tests.rs @@ -1,4 +1,4 @@ -use std::{array::from_fn, sync::Arc}; +use std::array::from_fn; use num_bigint_dig::BigUint; use num_traits::Zero; @@ -10,7 +10,7 @@ use openvm_circuit_primitives::{ bigint::utils::{ big_uint_mod_inverse, big_uint_to_limbs, secp256k1_coord_prime, secp256k1_scalar_prime, }, - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, }; use openvm_instructions::{instruction::Instruction, riscv::RV32_CELL_BITS, VmOpcode}; use openvm_mod_circuit_builder::{ @@ -65,9 +65,7 @@ fn test_addsub(opcode_offset: usize, modulus: BigUint) { Rv32ModularArithmeticOpcode::default_offset() + opcode_offset, ); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); // doing 1xNUM_LIMBS reads and writes let adapter = Rv32VecHeapAdapterChip::::new( @@ -195,9 +193,7 @@ fn test_muldiv(opcode_offset: usize, modulus: BigUint) { Rv32ModularArithmeticOpcode::default_offset() + opcode_offset, ); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); // doing 1xNUM_LIMBS reads and writes let adapter = Rv32VecHeapAdapterChip::::new( tester.execution_bus(), @@ -305,7 +301,7 @@ fn test_is_equal::new(bitwise_bus)); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); let mut chip = ModularIsEqualChip::::new( diff --git a/extensions/algebra/circuit/src/modular_extension.rs b/extensions/algebra/circuit/src/modular_extension.rs index 7b51798957..8fdc9770b1 100644 --- a/extensions/algebra/circuit/src/modular_extension.rs +++ b/extensions/algebra/circuit/src/modular_extension.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use derive_more::derive::From; use num_bigint_dig::BigUint; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; @@ -8,9 +6,9 @@ use openvm_circuit::{ arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{UsizeOpcode, VmOpcode}; @@ -33,7 +31,7 @@ pub struct ModularExtension { pub supported_modulus: Vec, } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From, Stateful)] pub enum ModularExtensionExecutor { // 32 limbs prime ModularAddSubRv32_32(ModularAddSubChip), @@ -45,9 +43,9 @@ pub enum ModularExtensionExecutor { ModularIsEqualRv32_48(ModularIsEqualChip), } -#[derive(ChipUsageGetter, Chip, AnyEnum, From)] +#[derive(ChipUsageGetter, Chip, AnyEnum, From, Stateful)] pub enum ModularExtensionPeriphery { - BitwiseOperationLookup(Arc>), + BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), // We put this only to get the generic to work Phantom(PhantomChip), } @@ -67,14 +65,14 @@ impl VmExtension for ModularExtension { memory_bridge, } = builder.system_port(); let range_checker = builder.system_base().range_checker_chip.clone(); - let bitwise_lu_chip = if let Some(chip) = builder - .find_chip::>>() + let bitwise_lu_chip = if let Some(&chip) = builder + .find_chip::>() .first() { - Arc::clone(chip) + chip.clone() } else { let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = Arc::new(BitwiseOperationLookupChip::new(bitwise_lu_bus)); + let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); inventory.add_periphery_chip(chip.clone()); chip }; diff --git a/extensions/bigint/circuit/src/extension.rs b/extensions/bigint/circuit/src/extension.rs index f1410d7d0b..fcfffd860a 100644 --- a/extensions/bigint/circuit/src/extension.rs +++ b/extensions/bigint/circuit/src/extension.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use derive_more::derive::From; use openvm_bigint_transpiler::{ Rv32BaseAlu256Opcode, Rv32BranchEqual256Opcode, Rv32BranchLessThan256Opcode, @@ -12,10 +10,10 @@ use openvm_circuit::{ }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful, VmConfig}; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, - range_tuple::{RangeTupleCheckerBus, RangeTupleCheckerChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{program::DEFAULT_PC_STEP, UsizeOpcode, VmOpcode}; @@ -72,7 +70,7 @@ fn default_range_tuple_checker_sizes() -> [u32; 2] { [1 << 8, 32 * (1 << 8)] } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] pub enum Int256Executor { BaseAlu256(Rv32BaseAlu256Chip), LessThan256(Rv32LessThan256Chip), @@ -82,11 +80,11 @@ pub enum Int256Executor { Shift256(Rv32Shift256Chip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] pub enum Int256Periphery { - BitwiseOperationLookup(Arc>), + BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), /// Only needed for multiplication extension - RangeTupleChecker(Arc>), + RangeTupleChecker(SharedRangeTupleCheckerChip<2>), Phantom(PhantomChip), } @@ -105,14 +103,14 @@ impl VmExtension for Int256 { memory_bridge, } = builder.system_port(); let range_checker_chip = builder.system_base().range_checker_chip.clone(); - let bitwise_lu_chip = if let Some(chip) = builder - .find_chip::>>() + let bitwise_lu_chip = if let Some(&chip) = builder + .find_chip::>() .first() { - Arc::clone(chip) + chip.clone() } else { let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = Arc::new(BitwiseOperationLookupChip::new(bitwise_lu_bus)); + let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); inventory.add_periphery_chip(chip.clone()); chip }; @@ -120,7 +118,7 @@ impl VmExtension for Int256 { let address_bits = builder.system_config().memory_config.pointer_max_bits; let range_tuple_chip = if let Some(chip) = builder - .find_chip::>>() + .find_chip::>() .into_iter() .find(|c| { c.bus().sizes[0] >= self.range_tuple_checker_sizes[0] @@ -130,7 +128,7 @@ impl VmExtension for Int256 { } else { let range_tuple_bus = RangeTupleCheckerBus::new(builder.new_bus_idx(), self.range_tuple_checker_sizes); - let chip = Arc::new(RangeTupleCheckerChip::new(range_tuple_bus)); + let chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); inventory.add_periphery_chip(chip.clone()); chip }; diff --git a/extensions/bigint/circuit/src/tests.rs b/extensions/bigint/circuit/src/tests.rs index 10a240d73f..8727956990 100644 --- a/extensions/bigint/circuit/src/tests.rs +++ b/extensions/bigint/circuit/src/tests.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use openvm_circuit::{ arch::{ testing::VmChipTestBuilder, InstructionExecutor, BITWISE_OP_LOOKUP_BUS, @@ -8,8 +6,8 @@ use openvm_circuit::{ utils::generate_long_number, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, - range_tuple::{RangeTupleCheckerBus, RangeTupleCheckerChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; use openvm_instructions::{program::PC_BITS, riscv::RV32_CELL_BITS, UsizeOpcode}; use openvm_rv32_adapters::{ @@ -85,9 +83,7 @@ fn run_int_256_rand_execute>( fn run_alu_256_rand_test(opcode: BaseAluOpcode, num_ops: usize) { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); @@ -135,9 +131,7 @@ fn alu_256_and_rand_test() { fn run_lt_256_rand_test(opcode: LessThanOpcode, num_ops: usize) { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32LessThan256Chip::::new( @@ -175,11 +169,9 @@ fn run_mul_256_rand_test(num_ops: usize) { (INT256_NUM_LIMBS * (1 << RV32_CELL_BITS)) as u32, ], ); - let range_tuple_checker = Arc::new(RangeTupleCheckerChip::new(range_tuple_bus)); + let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32Multiplication256Chip::::new( @@ -217,9 +209,7 @@ fn mul_256_rand_test() { fn run_shift_256_rand_test(opcode: ShiftOpcode, num_ops: usize) { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32Shift256Chip::::new( @@ -261,9 +251,7 @@ fn shift_256_sra_rand_test() { fn run_beq_256_rand_test(opcode: BranchEqualOpcode, num_ops: usize) { let mut tester = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut chip = Rv32BranchEqual256Chip::::new( Rv32HeapBranchAdapterChip::::new( tester.execution_bus(), @@ -306,9 +294,7 @@ fn beq_256_bne_rand_test() { fn run_blt_256_rand_test(opcode: BranchLessThanOpcode, num_ops: usize) { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32BranchLessThan256Chip::::new( diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mod.rs b/extensions/ecc/circuit/src/weierstrass_chip/mod.rs index 71bbdea2bf..70c681dda0 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/mod.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/mod.rs @@ -13,7 +13,7 @@ use std::sync::Mutex; use num_bigint_dig::BigUint; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_derive::{InstructionExecutor, Stateful}; use openvm_circuit_primitives::var_range::VariableRangeCheckerChip; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_ecc_transpiler::Rv32WeierstrassOpcode; @@ -25,7 +25,7 @@ use openvm_stark_backend::p3_field::PrimeField32; /// BLOCKS: how many blocks do we need to represent one input or output /// For example, for bls12_381, BLOCK_SIZE = 16, each element has 3 blocks and with two elements per input AffinePoint, BLOCKS = 6. /// For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2. -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] pub struct EcAddNeChip( VmChipWrapper< F, @@ -61,7 +61,7 @@ impl } } -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] pub struct EcDoubleChip( VmChipWrapper< F, diff --git a/extensions/ecc/circuit/src/weierstrass_chip/tests.rs b/extensions/ecc/circuit/src/weierstrass_chip/tests.rs index 82d91aa7a6..211b52dad9 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/tests.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/tests.rs @@ -1,11 +1,11 @@ -use std::{str::FromStr, sync::Arc}; +use std::str::FromStr; use num_bigint_dig::BigUint; use num_traits::{FromPrimitive, Num, Zero}; use openvm_circuit::arch::{testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::{ bigint::utils::{secp256k1_coord_prime, secp256r1_coord_prime}, - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, }; use openvm_ecc_transpiler::Rv32WeierstrassOpcode; use openvm_instructions::{riscv::RV32_CELL_BITS, UsizeOpcode}; @@ -86,9 +86,7 @@ fn test_add_ne() { limb_bits: LIMB_BITS, }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapAdapterChip::::new( tester.execution_bus(), tester.program_bus(), @@ -160,9 +158,7 @@ fn test_double() { limb_bits: LIMB_BITS, }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapAdapterChip::::new( tester.execution_bus(), tester.program_bus(), @@ -229,9 +225,7 @@ fn test_p256_double() { ) .unwrap(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapAdapterChip::::new( tester.execution_bus(), tester.program_bus(), diff --git a/extensions/ecc/circuit/src/weierstrass_extension.rs b/extensions/ecc/circuit/src/weierstrass_extension.rs index 835ca88094..0e4920f20c 100644 --- a/extensions/ecc/circuit/src/weierstrass_extension.rs +++ b/extensions/ecc/circuit/src/weierstrass_extension.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use derive_more::derive::From; use num_bigint_dig::BigUint; use num_traits::{FromPrimitive, Zero}; @@ -9,9 +7,9 @@ use openvm_circuit::{ arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_ecc_guest::{ @@ -65,7 +63,7 @@ pub struct WeierstrassExtension { pub supported_curves: Vec, } -#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum, Stateful)] pub enum WeierstrassExtensionExecutor { // 32 limbs prime EcAddNeRv32_32(EcAddNeChip), @@ -75,9 +73,9 @@ pub enum WeierstrassExtensionExecutor { EcDoubleRv32_48(EcDoubleChip), } -#[derive(ChipUsageGetter, Chip, AnyEnum, From)] +#[derive(ChipUsageGetter, Chip, AnyEnum, From, Stateful)] pub enum WeierstrassExtensionPeriphery { - BitwiseOperationLookup(Arc>), + BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), Phantom(PhantomChip), } @@ -95,14 +93,14 @@ impl VmExtension for WeierstrassExtension { program_bus, memory_bridge, } = builder.system_port(); - let bitwise_lu_chip = if let Some(chip) = builder - .find_chip::>>() + let bitwise_lu_chip = if let Some(&chip) = builder + .find_chip::>() .first() { - Arc::clone(chip) + chip.clone() } else { let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = Arc::new(BitwiseOperationLookupChip::new(bitwise_lu_bus)); + let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); inventory.add_periphery_chip(chip.clone()); chip }; diff --git a/extensions/keccak256/circuit/Cargo.toml b/extensions/keccak256/circuit/Cargo.toml index a8b8865cd6..aa9888858f 100644 --- a/extensions/keccak256/circuit/Cargo.toml +++ b/extensions/keccak256/circuit/Cargo.toml @@ -30,6 +30,8 @@ derive_more = { workspace = true, features = ["from"] } rand.workspace = true eyre.workspace = true serde.workspace = true +serde-big-array.workspace = true +bitcode.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } diff --git a/extensions/keccak256/circuit/src/extension.rs b/extensions/keccak256/circuit/src/extension.rs index 7ce2c6d7c4..131c59bf18 100644 --- a/extensions/keccak256/circuit/src/extension.rs +++ b/extensions/keccak256/circuit/src/extension.rs @@ -6,7 +6,7 @@ use openvm_circuit::{ }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful, VmConfig}; use openvm_circuit_primitives::bitwise_op_lookup::BitwiseOperationLookupBus; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::*; @@ -49,14 +49,14 @@ impl Default for Keccak256Rv32Config { #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct Keccak256; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] pub enum Keccak256Executor { Keccak256(KeccakVmChip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] pub enum Keccak256Periphery { - BitwiseOperationLookup(Arc>), + BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), Phantom(PhantomChip), } @@ -74,14 +74,14 @@ impl VmExtension for Keccak256 { program_bus, memory_bridge, } = builder.system_port(); - let bitwise_lu_chip = if let Some(chip) = builder - .find_chip::>>() + let bitwise_lu_chip = if let Some(&chip) = builder + .find_chip::>() .first() { - Arc::clone(chip) + chip.clone() } else { let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = Arc::new(BitwiseOperationLookupChip::new(bitwise_lu_bus)); + let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); inventory.add_periphery_chip(chip.clone()); chip }; diff --git a/extensions/keccak256/circuit/src/lib.rs b/extensions/keccak256/circuit/src/lib.rs index 017a760885..10a217bace 100644 --- a/extensions/keccak256/circuit/src/lib.rs +++ b/extensions/keccak256/circuit/src/lib.rs @@ -6,8 +6,10 @@ use std::{ sync::{Arc, Mutex}, }; -use openvm_circuit_primitives::bitwise_op_lookup::BitwiseOperationLookupChip; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; +use openvm_stark_backend::{p3_field::PrimeField32, Stateful}; +use serde::{Deserialize, Serialize}; +use serde_big_array::BigArray; use tiny_keccak::{Hasher, Keccak}; use utils::num_keccak_f; @@ -71,14 +73,14 @@ pub struct KeccakVmChip { pub air: KeccakVmAir, /// IO and memory data necessary for each opcode call pub records: Vec>, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, offset: usize, offline_memory: Arc>>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct KeccakRecord { pub pc: F, pub dst_read: RecordId, @@ -88,13 +90,14 @@ pub struct KeccakRecord { pub digest_writes: [RecordId; KECCAK_DIGEST_WRITES], } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct KeccakInputBlock { /// Memory reads for non-padding bytes in this block. Length is at most [KECCAK_RATE_BYTES / KECCAK_WORD_SIZE]. pub reads: Vec, /// Index in `reads` of the memory read for < KECCAK_WORD_SIZE bytes, if any. pub partial_read_idx: Option, /// Bytes with padding. Can be derived from `bytes_read` but we store for convenience. + #[serde(with = "BigArray")] pub padded_bytes: [u8; KECCAK_RATE_BYTES], pub remaining_len: usize, pub src: usize, @@ -107,7 +110,7 @@ impl KeccakVmChip { program_bus: ProgramBus, memory_bridge: MemoryBridge, address_bits: usize, - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, offset: usize, offline_memory: Arc>>, ) -> Self { @@ -280,3 +283,13 @@ impl Default for KeccakInputBlock { } } } + +impl Stateful> for KeccakVmChip { + fn load_state(&mut self, state: Vec) { + self.records = bitcode::deserialize(&state).unwrap(); + } + + fn store_state(&self) -> Vec { + bitcode::serialize(&self.records).unwrap() + } +} diff --git a/extensions/keccak256/circuit/src/tests.rs b/extensions/keccak256/circuit/src/tests.rs index 4774a18da8..5a47b21ecc 100644 --- a/extensions/keccak256/circuit/src/tests.rs +++ b/extensions/keccak256/circuit/src/tests.rs @@ -1,4 +1,4 @@ -use std::{borrow::BorrowMut, sync::Arc}; +use std::borrow::BorrowMut; use hex::FromHex; use openvm_circuit::arch::{ @@ -6,7 +6,7 @@ use openvm_circuit::arch::{ BITWISE_OP_LOOKUP_BUS, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, VmOpcode}; use openvm_keccak256_transpiler::Rv32KeccakOpcode; @@ -31,7 +31,7 @@ fn build_keccak256_test( io: Vec<(Vec, Option<[u8; 32]>, Option<[u8; 32]>)>, ) -> VmChipTester { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::<8>::new(bitwise_bus)); + let bitwise_chip = SharedBitwiseOperationLookupChip::<8>::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = KeccakVmChip::new( diff --git a/extensions/native/circuit/Cargo.toml b/extensions/native/circuit/Cargo.toml index 8fe5e29ad3..3a806040ac 100644 --- a/extensions/native/circuit/Cargo.toml +++ b/extensions/native/circuit/Cargo.toml @@ -30,6 +30,7 @@ serde.workspace = true serde-big-array.workspace = true serde_with.workspace = true rayon.workspace = true +bitcode.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } diff --git a/extensions/native/circuit/src/extension.rs b/extensions/native/circuit/src/extension.rs index e83bf1f90a..9334b2e0c3 100644 --- a/extensions/native/circuit/src/extension.rs +++ b/extensions/native/circuit/src/extension.rs @@ -10,7 +10,7 @@ use openvm_circuit::{ }, system::{native_adapter::NativeAdapterChip, phantom::PhantomChip}, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful, VmConfig}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{ program::DEFAULT_PC_STEP, PhantomDiscriminant, Poseidon2Opcode, UsizeOpcode, VmOpcode, @@ -70,7 +70,7 @@ impl NativeConfig { #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct Native; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] pub enum NativeExecutor { LoadStore(NativeLoadStoreChip), BlockLoadStore(NativeLoadStoreChip), @@ -82,7 +82,7 @@ pub enum NativeExecutor { FriReducedOpening(FriReducedOpeningChip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] pub enum NativePeriphery { Phantom(PhantomChip), } diff --git a/extensions/native/circuit/src/fri/mod.rs b/extensions/native/circuit/src/fri/mod.rs index 0ddcb6002e..256beb2bb7 100644 --- a/extensions/native/circuit/src/fri/mod.rs +++ b/extensions/native/circuit/src/fri/mod.rs @@ -33,8 +33,9 @@ use openvm_stark_backend::{ p3_maybe_rayon::prelude::*, prover::types::AirProofInput, rap::{AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, - Chip, ChipUsageGetter, + Chip, ChipUsageGetter, Stateful, }; +use serde::{Deserialize, Serialize}; use super::field_extension::{FieldExtension, EXT_DEG}; @@ -294,6 +295,8 @@ impl Air for FriReducedOpeningAir { } } +#[derive(Serialize, Deserialize)] +#[serde(bound = "F: Field")] pub struct FriReducedOpeningRecord { pub pc: F, pub start_timestamp: F, @@ -586,3 +589,13 @@ where AirProofInput::simple_no_pis(self.air(), self.generate_trace()) } } + +impl Stateful> for FriReducedOpeningChip { + fn load_state(&mut self, state: Vec) { + self.records = bitcode::deserialize(&state).unwrap(); + } + + fn store_state(&self) -> Vec { + bitcode::serialize(&self.records).unwrap() + } +} diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index f995af715c..e53975a0ec 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -17,7 +17,11 @@ use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, Poseidon2Opcode, UsizeOpcode, }; use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubChip}; -use openvm_stark_backend::p3_field::{Field, PrimeField32}; +use openvm_stark_backend::{ + p3_field::{Field, PrimeField32}, + Stateful, +}; +use serde::{Deserialize, Serialize}; use super::{ NativePoseidon2Air, NativePoseidon2MemoryCols, NATIVE_POSEIDON2_CHUNK_SIZE, @@ -31,7 +35,7 @@ pub struct NativePoseidon2BaseChip { pub offline_memory: Arc>>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct NativePoseidon2ChipRecord { pub from_state: ExecutionState, pub opcode: Poseidon2Opcode, @@ -216,3 +220,15 @@ impl NativePoseidon2ChipRecord { } } } + +impl Stateful> + for NativePoseidon2BaseChip +{ + fn load_state(&mut self, state: Vec) { + self.records = bitcode::deserialize(&state).unwrap() + } + + fn store_state(&self) -> Vec { + bitcode::serialize(&self.records).unwrap() + } +} diff --git a/extensions/native/circuit/src/poseidon2/mod.rs b/extensions/native/circuit/src/poseidon2/mod.rs index 2f319a9435..d16256e850 100644 --- a/extensions/native/circuit/src/poseidon2/mod.rs +++ b/extensions/native/circuit/src/poseidon2/mod.rs @@ -23,6 +23,7 @@ use std::sync::Mutex; pub use columns::*; use openvm_circuit::system::memory::{offline_checker::MemoryBridge, OfflineMemory}; +use openvm_circuit_derive::Stateful; mod trace; @@ -32,6 +33,7 @@ mod tests; pub const NATIVE_POSEIDON2_WIDTH: usize = 16; pub const NATIVE_POSEIDON2_CHUNK_SIZE: usize = 8; +#[derive(Stateful)] pub enum NativePoseidon2Chip { Register0(NativePoseidon2BaseChip), Register1(NativePoseidon2BaseChip), diff --git a/extensions/pairing/circuit/src/fp12_chip/mul.rs b/extensions/pairing/circuit/src/fp12_chip/mul.rs index f7129c4244..fcb44c3044 100644 --- a/extensions/pairing/circuit/src/fp12_chip/mul.rs +++ b/extensions/pairing/circuit/src/fp12_chip/mul.rs @@ -5,7 +5,7 @@ use std::{ }; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_derive::{InstructionExecutor, Stateful}; use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ @@ -18,7 +18,7 @@ use openvm_stark_backend::p3_field::PrimeField32; use crate::Fp12; // Input: Fp12 * 2 // Output: Fp12 -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] pub struct Fp12MulChip( pub VmChipWrapper< F, @@ -72,13 +72,11 @@ pub fn fp12_mul_expr( #[cfg(test)] mod tests { - use std::sync::Arc; - use halo2curves_axiom::{bn256::Fq12, ff::Field}; use itertools::Itertools; use openvm_circuit::arch::{testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_ecc_guest::algebra::field::FieldExtension; use openvm_instructions::{riscv::RV32_CELL_BITS, UsizeOpcode}; @@ -109,9 +107,7 @@ mod tests { limb_bits: LIMB_BITS, }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapAdapterChip::::new( tester.execution_bus(), tester.program_bus(), diff --git a/extensions/pairing/circuit/src/fp12_chip/tests.rs b/extensions/pairing/circuit/src/fp12_chip/tests.rs index 747639475c..1e58f03181 100644 --- a/extensions/pairing/circuit/src/fp12_chip/tests.rs +++ b/extensions/pairing/circuit/src/fp12_chip/tests.rs @@ -1,9 +1,7 @@ -use std::sync::Arc; - use num_bigint_dig::BigUint; use openvm_circuit::arch::{testing::VmChipTestBuilder, VmChipWrapper, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{riscv::RV32_CELL_BITS, UsizeOpcode}; use openvm_mod_circuit_builder::{ @@ -54,9 +52,7 @@ fn test_fp12_fn< false, ); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapAdapterChip::::new( diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs index 36d6ce9dba..77e53c25cc 100644 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs +++ b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs @@ -6,7 +6,7 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_derive::{InstructionExecutor, Stateful}; use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ @@ -18,7 +18,7 @@ use openvm_stark_backend::p3_field::PrimeField32; // Input: line0.b, line0.c, line1.b, line1.c : 2 x 4 field elements // Output: 5 Fp2 coefficients -> 10 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] pub struct EcLineMul013By013Chip< F: PrimeField32, const INPUT_BLOCKS: usize, diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs index 7a5e636f24..e846b42339 100644 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs +++ b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs @@ -6,7 +6,7 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_derive::{InstructionExecutor, Stateful}; use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ @@ -20,7 +20,7 @@ use crate::Fp12; // Input: Fp12 (12 field elements), [Fp2; 5] (5 x 2 field elements) // Output: Fp12 (12 field elements) -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] pub struct EcLineMulBy01234Chip< F: PrimeField32, const INPUT_BLOCKS1: usize, diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/tests.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/tests.rs index c05c45fa34..bc28441cbf 100644 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/tests.rs +++ b/extensions/pairing/circuit/src/pairing_chip/line/d_type/tests.rs @@ -1,12 +1,10 @@ -use std::sync::Arc; - use halo2curves_axiom::{ bn256::{Fq, Fq12, Fq2, G1Affine}, ff::Field, }; use openvm_circuit::arch::{testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_ecc_guest::AffinePoint; use openvm_instructions::{riscv::RV32_CELL_BITS, UsizeOpcode}; @@ -41,9 +39,7 @@ const BLOCK_SIZE: usize = 32; fn test_mul_013_by_013() { let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapAdapterChip::::new( tester.execution_bus(), tester.program_bus(), @@ -139,9 +135,7 @@ fn test_mul_013_by_013() { fn test_mul_by_01234() { let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapTwoReadsAdapterChip::::new( tester.execution_bus(), tester.program_bus(), @@ -234,9 +228,7 @@ fn test_evaluate_line() { num_limbs: BN254_NUM_LIMBS, }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapTwoReadsAdapterChip::::new( tester.execution_bus(), tester.program_bus(), diff --git a/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs b/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs index 5677e72836..5ceda2bc57 100644 --- a/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs +++ b/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs @@ -6,7 +6,7 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_derive::{InstructionExecutor, Stateful}; use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ @@ -18,7 +18,7 @@ use openvm_stark_backend::p3_field::PrimeField32; // Input: UnevaluatedLine, (Fp, Fp) // Output: EvaluatedLine -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] pub struct EvaluateLineChip< F: PrimeField32, const INPUT_BLOCKS1: usize, diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs index 144192cd9a..d0e4f56c4d 100644 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs +++ b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs @@ -6,7 +6,7 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_derive::{InstructionExecutor, Stateful}; use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ @@ -18,7 +18,7 @@ use openvm_stark_backend::p3_field::PrimeField32; // Input: line0.b, line0.c, line1.b, line1.c : 2 x 4 field elements // Output: 5 Fp2 coefficients -> 10 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] pub struct EcLineMul023By023Chip< F: PrimeField32, const INPUT_BLOCKS: usize, diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs index 04dc1640a0..170a8a4c3d 100644 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs +++ b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs @@ -6,7 +6,7 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_derive::{InstructionExecutor, Stateful}; use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ @@ -20,7 +20,7 @@ use crate::Fp12; // Input: 2 Fp12: 2 x 12 field elements // Output: Fp12 -> 12 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] pub struct EcLineMulBy02345Chip< F: PrimeField32, const INPUT_BLOCKS1: usize, diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/tests.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/tests.rs index 2cba9b81da..8e8cabd544 100644 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/tests.rs +++ b/extensions/pairing/circuit/src/pairing_chip/line/m_type/tests.rs @@ -1,12 +1,10 @@ -use std::sync::Arc; - use halo2curves_axiom::{ bls12_381::{Fq, Fq12, Fq2, G1Affine}, ff::Field, }; use openvm_circuit::arch::{testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_ecc_guest::AffinePoint; use openvm_instructions::{riscv::RV32_CELL_BITS, UsizeOpcode}; @@ -35,9 +33,7 @@ const BLOCK_SIZE: usize = 16; fn test_mul_023_by_023() { let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapAdapterChip::::new( tester.execution_bus(), tester.program_bus(), @@ -136,9 +132,7 @@ fn test_mul_023_by_023() { fn test_mul_by_02345() { let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapTwoReadsAdapterChip::::new( tester.execution_bus(), tester.program_bus(), diff --git a/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs b/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs index 191ca7dc4f..7e7a602cfe 100644 --- a/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs +++ b/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs @@ -6,7 +6,7 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_derive::{InstructionExecutor, Stateful}; use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ @@ -18,7 +18,7 @@ use openvm_stark_backend::p3_field::PrimeField32; // Input: two AffinePoint: 4 field elements each // Output: (AffinePoint, UnevaluatedLine, UnevaluatedLine) -> 2*2 + 2*2 + 2*2 = 12 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] pub struct MillerDoubleAndAddStepChip< F: PrimeField32, const INPUT_BLOCKS: usize, @@ -103,12 +103,10 @@ pub fn miller_double_and_add_step_expr( #[cfg(test)] mod tests { - use std::sync::Arc; - use halo2curves_axiom::bn256::G2Affine; use openvm_circuit::arch::{testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_ecc_guest::AffinePoint; use openvm_instructions::{riscv::RV32_CELL_BITS, UsizeOpcode}; @@ -134,9 +132,7 @@ mod tests { fn test_miller_double_and_add() { let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapAdapterChip::::new( tester.execution_bus(), tester.program_bus(), diff --git a/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs b/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs index d8daf258f1..14bc7d29e5 100644 --- a/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs +++ b/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs @@ -6,7 +6,7 @@ use std::{ use openvm_algebra_circuit::Fp2; use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; +use openvm_circuit_derive::{InstructionExecutor, Stateful}; use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ @@ -18,7 +18,7 @@ use openvm_stark_backend::p3_field::PrimeField32; // Input: AffinePoint: 4 field elements // Output: (AffinePoint, Fp2, Fp2) -> 8 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] pub struct MillerDoubleStepChip< F: PrimeField32, const INPUT_BLOCKS: usize, @@ -90,11 +90,9 @@ pub fn miller_double_step_expr( #[cfg(test)] mod tests { - use std::sync::Arc; - use openvm_circuit::arch::{testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_ecc_guest::AffinePoint; use openvm_instructions::{riscv::RV32_CELL_BITS, UsizeOpcode}; @@ -132,9 +130,7 @@ mod tests { num_limbs: BN254_NUM_LIMBS, }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapAdapterChip::::new( tester.execution_bus(), tester.program_bus(), @@ -201,9 +197,7 @@ mod tests { num_limbs: BLS12_381_NUM_LIMBS, }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let adapter = Rv32VecHeapAdapterChip::::new( tester.execution_bus(), tester.program_bus(), diff --git a/extensions/pairing/circuit/src/pairing_extension.rs b/extensions/pairing/circuit/src/pairing_extension.rs index b9d63ef05c..e08633ab1d 100644 --- a/extensions/pairing/circuit/src/pairing_extension.rs +++ b/extensions/pairing/circuit/src/pairing_extension.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use derive_more::derive::From; use num_bigint_dig::BigUint; use num_traits::{FromPrimitive, Zero}; @@ -7,9 +5,9 @@ use openvm_circuit::{ arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_ecc_circuit::CurveConfig; @@ -66,7 +64,7 @@ pub struct PairingExtension { pub supported_curves: Vec, } -#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum, Stateful)] pub enum PairingExtensionExecutor { // bn254 (32 limbs) MillerDoubleStepRv32_32(MillerDoubleStepChip), @@ -84,9 +82,9 @@ pub enum PairingExtensionExecutor { EcLineMulBy02345(EcLineMulBy02345Chip), } -#[derive(ChipUsageGetter, Chip, AnyEnum, From)] +#[derive(ChipUsageGetter, Chip, AnyEnum, From, Stateful)] pub enum PairingExtensionPeriphery { - BitwiseOperationLookup(Arc>), + BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), Phantom(PhantomChip), } @@ -104,14 +102,14 @@ impl VmExtension for PairingExtension { program_bus, memory_bridge, } = builder.system_port(); - let bitwise_lu_chip = if let Some(chip) = builder - .find_chip::>>() + let bitwise_lu_chip = if let Some(&chip) = builder + .find_chip::>() .first() { - Arc::clone(chip) + chip.clone() } else { let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = Arc::new(BitwiseOperationLookupChip::new(bitwise_lu_bus)); + let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); inventory.add_periphery_chip(chip.clone()); chip }; diff --git a/extensions/rv32-adapters/src/eq_mod.rs b/extensions/rv32-adapters/src/eq_mod.rs index d6b2d4ff3b..81883ea9b1 100644 --- a/extensions/rv32-adapters/src/eq_mod.rs +++ b/extensions/rv32-adapters/src/eq_mod.rs @@ -2,7 +2,6 @@ use std::{ array::{self, from_fn}, borrow::{Borrow, BorrowMut}, marker::PhantomData, - sync::Arc, }; use itertools::izip; @@ -21,7 +20,7 @@ use openvm_circuit::{ }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -237,7 +236,7 @@ pub struct Rv32IsEqualModAdapterChip< const TOTAL_READ_SIZE: usize, > { pub air: Rv32IsEqualModAdapterAir, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, _marker: PhantomData, } @@ -254,7 +253,7 @@ impl< program_bus: ProgramBus, memory_bridge: MemoryBridge, address_bits: usize, - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE); diff --git a/extensions/rv32-adapters/src/heap.rs b/extensions/rv32-adapters/src/heap.rs index 857c9b6caa..903c633f6c 100644 --- a/extensions/rv32-adapters/src/heap.rs +++ b/extensions/rv32-adapters/src/heap.rs @@ -2,7 +2,6 @@ use std::{ array::{self, from_fn}, borrow::Borrow, marker::PhantomData, - sync::Arc, }; use openvm_circuit::{ @@ -17,7 +16,7 @@ use openvm_circuit::{ }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{ instruction::Instruction, @@ -109,7 +108,7 @@ pub struct Rv32HeapAdapterChip< const WRITE_SIZE: usize, > { pub air: Rv32HeapAdapterAir, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, _marker: PhantomData, } @@ -121,7 +120,7 @@ impl>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); assert!( @@ -226,7 +225,7 @@ impl VmA pub struct Rv32HeapBranchAdapterChip { pub air: Rv32HeapBranchAdapterAir, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, _marker: PhantomData, } @@ -182,7 +181,7 @@ impl program_bus: ProgramBus, memory_bridge: MemoryBridge, address_bits: usize, - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); assert!( diff --git a/extensions/rv32-adapters/src/vec_heap.rs b/extensions/rv32-adapters/src/vec_heap.rs index 96377c8be4..a237cdef93 100644 --- a/extensions/rv32-adapters/src/vec_heap.rs +++ b/extensions/rv32-adapters/src/vec_heap.rs @@ -3,7 +3,6 @@ use std::{ borrow::{Borrow, BorrowMut}, iter::{once, zip}, marker::PhantomData, - sync::Arc, }; use itertools::izip; @@ -21,7 +20,7 @@ use openvm_circuit::{ }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -59,7 +58,7 @@ pub struct Rv32VecHeapAdapterChip< > { pub air: Rv32VecHeapAdapterAir, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, _marker: PhantomData, } @@ -78,7 +77,7 @@ impl< program_bus: ProgramBus, memory_bridge: MemoryBridge, address_bits: usize, - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); assert!( @@ -458,7 +457,7 @@ impl< row_slice, &read_record, &write_record, - &self.bitwise_lookup_chip, + self.bitwise_lookup_chip.clone(), self.air.address_bits, memory, ) @@ -480,7 +479,7 @@ pub(super) fn vec_heap_generate_trace_row_impl< row_slice: &mut [F], read_record: &Rv32VecHeapReadRecord, write_record: &Rv32VecHeapWriteRecord, - bitwise_lookup_chip: &BitwiseOperationLookupChip, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, address_bits: usize, memory: &OfflineMemory, ) { diff --git a/extensions/rv32-adapters/src/vec_heap_two_reads.rs b/extensions/rv32-adapters/src/vec_heap_two_reads.rs index 70ed6249b0..6f896c3392 100644 --- a/extensions/rv32-adapters/src/vec_heap_two_reads.rs +++ b/extensions/rv32-adapters/src/vec_heap_two_reads.rs @@ -3,7 +3,6 @@ use std::{ borrow::{Borrow, BorrowMut}, iter::zip, marker::PhantomData, - sync::Arc, }; use itertools::izip; @@ -21,7 +20,7 @@ use openvm_circuit::{ }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -63,7 +62,7 @@ pub struct Rv32VecHeapTwoReadsAdapterChip< READ_SIZE, WRITE_SIZE, >, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, _marker: PhantomData, } @@ -89,7 +88,7 @@ impl< program_bus: ProgramBus, memory_bridge: MemoryBridge, address_bits: usize, - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!( RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, @@ -487,7 +486,7 @@ impl< row_slice, &read_record, &write_record, - &self.bitwise_lookup_chip, + self.bitwise_lookup_chip.clone(), self.air.address_bits, memory, ) @@ -509,7 +508,7 @@ pub(super) fn vec_heap_two_reads_generate_trace_row_impl< row_slice: &mut [F], read_record: &Rv32VecHeapTwoReadsReadRecord, write_record: &Rv32VecHeapTwoReadsWriteRecord, - bitwise_lookup_chip: &BitwiseOperationLookupChip, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, address_bits: usize, memory: &OfflineMemory, ) { diff --git a/extensions/rv32im/circuit/src/auipc/core.rs b/extensions/rv32im/circuit/src/auipc/core.rs index 354285ab23..b95d5197dc 100644 --- a/extensions/rv32im/circuit/src/auipc/core.rs +++ b/extensions/rv32im/circuit/src/auipc/core.rs @@ -1,7 +1,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - sync::Arc, }; use openvm_circuit::arch::{ @@ -9,7 +8,7 @@ use openvm_circuit::arch::{ VmCoreAir, VmCoreChip, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, UsizeOpcode}; @@ -135,12 +134,12 @@ pub struct Rv32AuipcCoreRecord { pub struct Rv32AuipcCoreChip { pub air: Rv32AuipcCoreAir, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } impl Rv32AuipcCoreChip { pub fn new( - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, offset: usize, ) -> Self { Self { diff --git a/extensions/rv32im/circuit/src/auipc/tests.rs b/extensions/rv32im/circuit/src/auipc/tests.rs index 3955a9e0b9..f3fa26b459 100644 --- a/extensions/rv32im/circuit/src/auipc/tests.rs +++ b/extensions/rv32im/circuit/src/auipc/tests.rs @@ -1,8 +1,8 @@ -use std::{borrow::BorrowMut, sync::Arc}; +use std::borrow::BorrowMut; use openvm_circuit::arch::{testing::VmChipTestBuilder, VmAdapterChip}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, program::PC_BITS, UsizeOpcode, VmOpcode}; use openvm_rv32im_transpiler::Rv32AuipcOpcode::{self, *}; @@ -59,9 +59,7 @@ fn set_and_execute( fn rand_auipc_test() { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let adapter = Rv32RdWriteAdapterChip::::new( @@ -100,9 +98,7 @@ fn run_negative_auipc_test( ) { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let adapter = Rv32RdWriteAdapterChip::::new( @@ -248,9 +244,7 @@ fn overflow_negative_tests() { fn execute_roundtrip_sanity_test() { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let adapter = Rv32RdWriteAdapterChip::::new( diff --git a/extensions/rv32im/circuit/src/base_alu/core.rs b/extensions/rv32im/circuit/src/base_alu/core.rs index 334fd5d341..54e4b68a8a 100644 --- a/extensions/rv32im/circuit/src/base_alu/core.rs +++ b/extensions/rv32im/circuit/src/base_alu/core.rs @@ -1,7 +1,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - sync::Arc, }; use openvm_circuit::arch::{ @@ -9,7 +8,7 @@ use openvm_circuit::arch::{ VmCoreAir, VmCoreChip, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -173,12 +172,12 @@ pub struct BaseAluCoreRecord pub struct BaseAluCoreChip { pub air: BaseAluCoreAir, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } impl BaseAluCoreChip { pub fn new( - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, offset: usize, ) -> Self { Self { diff --git a/extensions/rv32im/circuit/src/base_alu/tests.rs b/extensions/rv32im/circuit/src/base_alu/tests.rs index a691ac27ff..30210abe94 100644 --- a/extensions/rv32im/circuit/src/base_alu/tests.rs +++ b/extensions/rv32im/circuit/src/base_alu/tests.rs @@ -1,4 +1,4 @@ -use std::{borrow::BorrowMut, sync::Arc}; +use std::borrow::BorrowMut; use openvm_circuit::{ arch::{ @@ -8,7 +8,7 @@ use openvm_circuit::{ utils::generate_long_number, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, VmOpcode}; use openvm_rv32im_transpiler::BaseAluOpcode; @@ -45,9 +45,7 @@ type F = BabyBear; fn run_rv32_alu_rand_test(opcode: BaseAluOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32BaseAluChip::::new( @@ -130,9 +128,7 @@ fn run_rv32_alu_negative_test( interaction_error: bool, ) { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); let mut chip = Rv32BaseAluTestChip::::new( diff --git a/extensions/rv32im/circuit/src/branch_lt/core.rs b/extensions/rv32im/circuit/src/branch_lt/core.rs index e9dd9af7d9..68b5e36c8f 100644 --- a/extensions/rv32im/circuit/src/branch_lt/core.rs +++ b/extensions/rv32im/circuit/src/branch_lt/core.rs @@ -1,7 +1,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - sync::Arc, }; use openvm_circuit::arch::{ @@ -9,7 +8,7 @@ use openvm_circuit::arch::{ VmCoreAir, VmCoreChip, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -193,12 +192,12 @@ pub struct BranchLessThanCoreRecord { pub air: BranchLessThanCoreAir, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } impl BranchLessThanCoreChip { pub fn new( - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, offset: usize, ) -> Self { Self { diff --git a/extensions/rv32im/circuit/src/branch_lt/tests.rs b/extensions/rv32im/circuit/src/branch_lt/tests.rs index a4d8aed889..893a7cbdf3 100644 --- a/extensions/rv32im/circuit/src/branch_lt/tests.rs +++ b/extensions/rv32im/circuit/src/branch_lt/tests.rs @@ -1,4 +1,4 @@ -use std::{borrow::BorrowMut, sync::Arc}; +use std::borrow::BorrowMut; use openvm_circuit::{ arch::{ @@ -9,7 +9,7 @@ use openvm_circuit::{ utils::{generate_long_number, i32_to_f}, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, program::PC_BITS, UsizeOpcode, VmOpcode}; use openvm_rv32im_transpiler::BranchLessThanOpcode; @@ -88,9 +88,7 @@ fn run_rv32_branch_lt_rand_test(opcode: BranchLessThanOpcode, num_ops: usize) { const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32BranchLessThanChip::::new( @@ -191,9 +189,7 @@ fn run_rv32_blt_negative_test( ) { let imm = 16u32; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); let mut chip = Rv32BranchLessThanTestChip::::new( @@ -496,9 +492,7 @@ fn rv32_blt_unsigned_wrong_b_msb_sign_negative_test() { #[test] fn execute_pc_increment_sanity_test() { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let core = BranchLessThanCoreChip::::new(bitwise_chip, 0); diff --git a/extensions/rv32im/circuit/src/divrem/core.rs b/extensions/rv32im/circuit/src/divrem/core.rs index 8bdd7a3f6e..2f70021bfd 100644 --- a/extensions/rv32im/circuit/src/divrem/core.rs +++ b/extensions/rv32im/circuit/src/divrem/core.rs @@ -1,7 +1,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - sync::Arc, }; use num_bigint::BigUint; @@ -11,8 +10,8 @@ use openvm_circuit::arch::{ VmCoreAir, VmCoreChip, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, - range_tuple::{RangeTupleCheckerBus, RangeTupleCheckerChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, utils::{not, select}, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -308,14 +307,14 @@ where pub struct DivRemCoreChip { pub air: DivRemCoreAir, - pub bitwise_lookup_chip: Arc>, - pub range_tuple_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub range_tuple_chip: SharedRangeTupleCheckerChip<2>, } impl DivRemCoreChip { pub fn new( - bitwise_lookup_chip: Arc>, - range_tuple_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + range_tuple_chip: SharedRangeTupleCheckerChip<2>, offset: usize, ) -> Self { // The RangeTupleChecker is used to range check (a[i], carry[i]) pairs where 0 <= i diff --git a/extensions/rv32im/circuit/src/divrem/tests.rs b/extensions/rv32im/circuit/src/divrem/tests.rs index 446f207683..94162a86eb 100644 --- a/extensions/rv32im/circuit/src/divrem/tests.rs +++ b/extensions/rv32im/circuit/src/divrem/tests.rs @@ -1,4 +1,4 @@ -use std::{array, borrow::BorrowMut, sync::Arc}; +use std::{array, borrow::BorrowMut}; use openvm_circuit::{ arch::{ @@ -9,8 +9,8 @@ use openvm_circuit::{ utils::generate_long_number, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, - range_tuple::{RangeTupleCheckerBus, RangeTupleCheckerChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; use openvm_instructions::{instruction::Instruction, VmOpcode}; use openvm_rv32im_transpiler::DivRemOpcode; @@ -98,10 +98,8 @@ fn run_rv32_divrem_rand_test(opcode: DivRemOpcode, num_ops: usize) { [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], ); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); - let range_tuple_checker = Arc::new(RangeTupleCheckerChip::new(range_tuple_bus)); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32DivRemChip::::new( @@ -240,10 +238,8 @@ fn run_rv32_divrem_negative_test( [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], ); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); - let range_tuple_chip = Arc::new(RangeTupleCheckerChip::new(range_tuple_bus)); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32DivRemTestChip::::new( diff --git a/extensions/rv32im/circuit/src/extension.rs b/extensions/rv32im/circuit/src/extension.rs index fe912fa0e2..7eb1ac7d69 100644 --- a/extensions/rv32im/circuit/src/extension.rs +++ b/extensions/rv32im/circuit/src/extension.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use derive_more::derive::From; use openvm_circuit::{ arch::{ @@ -8,10 +6,10 @@ use openvm_circuit::{ }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful, VmConfig}; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, - range_tuple::{RangeTupleCheckerBus, RangeTupleCheckerChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{program::DEFAULT_PC_STEP, PhantomDiscriminant, UsizeOpcode, VmOpcode}; @@ -152,7 +150,7 @@ fn default_range_tuple_checker_sizes() -> [u32; 2] { // ============ Executor and Periphery Enums for Extension ============ /// RISC-V 32-bit Base (RV32I) Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] pub enum Rv32IExecutor { // Rv32 (for standard 32-bit integers): BaseAlu(Rv32BaseAluChip), @@ -168,7 +166,7 @@ pub enum Rv32IExecutor { } /// RISC-V 32-bit Multiplication Extension (RV32M) Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] pub enum Rv32MExecutor { Multiplication(Rv32MultiplicationChip), MultiplicationHigh(Rv32MulHChip), @@ -176,30 +174,30 @@ pub enum Rv32MExecutor { } /// RISC-V 32-bit Io Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] pub enum Rv32IoExecutor { HintStore(Rv32HintStoreChip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] pub enum Rv32IPeriphery { - BitwiseOperationLookup(Arc>), + BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), // We put this only to get the generic to work Phantom(PhantomChip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] pub enum Rv32MPeriphery { - BitwiseOperationLookup(Arc>), + BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), /// Only needed for multiplication extension - RangeTupleChecker(Arc>), + RangeTupleChecker(SharedRangeTupleCheckerChip<2>), // We put this only to get the generic to work Phantom(PhantomChip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] pub enum Rv32IoPeriphery { - BitwiseOperationLookup(Arc>), + BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), // We put this only to get the generic to work Phantom(PhantomChip), } @@ -225,14 +223,14 @@ impl VmExtension for Rv32I { let offline_memory = builder.system_base().offline_memory(); let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; - let bitwise_lu_chip = if let Some(chip) = builder - .find_chip::>>() + let bitwise_lu_chip = if let Some(&chip) = builder + .find_chip::>() .first() { - Arc::clone(chip) + chip.clone() } else { let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = Arc::new(BitwiseOperationLookupChip::new(bitwise_lu_bus)); + let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); inventory.add_periphery_chip(chip.clone()); chip }; @@ -398,20 +396,20 @@ impl VmExtension for Rv32M { } = builder.system_port(); let offline_memory = builder.system_base().offline_memory(); - let bitwise_lu_chip = if let Some(chip) = builder - .find_chip::>>() + let bitwise_lu_chip = if let Some(&chip) = builder + .find_chip::>() .first() { - Arc::clone(chip) + chip.clone() } else { let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = Arc::new(BitwiseOperationLookupChip::new(bitwise_lu_bus)); + let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); inventory.add_periphery_chip(chip.clone()); chip }; let range_tuple_checker = if let Some(chip) = builder - .find_chip::>>() + .find_chip::>() .into_iter() .find(|c| { c.bus().sizes[0] >= self.range_tuple_checker_sizes[0] @@ -421,7 +419,7 @@ impl VmExtension for Rv32M { } else { let range_tuple_bus = RangeTupleCheckerBus::new(builder.new_bus_idx(), self.range_tuple_checker_sizes); - let chip = Arc::new(RangeTupleCheckerChip::new(range_tuple_bus)); + let chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); inventory.add_periphery_chip(chip.clone()); chip }; @@ -486,14 +484,14 @@ impl VmExtension for Rv32Io { let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; let range_checker = builder.system_base().range_checker_chip.clone(); - let bitwise_lu_chip = if let Some(chip) = builder - .find_chip::>>() + let bitwise_lu_chip = if let Some(&chip) = builder + .find_chip::>() .first() { - Arc::clone(chip) + chip.clone() } else { let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = Arc::new(BitwiseOperationLookupChip::new(bitwise_lu_bus)); + let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); inventory.add_periphery_chip(chip.clone()); chip }; diff --git a/extensions/rv32im/circuit/src/hintstore/core.rs b/extensions/rv32im/circuit/src/hintstore/core.rs index 013e35cffb..103fd57d3a 100644 --- a/extensions/rv32im/circuit/src/hintstore/core.rs +++ b/extensions/rv32im/circuit/src/hintstore/core.rs @@ -9,7 +9,7 @@ use openvm_circuit::arch::{ VmAdapterInterface, VmCoreAir, VmCoreChip, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, UsizeOpcode}; @@ -95,12 +95,12 @@ where pub struct Rv32HintStoreCoreChip { pub air: Rv32HintStoreCoreAir, pub streams: OnceLock>>>, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } impl Rv32HintStoreCoreChip { pub fn new( - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, offset: usize, ) -> Self { Self { diff --git a/extensions/rv32im/circuit/src/hintstore/tests.rs b/extensions/rv32im/circuit/src/hintstore/tests.rs index 8aad30fab6..c703f90492 100644 --- a/extensions/rv32im/circuit/src/hintstore/tests.rs +++ b/extensions/rv32im/circuit/src/hintstore/tests.rs @@ -12,7 +12,7 @@ use openvm_circuit::{ utils::{u32_into_limbs, u32_sign_extend}, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, UsizeOpcode, VmOpcode}; use openvm_rv32im_transpiler::Rv32HintStoreOpcode::{self, *}; @@ -106,9 +106,7 @@ fn rand_hintstore_test() { let mut tester = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); let adapter = Rv32HintStoreAdapterChip::::new( @@ -152,9 +150,7 @@ fn run_negative_hintstore_test( let mut tester = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); let adapter = Rv32HintStoreAdapterChip::::new( @@ -211,9 +207,7 @@ fn execute_roundtrip_sanity_test() { let mut tester = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); let adapter = Rv32HintStoreAdapterChip::::new( diff --git a/extensions/rv32im/circuit/src/jal_lui/core.rs b/extensions/rv32im/circuit/src/jal_lui/core.rs index 6ba1d5d57e..457c7b4cbc 100644 --- a/extensions/rv32im/circuit/src/jal_lui/core.rs +++ b/extensions/rv32im/circuit/src/jal_lui/core.rs @@ -1,7 +1,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - sync::Arc, }; use openvm_circuit::arch::{ @@ -9,7 +8,7 @@ use openvm_circuit::arch::{ VmCoreAir, VmCoreChip, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -144,12 +143,12 @@ pub struct Rv32JalLuiCoreRecord { pub struct Rv32JalLuiCoreChip { pub air: Rv32JalLuiCoreAir, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } impl Rv32JalLuiCoreChip { pub fn new( - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, offset: usize, ) -> Self { Self { diff --git a/extensions/rv32im/circuit/src/jal_lui/tests.rs b/extensions/rv32im/circuit/src/jal_lui/tests.rs index a16b20e1a7..9ac8373576 100644 --- a/extensions/rv32im/circuit/src/jal_lui/tests.rs +++ b/extensions/rv32im/circuit/src/jal_lui/tests.rs @@ -1,8 +1,8 @@ -use std::{borrow::BorrowMut, sync::Arc}; +use std::borrow::BorrowMut; use openvm_circuit::arch::{testing::VmChipTestBuilder, VmAdapterChip, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, program::PC_BITS, UsizeOpcode, VmOpcode}; use openvm_rv32im_transpiler::Rv32JalLuiOpcode::{self, *}; @@ -82,9 +82,7 @@ fn set_and_execute( fn rand_jal_lui_test() { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let adapter = Rv32CondRdWriteAdapterChip::::new( @@ -126,9 +124,7 @@ fn run_negative_jal_lui_test( ) { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let adapter = Rv32CondRdWriteAdapterChip::::new( @@ -311,9 +307,7 @@ fn overflow_negative_tests() { fn execute_roundtrip_sanity_test() { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let adapter = Rv32CondRdWriteAdapterChip::::new( diff --git a/extensions/rv32im/circuit/src/jalr/core.rs b/extensions/rv32im/circuit/src/jalr/core.rs index 4cf495931e..7938f77c91 100644 --- a/extensions/rv32im/circuit/src/jalr/core.rs +++ b/extensions/rv32im/circuit/src/jalr/core.rs @@ -9,7 +9,7 @@ use openvm_circuit::arch::{ VmCoreAir, VmCoreChip, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -174,13 +174,13 @@ where pub struct Rv32JalrCoreChip { pub air: Rv32JalrCoreAir, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_checker_chip: Arc, } impl Rv32JalrCoreChip { pub fn new( - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_checker_chip: Arc, offset: usize, ) -> Self { diff --git a/extensions/rv32im/circuit/src/jalr/tests.rs b/extensions/rv32im/circuit/src/jalr/tests.rs index 7637e54f1a..4b92a08ea5 100644 --- a/extensions/rv32im/circuit/src/jalr/tests.rs +++ b/extensions/rv32im/circuit/src/jalr/tests.rs @@ -1,8 +1,8 @@ -use std::{array, borrow::BorrowMut, sync::Arc}; +use std::{array, borrow::BorrowMut}; use openvm_circuit::arch::{testing::VmChipTestBuilder, VmAdapterChip, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, program::PC_BITS, UsizeOpcode, VmOpcode}; use openvm_rv32im_transpiler::Rv32JalrOpcode::{self, *}; @@ -85,9 +85,7 @@ fn set_and_execute( fn rand_jalr_test() { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); @@ -136,9 +134,7 @@ fn run_negative_jalr_test( ) { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); @@ -278,9 +274,7 @@ fn overflow_negative_tests() { fn execute_roundtrip_sanity_test() { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); diff --git a/extensions/rv32im/circuit/src/less_than/core.rs b/extensions/rv32im/circuit/src/less_than/core.rs index ef2b341090..2b173ec316 100644 --- a/extensions/rv32im/circuit/src/less_than/core.rs +++ b/extensions/rv32im/circuit/src/less_than/core.rs @@ -1,7 +1,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - sync::Arc, }; use openvm_circuit::arch::{ @@ -9,7 +8,7 @@ use openvm_circuit::arch::{ VmCoreAir, VmCoreChip, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -170,12 +169,12 @@ pub struct LessThanCoreRecord pub struct LessThanCoreChip { pub air: LessThanCoreAir, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } impl LessThanCoreChip { pub fn new( - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, offset: usize, ) -> Self { Self { diff --git a/extensions/rv32im/circuit/src/less_than/tests.rs b/extensions/rv32im/circuit/src/less_than/tests.rs index f82420710f..9dac17b262 100644 --- a/extensions/rv32im/circuit/src/less_than/tests.rs +++ b/extensions/rv32im/circuit/src/less_than/tests.rs @@ -1,4 +1,4 @@ -use std::{borrow::BorrowMut, sync::Arc}; +use std::borrow::BorrowMut; use openvm_circuit::{ arch::{ @@ -8,7 +8,7 @@ use openvm_circuit::{ utils::{generate_long_number, i32_to_f}, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, VmOpcode}; use openvm_rv32im_transpiler::LessThanOpcode; @@ -45,9 +45,7 @@ type F = BabyBear; fn run_rv32_lt_rand_test(opcode: LessThanOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32LessThanChip::::new( @@ -137,9 +135,7 @@ fn run_rv32_lt_negative_test( interaction_error: bool, ) { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); let mut chip = Rv32LessThanTestChip::::new( diff --git a/extensions/rv32im/circuit/src/mul/core.rs b/extensions/rv32im/circuit/src/mul/core.rs index 431079ee92..6f56b1807a 100644 --- a/extensions/rv32im/circuit/src/mul/core.rs +++ b/extensions/rv32im/circuit/src/mul/core.rs @@ -1,14 +1,13 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - sync::Arc, }; use openvm_circuit::arch::{ AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, VmCoreAir, VmCoreChip, }; -use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerBus, RangeTupleCheckerChip}; +use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, UsizeOpcode}; use openvm_rv32im_transpiler::MulOpcode; @@ -107,11 +106,11 @@ where #[derive(Debug)] pub struct MultiplicationCoreChip { pub air: MultiplicationCoreAir, - pub range_tuple_chip: Arc>, + pub range_tuple_chip: SharedRangeTupleCheckerChip<2>, } impl MultiplicationCoreChip { - pub fn new(range_tuple_chip: Arc>, offset: usize) -> Self { + pub fn new(range_tuple_chip: SharedRangeTupleCheckerChip<2>, offset: usize) -> Self { // The RangeTupleChecker is used to range check (a[i], carry[i]) pairs where 0 <= i // < NUM_LIMBS. a[i] must have LIMB_BITS bits and carry[i] is the sum of i + 1 bytes // (with LIMB_BITS bits). diff --git a/extensions/rv32im/circuit/src/mul/tests.rs b/extensions/rv32im/circuit/src/mul/tests.rs index 30fdfa5c4e..9cdbfa9d9b 100644 --- a/extensions/rv32im/circuit/src/mul/tests.rs +++ b/extensions/rv32im/circuit/src/mul/tests.rs @@ -1,4 +1,4 @@ -use std::{borrow::BorrowMut, sync::Arc}; +use std::borrow::BorrowMut; use openvm_circuit::{ arch::{ @@ -7,7 +7,7 @@ use openvm_circuit::{ }, utils::generate_long_number, }; -use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerBus, RangeTupleCheckerChip}; +use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}; use openvm_instructions::{instruction::Instruction, VmOpcode}; use openvm_rv32im_transpiler::MulOpcode; use openvm_stark_backend::{ @@ -48,7 +48,7 @@ fn run_rv32_mul_rand_test(num_ops: usize) { RANGE_TUPLE_CHECKER_BUS, [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], ); - let range_tuple_checker = Arc::new(RangeTupleCheckerChip::new(range_tuple_bus)); + let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32MultiplicationChip::::new( @@ -123,7 +123,7 @@ fn run_rv32_mul_negative_test( RANGE_TUPLE_CHECKER_BUS, [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], ); - let range_tuple_chip = Arc::new(RangeTupleCheckerChip::new(range_tuple_bus)); + let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32MultiplicationTestChip::::new( diff --git a/extensions/rv32im/circuit/src/mulh/core.rs b/extensions/rv32im/circuit/src/mulh/core.rs index bec6ae754e..a8c4a0273a 100644 --- a/extensions/rv32im/circuit/src/mulh/core.rs +++ b/extensions/rv32im/circuit/src/mulh/core.rs @@ -1,7 +1,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - sync::Arc, }; use openvm_circuit::arch::{ @@ -9,8 +8,8 @@ use openvm_circuit::arch::{ VmCoreAir, VmCoreChip, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, - range_tuple::{RangeTupleCheckerBus, RangeTupleCheckerChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, UsizeOpcode}; @@ -180,14 +179,14 @@ where pub struct MulHCoreChip { pub air: MulHCoreAir, - pub bitwise_lookup_chip: Arc>, - pub range_tuple_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub range_tuple_chip: SharedRangeTupleCheckerChip<2>, } impl MulHCoreChip { pub fn new( - bitwise_lookup_chip: Arc>, - range_tuple_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + range_tuple_chip: SharedRangeTupleCheckerChip<2>, offset: usize, ) -> Self { // The RangeTupleChecker is used to range check (a[i], carry[i]) pairs where 0 <= i diff --git a/extensions/rv32im/circuit/src/mulh/tests.rs b/extensions/rv32im/circuit/src/mulh/tests.rs index 62d64470d5..31a1fbc62e 100644 --- a/extensions/rv32im/circuit/src/mulh/tests.rs +++ b/extensions/rv32im/circuit/src/mulh/tests.rs @@ -1,4 +1,4 @@ -use std::{borrow::BorrowMut, sync::Arc}; +use std::borrow::BorrowMut; use openvm_circuit::{ arch::{ @@ -9,8 +9,8 @@ use openvm_circuit::{ utils::generate_long_number, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, - range_tuple::{RangeTupleCheckerBus, RangeTupleCheckerChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; use openvm_instructions::{instruction::Instruction, VmOpcode}; use openvm_rv32im_transpiler::MulHOpcode; @@ -82,10 +82,8 @@ fn run_rv32_mulh_rand_test(opcode: MulHOpcode, num_ops: usize) { [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], ); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); - let range_tuple_checker = Arc::new(RangeTupleCheckerChip::new(range_tuple_bus)); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32MulHChip::::new( @@ -157,10 +155,8 @@ fn run_rv32_mulh_negative_test( [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], ); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); - let range_tuple_chip = Arc::new(RangeTupleCheckerChip::new(range_tuple_bus)); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32MulHTestChip::::new( diff --git a/extensions/rv32im/circuit/src/shift/core.rs b/extensions/rv32im/circuit/src/shift/core.rs index 2d05bcc9cb..28eca7833a 100644 --- a/extensions/rv32im/circuit/src/shift/core.rs +++ b/extensions/rv32im/circuit/src/shift/core.rs @@ -9,7 +9,7 @@ use openvm_circuit::arch::{ VmCoreAir, VmCoreChip, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}, }; @@ -251,13 +251,13 @@ pub struct ShiftCoreRecord { pub struct ShiftCoreChip { pub air: ShiftCoreAir, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_checker_chip: Arc, } impl ShiftCoreChip { pub fn new( - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_checker_chip: Arc, offset: usize, ) -> Self { diff --git a/extensions/rv32im/circuit/src/shift/tests.rs b/extensions/rv32im/circuit/src/shift/tests.rs index 4d582440da..b6657de414 100644 --- a/extensions/rv32im/circuit/src/shift/tests.rs +++ b/extensions/rv32im/circuit/src/shift/tests.rs @@ -1,4 +1,4 @@ -use std::{array, borrow::BorrowMut, sync::Arc}; +use std::{array, borrow::BorrowMut}; use openvm_circuit::{ arch::{ @@ -8,7 +8,7 @@ use openvm_circuit::{ utils::generate_long_number, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, VmOpcode}; use openvm_rv32im_transpiler::ShiftOpcode; @@ -45,9 +45,7 @@ type F = BabyBear; fn run_rv32_shift_rand_test(opcode: ShiftOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32ShiftChip::::new( @@ -138,9 +136,7 @@ fn run_rv32_shift_negative_test( interaction_error: bool, ) { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); let mut chip = Rv32ShiftTestChip::::new( diff --git a/extensions/sha256/circuit/Cargo.toml b/extensions/sha256/circuit/Cargo.toml index bdf8cb0d69..efb1d12d42 100644 --- a/extensions/sha256/circuit/Cargo.toml +++ b/extensions/sha256/circuit/Cargo.toml @@ -23,6 +23,7 @@ rand.workspace = true serde.workspace = true sha2 = { version = "0.10", default-features = false } strum = { workspace = true } +bitcode.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } diff --git a/extensions/sha256/circuit/src/extension.rs b/extensions/sha256/circuit/src/extension.rs index 8daceb2991..baeeccf76b 100644 --- a/extensions/sha256/circuit/src/extension.rs +++ b/extensions/sha256/circuit/src/extension.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use derive_more::derive::From; use openvm_circuit::{ arch::{ @@ -8,9 +6,9 @@ use openvm_circuit::{ }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InstructionExecutor, Stateful, VmConfig}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::*; @@ -54,14 +52,14 @@ impl Default for Sha256Rv32Config { #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct Sha256; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, Stateful)] pub enum Sha256Executor { Sha256(Sha256VmChip), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] +#[derive(From, ChipUsageGetter, Chip, AnyEnum, Stateful)] pub enum Sha256Periphery { - BitwiseOperationLookup(Arc>), + BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), Phantom(PhantomChip), } @@ -74,14 +72,14 @@ impl VmExtension for Sha256 { builder: &mut VmInventoryBuilder, ) -> Result, VmInventoryError> { let mut inventory = VmInventory::new(); - let bitwise_lu_chip = if let Some(chip) = builder - .find_chip::>>() + let bitwise_lu_chip = if let Some(&chip) = builder + .find_chip::>() .first() { - Arc::clone(chip) + chip.clone() } else { let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = Arc::new(BitwiseOperationLookupChip::new(bitwise_lu_bus)); + let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); inventory.add_periphery_chip(chip.clone()); chip }; diff --git a/extensions/sha256/circuit/src/sha256_chip/mod.rs b/extensions/sha256/circuit/src/sha256_chip/mod.rs index 58f780717e..983819348d 100644 --- a/extensions/sha256/circuit/src/sha256_chip/mod.rs +++ b/extensions/sha256/circuit/src/sha256_chip/mod.rs @@ -9,7 +9,9 @@ use std::{ use openvm_circuit::arch::{ ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, SystemPort, }; -use openvm_circuit_primitives::{bitwise_op_lookup::BitwiseOperationLookupChip, encoder::Encoder}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, +}; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, @@ -19,7 +21,8 @@ use openvm_instructions::{ use openvm_rv32im_circuit::adapters::read_rv32_register; use openvm_sha256_air::{Sha256Air, SHA256_BLOCK_BITS}; use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_stark_backend::{p3_field::PrimeField32, Stateful}; +use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; mod air; @@ -49,12 +52,12 @@ pub struct Sha256VmChip { /// IO and memory data necessary for each opcode call pub records: Vec>, pub offline_memory: Arc>>, - pub bitwise_lookup_chip: Arc>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, offset: usize, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Sha256Record { pub from_state: ExecutionState, pub dst_read: RecordId, @@ -73,7 +76,7 @@ impl Sha256VmChip { memory_bridge, }: SystemPort, address_bits: usize, - bitwise_lookup_chip: Arc>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, self_bus_idx: usize, offset: usize, offline_memory: Arc>>, @@ -194,6 +197,16 @@ impl InstructionExecutor for Sha256VmChip { } } +impl Stateful> for Sha256VmChip { + fn load_state(&mut self, state: Vec) { + self.records = bitcode::deserialize(&state).unwrap(); + } + + fn store_state(&self) -> Vec { + bitcode::serialize(&self.records).unwrap() + } +} + pub fn sha256_solve(input_message: &[u8]) -> [u8; SHA256_WRITE_SIZE] { let mut hasher = Sha256::new(); hasher.update(input_message); diff --git a/extensions/sha256/circuit/src/sha256_chip/tests.rs b/extensions/sha256/circuit/src/sha256_chip/tests.rs index 99a2736b99..9f5a016874 100644 --- a/extensions/sha256/circuit/src/sha256_chip/tests.rs +++ b/extensions/sha256/circuit/src/sha256_chip/tests.rs @@ -1,11 +1,9 @@ -use std::sync::Arc; - use openvm_circuit::arch::{ testing::{memory::gen_pointer, VmChipTestBuilder}, SystemPort, BITWISE_OP_LOOKUP_BUS, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, riscv::RV32_CELL_BITS, UsizeOpcode, VmOpcode}; use openvm_sha256_air::get_random_message; @@ -78,9 +76,7 @@ fn rand_sha256_test() { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut chip = Sha256VmChip::new( SystemPort { execution_bus: tester.execution_bus(), @@ -113,9 +109,7 @@ fn execute_roundtrip_sanity_test() { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut chip = Sha256VmChip::new( SystemPort { execution_bus: tester.execution_bus(), diff --git a/extensions/sha256/circuit/src/sha256_chip/trace.rs b/extensions/sha256/circuit/src/sha256_chip/trace.rs index a691d3496e..aae1ab0eab 100644 --- a/extensions/sha256/circuit/src/sha256_chip/trace.rs +++ b/extensions/sha256/circuit/src/sha256_chip/trace.rs @@ -125,7 +125,7 @@ where width, SHA256VM_CONTROL_WIDTH, &padded_message, - self.bitwise_lookup_chip.as_ref(), + self.bitwise_lookup_chip.clone(), &state.hash, is_last_block, global_block_idx as u32 + 1,