From 68526ecadfae2751b0a39fa1c9cac2790250509e Mon Sep 17 00:00:00 2001 From: Morgan Thomas Date: Mon, 1 Apr 2024 22:57:44 -0400 Subject: [PATCH] wip: static data chip: making the constraints pass --- basic/tests/test_static_data.rs | 6 +-- cpu/src/lib.rs | 46 +++++++++++----------- memory/src/columns.rs | 2 +- memory/src/lib.rs | 68 ++++++++++++++++++--------------- static_data/src/lib.rs | 13 ++++--- 5 files changed, 74 insertions(+), 61 deletions(-) diff --git a/basic/tests/test_static_data.rs b/basic/tests/test_static_data.rs index ca7705d..e42e605 100644 --- a/basic/tests/test_static_data.rs +++ b/basic/tests/test_static_data.rs @@ -36,14 +36,14 @@ use valida_machine::__internal::p3_commit::ExtensionMmcs; #[test] fn prove_static_data() { // _start: - // imm32 0(fp), 0, 0, 0, 0 + // imm32 0(fp), 0, 0, 0, 0x13 // load32 -4(fp), 0(fp), 0, 0, 0 // bnei _start, 0(fp), 0x25, 0, 1 // infinite loop unless static value is loaded // stop let program = vec![ InstructionWord { opcode: , Val>>::OPCODE, - operands: Operands([0, 0, 0, 0, 0]), + operands: Operands([0, 0, 0, 0, 0x13]), }, InstructionWord { opcode: , Val>>::OPCODE, @@ -61,7 +61,7 @@ fn prove_static_data() { let mut machine = BasicMachine::::default(); let rom = ProgramROM::new(program); - machine.static_data_mut().write(0, Word([0, 0, 0, 0x25])); + machine.static_data_mut().write(0x13, Word([0, 0, 0, 0x25])); machine.program_mut().set_program_rom(&rom); machine.cpu_mut().fp = 0x1000; machine.cpu_mut().save_register_state(); // TODO: Initial register state should be saved diff --git a/cpu/src/lib.rs b/cpu/src/lib.rs index cbbcd89..f781a6c 100644 --- a/cpu/src/lib.rs +++ b/cpu/src/lib.rs @@ -1,4 +1,3 @@ -#![no_std] extern crate alloc; @@ -91,23 +90,23 @@ where fn global_sends(&self, machine: &M) -> Vec> { // Memory bus channels - let mem_sends = (0..3).map(|i| { - let channel = &CPU_COL_MAP.mem_channels[i]; - let is_read = VirtualPairCol::single_main(channel.is_read); - let clk = VirtualPairCol::single_main(CPU_COL_MAP.clk); - let addr = VirtualPairCol::single_main(channel.addr); - let is_static_initial = VirtualPairCol::constant(SC::Val::zero()); - let value = channel.value.0.map(VirtualPairCol::single_main); - - let mut fields = vec![is_read, clk, addr, is_static_initial]; - fields.extend(value); - - Interaction { - fields, - count: VirtualPairCol::single_main(channel.used), - argument_index: machine.mem_bus(), - } - }); + // let mem_sends = (0..3).map(|i| { + // let channel = &CPU_COL_MAP.mem_channels[i]; + // let is_read = VirtualPairCol::single_main(channel.is_read); + // let clk = VirtualPairCol::single_main(CPU_COL_MAP.clk); + // let addr = VirtualPairCol::single_main(channel.addr); + // let is_static_initial = VirtualPairCol::constant(SC::Val::zero()); + // let value = channel.value.0.map(VirtualPairCol::single_main); + + // let mut fields = vec![is_read, clk, addr, is_static_initial]; + // fields.extend(value); + + // Interaction { + // fields, + // count: VirtualPairCol::single_main(channel.used), + // argument_index: machine.mem_bus(), + // } + // }); // General bus channel let mut fields = vec![VirtualPairCol::single_main(CPU_COL_MAP.instruction.opcode)]; @@ -145,10 +144,11 @@ where // argument_index: machine.program_bus(), // }; - mem_sends - .chain(iter::once(send_general)) - // .chain(iter::once(send_program)) - .collect() + vec![send_general] + //mem_sends + // .chain(iter::once(send_general)) + // // .chain(iter::once(send_program)) + // .collect() } } @@ -212,6 +212,8 @@ impl CpuChip { } } + std::println!("cpu row: {:?}", row.clone()); + row } diff --git a/memory/src/columns.rs b/memory/src/columns.rs index cc3d2e8..c66b1e9 100644 --- a/memory/src/columns.rs +++ b/memory/src/columns.rs @@ -4,7 +4,7 @@ use valida_derive::AlignedBorrow; use valida_machine::Word; use valida_util::indices_arr; -#[derive(AlignedBorrow, Default)] +#[derive(AlignedBorrow, Default, Debug)] pub struct MemoryCols { /// Memory address pub addr: T, diff --git a/memory/src/lib.rs b/memory/src/lib.rs index effbb88..04043c9 100644 --- a/memory/src/lib.rs +++ b/memory/src/lib.rs @@ -1,5 +1,3 @@ -#![no_std] - extern crate alloc; use crate::columns::{MemoryCols, MEM_COL_MAP, NUM_MEM_COLS}; @@ -107,25 +105,25 @@ where SC: StarkConfig, { fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { - let mut ops = self - .operations - .par_iter() - .map(|(clk, ops)| { - ops.iter() - .map(|op| (*clk, *op)) - .collect::>() - }) - .collect::>() - .into_iter() - .flatten() - .collect::>(); - - // Sort first by addr, then by clk - ops.sort_by_key(|(clk, op)| (op.get_address(), *clk)); - - // Consecutive sorted clock cycles for an address should differ no more - // than the length of the table (capped at 2^29) - Self::insert_dummy_reads(&mut ops); + // let mut ops = self + // .operations + // .par_iter() + // .map(|(clk, ops)| { + // ops.iter() + // .map(|op| (*clk, *op)) + // .collect::>() + // }) + // .collect::>() + // .into_iter() + // .flatten() + // .collect::>(); + + // // Sort first by addr, then by clk + // ops.sort_by_key(|(clk, op)| (op.get_address(), *clk)); + + // // Consecutive sorted clock cycles for an address should differ no more + // // than the length of the table (capped at 2^29) + // Self::insert_dummy_reads(&mut ops); let mut rows = self.static_data .iter() @@ -133,24 +131,30 @@ where .map(|(n, (addr, value))| self.static_data_to_row(n, *addr, *value)) .collect::>(); + let padding_row = [SC::Val::zero(); NUM_MEM_COLS]; + let n0 = rows.len(); - let ops_rows = ops - .par_iter() - .enumerate() - .map(|(n, (clk, op))| self.op_to_row(n0+n, *clk as usize, *op)) - .collect::>(); - rows.extend(ops_rows); + // let ops_rows = ops + // .par_iter() + // .enumerate() + // .map(|(n, (clk, op))| self.op_to_row(n0+n, *clk as usize, *op)) + // .collect::>(); + // rows.extend(ops_rows.clone()); // Compute address difference values - self.compute_address_diffs(ops, &mut rows); + // self.compute_address_diffs(ops, &mut rows); // Make sure the table length is a power of two - rows.resize(rows.len().next_power_of_two(), [SC::Val::zero(); NUM_MEM_COLS]); + rows.resize(rows.len().next_power_of_two(), padding_row); let trace = - RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_MEM_COLS); + RowMajorMatrix::new(rows.clone().into_iter().flatten().collect::>(), NUM_MEM_COLS); + std::println!("static data = {:?}\nmemory trace rows = {:?}", + self.static_data, + rows); + // std::println!("static data = {:?}\nops = {:?}\nops rows = {:?}\nmemory trace = {:?}", self.static_data, ops, ops_rows, trace); trace } @@ -233,6 +237,10 @@ impl MemoryChip { cols.addr = F::from_canonical_u32(addr); cols.value = value.transform(F::from_canonical_u8); cols.is_write = F::one(); + cols.is_read = F::zero(); + cols.diff = F::zero(); + cols.diff_inv = F::zero(); + cols.addr_not_equal = F::zero(); row } diff --git a/static_data/src/lib.rs b/static_data/src/lib.rs index 81eb39c..056a923 100644 --- a/static_data/src/lib.rs +++ b/static_data/src/lib.rs @@ -1,11 +1,11 @@ -#![no_std] extern crate alloc; -use crate::columns::{NUM_STATIC_DATA_COLS, STATIC_DATA_COL_MAP}; +use crate::columns::{StaticDataCols, NUM_STATIC_DATA_COLS, STATIC_DATA_COL_MAP}; use alloc::collections::BTreeMap; use alloc::vec; use alloc::vec::Vec; +use core::mem::transmute; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field}; use p3_matrix::dense::RowMajorMatrix; @@ -55,8 +55,12 @@ where fn generate_trace(&self, machine: &M) -> RowMajorMatrix { let mut rows = self.cells.iter() .map(|(addr, value)| { - let mut row: Vec = vec![SC::Val::from_canonical_u32(*addr)]; - row.extend(value.0.into_iter().map(SC::Val::from_canonical_u8).collect::>()); + let mut row = [SC::Val::zero(); NUM_STATIC_DATA_COLS]; + let cols: &mut StaticDataCols = unsafe { transmute(&mut row) }; + cols.addr = SC::Val::from_canonical_u32(*addr); + cols.value = value.transform(SC::Val::from_canonical_u8); + cols.is_real = SC::Val::one(); + std::println!("static data row: {:?}\n", row.clone()); row }) .flatten() @@ -66,7 +70,6 @@ where } fn global_sends(&self, machine: &M) -> Vec> { - // return vec![]; // TODO let addr = VirtualPairCol::single_main(STATIC_DATA_COL_MAP.addr); let value = STATIC_DATA_COL_MAP.value.0.map(VirtualPairCol::single_main); let is_read = VirtualPairCol::constant(SC::Val::zero());