Skip to content

Commit

Permalink
feat: Pass hasher to MemoryChip finalize (#506)
Browse files Browse the repository at this point in the history
  • Loading branch information
zlangley authored Oct 8, 2024
1 parent 4949e78 commit 21140b5
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 106 deletions.
18 changes: 5 additions & 13 deletions compiler/tests/keccak256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use p3_field::{extension::BinomialExtensionField, AbstractField};
use stark_vm::{
arch::ExecutorName,
hashes::keccak::hasher::{utils::keccak256, KECCAK_DIGEST_BYTES},
vm::config::{VmConfig, DEFAULT_MAX_SEGMENT_LEN, DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE},
vm::config::VmConfig,
};
use tracing::Level;

Expand Down Expand Up @@ -56,18 +56,10 @@ fn run_e2e_keccak_test(inputs: Vec<Vec<u8>>, expected_outputs: Vec<[u8; 32]>) {
execute_and_prove_program(
program,
vec![],
VmConfig::from_parameters(
Some(DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE),
Default::default(),
4,
DEFAULT_MAX_SEGMENT_LEN,
false,
8,
vec![],
)
.add_default_executor(ExecutorName::FieldArithmetic)
.add_default_executor(ExecutorName::FieldExtension)
.add_default_executor(ExecutorName::Keccak256),
VmConfig::default_with_no_executors()
.add_default_executor(ExecutorName::FieldArithmetic)
.add_default_executor(ExecutorName::FieldExtension)
.add_default_executor(ExecutorName::Keccak256),
BabyBearPoseidon2Engine::new(FriParameters::standard_fast()),
)
.unwrap();
Expand Down
10 changes: 7 additions & 3 deletions vm/src/arch/testing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub use execution::ExecutionTester;
pub use memory::MemoryTester;

use super::{ExecutionBus, InstructionExecutor};
use crate::memory::MemoryChipRef;
use crate::{hashes::poseidon2::Poseidon2Chip, memory::MemoryChipRef, vm::config::PersistenceType};

#[derive(Clone, Debug)]
pub struct MachineChipTestBuilder<F: PrimeField32> {
Expand Down Expand Up @@ -117,6 +117,10 @@ impl<F: PrimeField32> MachineChipTestBuilder<F> {

impl MachineChipTestBuilder<BabyBear> {
pub fn build(self) -> MachineChipTester {
self.memory
.chip
.borrow_mut()
.finalize(None::<&mut Poseidon2Chip<BabyBear>>);
let tester = MachineChipTester {
memory: Some(self.memory),
..Default::default()
Expand All @@ -128,12 +132,12 @@ impl MachineChipTestBuilder<BabyBear> {

impl<F: PrimeField32> Default for MachineChipTestBuilder<F> {
fn default() -> Self {
let mem_config = MemoryConfig::new(2, 29, 29, 17); // smaller testing config with smaller decomp_bits
let mem_config = MemoryConfig::new(2, 29, 29, 17, PersistenceType::Volatile);
let range_checker = Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new(
RANGE_CHECKER_BUS,
mem_config.decomp,
)));
let memory_chip = MemoryChip::with_volatile_memory(MemoryBus(1), mem_config, range_checker);
let memory_chip = MemoryChip::new(MemoryBus(1), mem_config, range_checker);
Self {
memory: MemoryTester::new(Rc::new(RefCell::new(memory_chip))),
execution: ExecutionTester::new(ExecutionBus(0)),
Expand Down
61 changes: 29 additions & 32 deletions vm/src/memory/manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ use crate::{
MemoryBridge, MemoryBus, MemoryReadAuxCols, MemoryReadOrImmediateAuxCols,
MemoryWriteAuxCols, AUX_LEN,
},
tree::Hasher,
},
vm::config::MemoryConfig,
vm::config::{MemoryConfig, PersistenceType},
};

pub mod dimensions;
Expand All @@ -43,6 +44,7 @@ mod memory;
mod trace;

const NUM_WORDS: usize = 16;
pub const CHUNK: usize = 8;

#[derive(Clone, Copy, Debug)]
pub struct TimestampedValue<T> {
Expand Down Expand Up @@ -160,27 +162,14 @@ pub struct MemoryChip<F: PrimeField32> {
}

impl<F: PrimeField32> MemoryChip<F> {
// pub fn with_persistent_memory(
// memory_dimensions: MemoryDimensions,
// memory: HashMap<(F, F), AccessCell<WORD_SIZE, F>>,
// ) -> Self {
// Self {
// interface_chip: MemoryInterface::Persistent(MemoryExpandInterfaceChip::new(
// memory_dimensions,
// )),
// clk: F::one(),
// memory,
// }
// }

pub fn with_volatile_memory(
pub fn new(
memory_bus: MemoryBus,
mem_config: MemoryConfig,
range_checker: Arc<VariableRangeCheckerChip>,
) -> Self {
Self {
memory_bus,
mem_config,
mem_config: mem_config.clone(),
interface_chip: MemoryInterface::Volatile(MemoryAuditChip::new(
memory_bus,
mem_config.addr_space_max_bits,
Expand Down Expand Up @@ -379,19 +368,29 @@ impl<F: PrimeField32> MemoryChip<F> {
}
}

fn finalize(&mut self) {
let all_addresses = self.interface_chip.all_addresses();
for (address_space, pointer) in all_addresses {
let records = self.memory.access(
AddressSpace(address_space.as_canonical_u32()),
pointer.as_canonical_u32() as usize,
1,
pub fn finalize(&mut self, hasher: Option<&mut impl Hasher<CHUNK, F>>) {
if let Some(_hasher) = hasher {
assert_eq!(
self.mem_config.persistence_type,
PersistenceType::Persistent
);
for record in records {
self.adapter_records
.entry(record.data.len())
.or_default()
.push(record);
todo!("finalize persistent memory");
} else {
assert_eq!(self.mem_config.persistence_type, PersistenceType::Volatile);

let all_addresses = self.interface_chip.all_addresses();
for (address_space, pointer) in all_addresses {
let records = self.memory.access(
AddressSpace(address_space.as_canonical_u32()),
pointer.as_canonical_u32() as usize,
1,
);
for record in records {
self.adapter_records
.entry(record.data.len())
.or_default()
.push(record);
}
}
}
}
Expand All @@ -401,9 +400,7 @@ impl<F: PrimeField32> MachineChip<F> for MemoryChip<F> {
fn generate_trace(self) -> RowMajorMatrix<F> {
panic!("cannot call generate_trace on MemoryChip, which has more than one trace");
}
fn generate_traces(mut self) -> Vec<RowMajorMatrix<F>> {
self.finalize();

fn generate_traces(self) -> Vec<RowMajorMatrix<F>> {
vec![
self.generate_memory_interface_trace(),
self.generate_access_adapter_trace::<2>(),
Expand Down Expand Up @@ -612,7 +609,7 @@ mod tests {
let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus));

let mut memory_chip =
MemoryChip::with_volatile_memory(memory_bus, memory_config, range_checker.clone());
MemoryChip::new(memory_bus, memory_config.clone(), range_checker.clone());

let mut rng = thread_rng();
for _ in 0..1000 {
Expand Down
9 changes: 6 additions & 3 deletions vm/src/memory/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ use super::{MemoryChip, MemoryReadRecord};
use crate::{
arch::MachineChip,
core::RANGE_CHECKER_BUS,
hashes::poseidon2::Poseidon2Chip,
memory::{
offline_checker::{MemoryBridge, MemoryBus, MemoryReadAuxCols, MemoryWriteAuxCols},
MemoryAddress, MemoryWriteRecord,
},
vm::config::MemoryConfig,
vm::config::{MemoryConfig, PersistenceType},
};

const MAX: usize = 64;
Expand Down Expand Up @@ -151,12 +152,12 @@ fn test_memory_chip() {
pointer_max_bits: 15,
clk_max_bits: 15,
decomp: 8,
persistence_type: PersistenceType::Volatile,
};
let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, memory_config.decomp);
let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus));

let mut memory_chip =
MemoryChip::with_volatile_memory(memory_bus, memory_config, range_checker.clone());
let mut memory_chip = MemoryChip::new(memory_bus, memory_config.clone(), range_checker.clone());
let aux_factory = memory_chip.aux_cols_factory();

#[allow(clippy::large_enum_variant)]
Expand Down Expand Up @@ -276,6 +277,8 @@ fn test_memory_chip() {
))
.collect();

memory_chip.finalize(None::<&mut Poseidon2Chip<BabyBear>>);

let traces = memory_chip
.generate_traces()
.into_iter()
Expand Down
21 changes: 14 additions & 7 deletions vm/src/vm/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,24 @@ use crate::{
pub const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 25) - 100;
pub const DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE: usize = 7; // the sbox degree used for Poseidon2

#[derive(Debug, Serialize, Deserialize, Clone, Copy, new)]
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
pub enum PersistenceType {
Persistent,
Volatile,
}

#[derive(Debug, Serialize, Deserialize, Clone, new)]
pub struct MemoryConfig {
pub addr_space_max_bits: usize,
pub pointer_max_bits: usize,
pub clk_max_bits: usize,
pub decomp: usize,
pub persistence_type: PersistenceType,
}

impl Default for MemoryConfig {
fn default() -> Self {
Self::new(29, 29, 29, 16)
Self::new(29, 29, 29, 15, PersistenceType::Volatile)
}
}

Expand Down Expand Up @@ -116,7 +123,7 @@ pub struct VmConfig {
pub executors: Vec<(Range<usize>, ExecutorName, usize)>, // (range of opcodes, who executes, offset)
pub modular_executors: Vec<(Range<usize>, ExecutorName, usize, BigUint)>, // (range of opcodes, who executes, offset, modulus)

pub poseidon2_max_constraint_degree: Option<usize>,
pub poseidon2_max_constraint_degree: usize,
pub memory_config: MemoryConfig,
pub num_public_values: usize,
pub max_segment_len: usize,
Expand All @@ -128,7 +135,7 @@ pub struct VmConfig {

impl VmConfig {
pub fn from_parameters(
poseidon2_max_constraint_degree: Option<usize>,
poseidon2_max_constraint_degree: usize,
memory_config: MemoryConfig,
num_public_values: usize,
max_segment_len: usize,
Expand Down Expand Up @@ -224,7 +231,7 @@ impl Default for VmConfig {
impl VmConfig {
pub fn default_with_no_executors() -> Self {
Self::from_parameters(
Some(DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE),
DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE,
Default::default(),
0,
DEFAULT_MAX_SEGMENT_LEN,
Expand All @@ -242,7 +249,7 @@ impl VmConfig {

pub fn core() -> Self {
Self::from_parameters(
None,
DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE,
Default::default(),
0,
DEFAULT_MAX_SEGMENT_LEN,
Expand All @@ -255,7 +262,7 @@ impl VmConfig {

pub fn aggregation(poseidon2_max_constraint_degree: usize) -> Self {
VmConfig {
poseidon2_max_constraint_degree: Some(poseidon2_max_constraint_degree),
poseidon2_max_constraint_degree,
num_public_values: 4,
..VmConfig::default()
}
Expand Down
10 changes: 4 additions & 6 deletions vm/src/vm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{cell::RefCell, collections::VecDeque, mem::take, rc::Rc};
use std::{collections::VecDeque, mem::take};

use cycle_tracker::CycleTracker;
use metrics::VmMetrics;
Expand Down Expand Up @@ -73,11 +73,9 @@ impl<F: PrimeField32> VirtualMachine<F> {
self.segments.len() + 1
);
let program = self.program.clone();
let segment = ExecutionSegment::new(self.config.clone(), program, state);
let segment_rc = Rc::new(RefCell::new(segment));
segment_rc.borrow_mut().cycle_tracker = cycle_tracker;
self.segments
.push(Rc::try_unwrap(segment_rc).unwrap().into_inner());
let mut segment = ExecutionSegment::new(self.config.clone(), program, state);
segment.cycle_tracker = cycle_tracker;
self.segments.push(segment);
}

/// Retrieves the current state of the VM by querying the last segment.
Expand Down
Loading

0 comments on commit 21140b5

Please sign in to comment.