Skip to content

Commit

Permalink
feat: add vectorization to native load store circuit
Browse files Browse the repository at this point in the history
  • Loading branch information
yi-sun committed Jan 10, 2025
1 parent a7ec44f commit 032f8f3
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 33 deletions.
42 changes: 18 additions & 24 deletions extensions/native/circuit/src/adapters/loadstore_native_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use openvm_circuit::{
},
system::{
memory::{
offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols},
offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
MemoryAddress, MemoryController, OfflineMemory, RecordId,
},
program::ProgramBus,
Expand Down Expand Up @@ -41,8 +41,7 @@ pub struct NativeLoadStoreAdapterInterface<T, const NUM_CELLS: usize>(PhantomDat
impl<T, const NUM_CELLS: usize> VmAdapterInterface<T>
for NativeLoadStoreAdapterInterface<T, NUM_CELLS>
{
// TODO[yi]: Fix when vectorizing
type Reads = (T, T);
type Reads = (T, [T; NUM_CELLS]);
type Writes = [T; NUM_CELLS];
type ProcessedInstruction = NativeLoadStoreInstruction<T>;
}
Expand Down Expand Up @@ -110,10 +109,8 @@ pub struct NativeLoadStoreAdapterCols<T, const NUM_CELLS: usize> {
pub data_write_as: T,
pub data_write_pointer: T,

pub pointer_read_aux_cols: MemoryReadOrImmediateAuxCols<T>,
pub data_read_aux_cols: MemoryReadOrImmediateAuxCols<T>,
// TODO[yi]: Fix when vectorizing
// pub data_read_aux_cols: MemoryReadAuxCols<T, NUM_CELLS>,
pub pointer_read_aux_cols: MemoryReadAuxCols<T, 1>,
pub data_read_aux_cols: MemoryReadAuxCols<T, NUM_CELLS>,
pub data_write_aux_cols: MemoryWriteAuxCols<T, NUM_CELLS>,
}

Expand All @@ -140,9 +137,6 @@ impl<AB: InteractionBuilder, const NUM_CELLS: usize> VmAdapterAir<AB>
local: &[AB::Var],
ctx: AdapterAirContext<AB::Expr, Self::Interface>,
) {
// TODO[yi]: Remove when vectorizing
assert_eq!(NUM_CELLS, 1);

let cols: &NativeLoadStoreAdapterCols<_, NUM_CELLS> = local.borrow();
let timestamp = cols.from_state.timestamp;
let mut timestamp_delta = AB::Expr::from_canonical_usize(0);
Expand All @@ -154,9 +148,9 @@ impl<AB: InteractionBuilder, const NUM_CELLS: usize> VmAdapterAir<AB>

// first pointer read is always [c]_d
self.memory_bridge
.read_or_immediate(
.read(
MemoryAddress::new(cols.d, cols.c),
ctx.reads.0.clone(),
[ctx.reads.0.clone()],
timestamp + timestamp_delta.clone(),
&cols.pointer_read_aux_cols,
)
Expand All @@ -178,7 +172,7 @@ impl<AB: InteractionBuilder, const NUM_CELLS: usize> VmAdapterAir<AB>
is_storew.clone() * cols.a + is_loadw.clone() * (ctx.reads.0.clone() + cols.b),
);
self.memory_bridge
.read_or_immediate(
.read(
MemoryAddress::new(cols.data_read_as, cols.data_read_pointer),
ctx.reads.1.clone(),
timestamp + timestamp_delta.clone(),
Expand Down Expand Up @@ -228,8 +222,7 @@ impl<AB: InteractionBuilder, const NUM_CELLS: usize> VmAdapterAir<AB>
impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>
for NativeLoadStoreAdapterChip<F, NUM_CELLS>
{
// TODO[yi]: Fix when vectorizing
type ReadRecord = NativeLoadStoreReadRecord<F, 1>;
type ReadRecord = NativeLoadStoreReadRecord<F, NUM_CELLS>;
type WriteRecord = NativeLoadStoreWriteRecord<F, NUM_CELLS>;
type Air = NativeLoadStoreAdapterAir<NUM_CELLS>;
type Interface = NativeLoadStoreAdapterInterface<F, NUM_CELLS>;
Expand Down Expand Up @@ -266,15 +259,13 @@ impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>
let (data_read_ptr, data_write_ptr) = {
match local_opcode {
LOADW => (read_cell.1 + b, a),
STOREW => (a, read_cell.1 + b),
SHINTW => (a, read_cell.1 + b),
STOREW | SHINTW => (a, read_cell.1 + b),
}
};

// TODO[yi]: Fix when vectorizing
let data_read = match local_opcode {
SHINTW => None,
_ => Some(memory.read::<1>(data_read_as, data_read_ptr)),
LOADW | STOREW => Some(memory.read::<NUM_CELLS>(data_read_as, data_read_ptr)),
};
let record = NativeLoadStoreReadRecord {
pointer_read: read_cell.0,
Expand All @@ -288,7 +279,10 @@ impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>
e,
};

Ok(((read_cell.1, data_read.map_or(F::ZERO, |x| x.1[0])), record))
Ok((
(read_cell.1, data_read.map_or([F::ZERO; NUM_CELLS], |x| x.1)),
record,
))
}

fn postprocess(
Expand Down Expand Up @@ -333,17 +327,17 @@ impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>
if let Some(data_read) = data_read {
cols.data_read_as = data_read.address_space;
cols.data_read_pointer = data_read.pointer;
cols.data_read_aux_cols = aux_cols_factory.make_read_or_immediate_aux_cols(data_read);
cols.data_read_aux_cols = aux_cols_factory.make_read_aux_cols(data_read);
} else {
cols.data_read_aux_cols = MemoryReadOrImmediateAuxCols::disabled();
cols.data_read_aux_cols = MemoryReadAuxCols::disabled();
}

let write = memory.record_by_id(write_record.write_id);
cols.data_write_as = write.address_space;
cols.data_write_pointer = write.pointer;

cols.pointer_read_aux_cols = aux_cols_factory
.make_read_or_immediate_aux_cols(memory.record_by_id(read_record.pointer_read));
cols.pointer_read_aux_cols =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(read_record.pointer_read));
cols.data_write_aux_cols = aux_cols_factory.make_write_aux_cols(write);
}

Expand Down
13 changes: 7 additions & 6 deletions extensions/native/circuit/src/loadstore/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub struct NativeLoadStoreCoreCols<T, const NUM_CELLS: usize> {
pub is_shintw: T,

pub pointer_read: T,
pub data_read: T,
pub data_read: [T; NUM_CELLS],
pub data_write: [T; NUM_CELLS],
}

Expand All @@ -40,7 +40,8 @@ pub struct NativeLoadStoreCoreRecord<F, const NUM_CELLS: usize> {
pub opcode: NativeLoadStoreOpcode,

pub pointer_read: F,
pub data_read: F,
#[serde(with = "BigArray")]
pub data_read: [F; NUM_CELLS],
#[serde(with = "BigArray")]
pub data_write: [F; NUM_CELLS],
}
Expand All @@ -65,7 +66,7 @@ impl<AB, I, const NUM_CELLS: usize> VmCoreAir<AB, I> for NativeLoadStoreCoreAir<
where
AB: InteractionBuilder,
I: VmAdapterInterface<AB::Expr>,
I::Reads: From<(AB::Expr, AB::Expr)>,
I::Reads: From<(AB::Expr, [AB::Expr; NUM_CELLS])>,
I::Writes: From<[AB::Expr; NUM_CELLS]>,
I::ProcessedInstruction: From<NativeLoadStoreInstruction<AB::Expr>>,
{
Expand All @@ -92,7 +93,7 @@ where

AdapterAirContext {
to_pc: None,
reads: (cols.pointer_read.into(), cols.data_read.into()).into(),
reads: (cols.pointer_read.into(), cols.data_read.map(Into::into)).into(),
writes: cols.data_write.map(Into::into).into(),
instruction: NativeLoadStoreInstruction {
is_valid,
Expand Down Expand Up @@ -127,7 +128,7 @@ impl<F: Field, const NUM_CELLS: usize> NativeLoadStoreCoreChip<F, NUM_CELLS> {
impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_CELLS: usize> VmCoreChip<F, I>
for NativeLoadStoreCoreChip<F, NUM_CELLS>
where
I::Reads: Into<(F, F)>,
I::Reads: Into<(F, [F; NUM_CELLS])>,
I::Writes: From<[F; NUM_CELLS]>,
{
type Record = NativeLoadStoreCoreRecord<F, NUM_CELLS>;
Expand All @@ -151,7 +152,7 @@ where
}
array::from_fn(|_| streams.hint_stream.pop_front().unwrap())
} else {
[data_read; NUM_CELLS]
data_read
};

let output = AdapterRuntimeContext::without_pc(data_write);
Expand Down
3 changes: 0 additions & 3 deletions extensions/native/circuit/src/loadstore/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ fn rand_native_loadstore_test() {
set_and_execute(&mut tester, &mut chip, &mut rng, false, STOREW);
set_and_execute(&mut tester, &mut chip, &mut rng, false, SHINTW);
set_and_execute(&mut tester, &mut chip, &mut rng, false, LOADW);

set_and_execute(&mut tester, &mut chip, &mut rng, true, STOREW);
set_and_execute(&mut tester, &mut chip, &mut rng, true, SHINTW);
}
let tester = tester.build().load(chip).finalize();
tester.simple_test().expect("Verification failed");
Expand Down

0 comments on commit 032f8f3

Please sign in to comment.