Skip to content

Commit

Permalink
feat: vectorize native load store (#1138)
Browse files Browse the repository at this point in the history
* feat: add vectorization to native load store circuit

* chore: update comment

* feat: replace LOAD/STORE immediates with ADD

* fix: lint

* feat: add LOADW4 STOREW4 SHINTW4

* chore: rename SHINTW -> HINT_STOREW

* feat: optimize out 2 columns after removing immediates

* chore: remove TODO

* chore: extract opcode arrays into constants

* chore: remove magic number 4

* fix: update integration tests

* fix: test_vm_override_executor_height
  • Loading branch information
yi-sun authored Jan 12, 2025
1 parent fc0be52 commit ee534ee
Show file tree
Hide file tree
Showing 11 changed files with 354 additions and 227 deletions.
203 changes: 112 additions & 91 deletions crates/vm/tests/integration_test.rs

Large diffs are not rendered by default.

15 changes: 8 additions & 7 deletions docs/specs/ISA.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ addressable cells. Registers are represented using the [LIMB] format with `LIMB_
## Hints

The `input_stream` is a private non-interactive queue of vectors of field elements which is provided at the start of
runtime execution. The `hint_stream` is a queue of values that can be written to memory by calling the `HINTSTOREW_RV32` and `HINTSTORE` instructions. The `hint_stream` is populated via [phantom sub-instructions](#phantom-sub-instructions) such
runtime execution. The `hint_stream` is a queue of values that can be written to memory by calling the `HINT_STOREW_RV32` and `HINT_STORE_RV32` instructions. The `hint_stream` is populated via [phantom sub-instructions](#phantom-sub-instructions) such
as `HINT_INPUT` and `HINT_BITS`.

## Public Outputs
Expand Down Expand Up @@ -266,7 +266,7 @@ We use the same notation for `r32{c}(b) := i32([b:4]_1) + sign_extend(decompose(

| Name | Operands | Description |
| --------------- | ----------- | ----------------------------------------------------------------------------------------------------------------------------------- |
| HINTSTOREW_RV32 | `_,b,c,1,2` | `[r32{c}(b):4]_2 = next 4 bytes from hint stream`. Only valid if next 4 values in hint stream are bytes. |
| HINT_STOREW_RV32 | `_,b,c,1,2` | `[r32{c}(b):4]_2 = next 4 bytes from hint stream`. Only valid if next 4 values in hint stream are bytes. |
| REVEAL_RV32 | `a,b,c,1,3` | Pseudo-instruction for `STOREW_RV32 a,b,c,1,3` writing to the user IO address space `3`. Only valid when continuations are enabled. |

### Hashes
Expand Down Expand Up @@ -425,16 +425,17 @@ instruction format suggested by Max Gillet to enable easier compatibility with o

In the instructions below, `d,e` may be any valid address space unless otherwise specified. In particular, the immediate address space `0` is allowed for non-vectorized reads but not allowed for writes. When using immediates, we interpret `[a]_0` as the immediate value `a`. Base kernel instructions enable memory movement between address spaces.

In some instructions below, `W` is a generic parameter for the block size.

| Name | Operands | Description |
| -------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| LOAD\<W\> | `a,b,c,d,e` | Set `[a:W]_d = [[c]_d + b:W]_e`. Both `d, e` must be non-zero. |
| STORE\<W\> | `a,b,c,d,e` | Set `[[c]_d + b:W]_e = [a:W]_d`. Both `d, e` must be non-zero. |
| LOADW | `a,b,c,d,e` | Set `[a]_d = [[c]_d + b]_e`. Both `d, e` must be non-zero. |
| STOREW | `a,b,c,d,e` | Set `[[c]_d + b]_e = [a]_d`. Both `d, e` must be non-zero. |
| LOADW4 | `a,b,c,d,e` | Set `[a:4]_d = [[c]_d + b:4]_e`. Both `d, e` must be non-zero. |
| STOREW4 | `a,b,c,d,e` | Set `[[c]_d + b:4]_e = [a:4]_d`. Both `d, e` must be non-zero. |
| JAL | `a,b,c,d` | Jump to address and link: set `[a]_d = (pc + DEFAULT_PC_STEP)` and `pc = pc + b`. Here `d` must be non-zero. |
| BEQ\<W\> | `a,b,c,d,e` | If `[a:W]_d == [b:W]_e`, then set `pc = pc + c`. |
| BNE\<W\> | `a,b,c,d,e` | If `[a:W]_d != [b:W]_e`, then set `pc = pc + c`. |
| HINTSTORE\<W\> | `_,b,c,d,e` | Set `[[c]_d + b:W]_e = next W elements from hint stream`. Both `d, e` must be non-zero. |
| HINT_STOREW | `_,b,c,d,e` | Set `[[c]_d + b]_e = next element from hint stream`. Both `d, e` must be non-zero. |
| HINT_STOREW4 | `_,b,c,d,e` | Set `[[c]_d + b:4]_e = next 4 elements from hint stream`. Both `d, e` must be non-zero. |
| PUBLISH | `a,b,_,d,e` | Set the user public output at index `[a]_d` to equal `[b]_e`. Invalid if `[a]_d` is greater than or equal to the configured length of user public outputs. Only valid when continuations are disabled. |
| CASTF | `a,b,_,d,e` | Cast a field element represented as `u32` into four bytes in little-endian: Set `[a:4]_d` to the unique array such that `sum_{i=0}^3 [a + i]_d * 2^{8i} = [b]_e` where `[a + i]_d < 2^8` for `i = 0..2` and `[a + 3]_d < 2^6`. This opcode constrains that `[b]_e` must be at most 30-bits. Both `d, e` must be non-zero. |

Expand Down
2 changes: 1 addition & 1 deletion docs/specs/RISCV.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ The transpilation will only be valid for programs where:
| RISC-V Inst | OpenVM Instruction |
| -------------- | ---------------------------------------------------------------- |
| terminate | TERMINATE `_, _, utof(imm)` |
| hintstorew | HINTSTOREW_RV32 `0, ind(rd), utof(sign_extend_16(imm)), 1, 2` |
| hintstorew | HINT_STOREW_RV32 `0, ind(rd), utof(sign_extend_16(imm)), 1, 2` |
| reveal | REVEAL_RV32 `0, ind(rd), utof(sign_extend_16(imm)), 1, 3` |
| hintinput | PHANTOM `_, _, HintInputRv32 as u16` |
| printstr | PHANTOM `ind(rd), ind(rs1), PrintStrRv32 as u16` |
Expand Down
88 changes: 34 additions & 54 deletions extensions/native/circuit/src/adapters/loadstore_native_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use openvm_circuit::{
},
system::{
memory::{
offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols},
offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
MemoryAddress, MemoryController, OfflineMemory, RecordId,
},
program::ProgramBus,
Expand All @@ -33,16 +33,15 @@ pub struct NativeLoadStoreInstruction<T> {
pub opcode: T,
pub is_loadw: T,
pub is_storew: T,
pub is_shintw: T,
pub is_hint_storew: T,
}

pub struct NativeLoadStoreAdapterInterface<T, const NUM_CELLS: usize>(PhantomData<T>);

impl<T, const NUM_CELLS: usize> VmAdapterInterface<T>
for NativeLoadStoreAdapterInterface<T, NUM_CELLS>
{
// TODO[yi]: Fix when vectorizing
type Reads = (T, T);
type Reads = (T, [T; NUM_CELLS]);
type Writes = [T; NUM_CELLS];
type ProcessedInstruction = NativeLoadStoreInstruction<T>;
}
Expand Down Expand Up @@ -104,16 +103,11 @@ pub struct NativeLoadStoreAdapterCols<T, const NUM_CELLS: usize> {
pub d: T,
pub e: T,

pub data_read_as: T,
pub data_read_pointer: T,

pub data_write_as: T,
pub data_write_pointer: T,

pub pointer_read_aux_cols: MemoryReadOrImmediateAuxCols<T>,
pub data_read_aux_cols: MemoryReadOrImmediateAuxCols<T>,
// TODO[yi]: Fix when vectorizing
// pub data_read_aux_cols: MemoryReadAuxCols<T, NUM_CELLS>,
pub pointer_read_aux_cols: MemoryReadAuxCols<T, 1>,
pub data_read_aux_cols: MemoryReadAuxCols<T, NUM_CELLS>,
pub data_write_aux_cols: MemoryWriteAuxCols<T, NUM_CELLS>,
}

Expand All @@ -140,63 +134,49 @@ impl<AB: InteractionBuilder, const NUM_CELLS: usize> VmAdapterAir<AB>
local: &[AB::Var],
ctx: AdapterAirContext<AB::Expr, Self::Interface>,
) {
// TODO[yi]: Remove when vectorizing
assert_eq!(NUM_CELLS, 1);

let cols: &NativeLoadStoreAdapterCols<_, NUM_CELLS> = local.borrow();
let timestamp = cols.from_state.timestamp;
let mut timestamp_delta = AB::Expr::from_canonical_usize(0);

let is_valid = ctx.instruction.is_valid;
let is_loadw = ctx.instruction.is_loadw;
let is_storew = ctx.instruction.is_storew;
let is_shintw = ctx.instruction.is_shintw;
let is_hint_storew = ctx.instruction.is_hint_storew;

// first pointer read is always [c]_d
self.memory_bridge
.read_or_immediate(
.read(
MemoryAddress::new(cols.d, cols.c),
ctx.reads.0.clone(),
[ctx.reads.0.clone()],
timestamp + timestamp_delta.clone(),
&cols.pointer_read_aux_cols,
)
.eval(builder, is_valid.clone());
timestamp_delta += is_valid.clone();

// TODO[yi]: Remove when vectorizing
// read data, disabled if SHINTW
// data pointer = [c]_d + [f]_d * g + b, degree 2
builder
.when(is_valid.clone() - is_shintw.clone())
.assert_eq(
cols.data_read_as,
utils::select::<AB::Expr>(is_loadw.clone(), cols.e, cols.d),
);
// TODO[yi]: Do we need to check for overflow?
builder.assert_eq(
(is_valid.clone() - is_shintw.clone()) * cols.data_read_pointer,
is_storew.clone() * cols.a + is_loadw.clone() * (ctx.reads.0.clone() + cols.b),
);
self.memory_bridge
.read_or_immediate(
MemoryAddress::new(cols.data_read_as, cols.data_read_pointer),
.read(
MemoryAddress::new(
utils::select::<AB::Expr>(is_loadw.clone(), cols.e, cols.d),
is_storew.clone() * cols.a + is_loadw.clone() * (ctx.reads.0.clone() + cols.b),
),
ctx.reads.1.clone(),
timestamp + timestamp_delta.clone(),
&cols.data_read_aux_cols,
)
.eval(builder, is_valid.clone() - is_shintw.clone());
timestamp_delta += is_valid.clone() - is_shintw.clone();
.eval(builder, is_valid.clone() - is_hint_storew.clone());
timestamp_delta += is_valid.clone() - is_hint_storew.clone();

// data write
builder.when(is_valid.clone()).assert_eq(
cols.data_write_as,
utils::select::<AB::Expr>(is_loadw.clone(), cols.d, cols.e),
);
// TODO[yi]: Do we need to check for overflow?

builder.assert_eq(
is_valid.clone() * cols.data_write_pointer,
is_loadw.clone() * cols.a
+ (is_storew.clone() + is_shintw.clone()) * (ctx.reads.0.clone() + cols.b),
+ (is_storew.clone() + is_hint_storew.clone()) * (ctx.reads.0.clone() + cols.b),
);
self.memory_bridge
.write(
Expand Down Expand Up @@ -228,8 +208,7 @@ impl<AB: InteractionBuilder, const NUM_CELLS: usize> VmAdapterAir<AB>
impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>
for NativeLoadStoreAdapterChip<F, NUM_CELLS>
{
// TODO[yi]: Fix when vectorizing
type ReadRecord = NativeLoadStoreReadRecord<F, 1>;
type ReadRecord = NativeLoadStoreReadRecord<F, NUM_CELLS>;
type WriteRecord = NativeLoadStoreWriteRecord<F, NUM_CELLS>;
type Air = NativeLoadStoreAdapterAir<NUM_CELLS>;
type Interface = NativeLoadStoreAdapterInterface<F, NUM_CELLS>;
Expand Down Expand Up @@ -259,22 +238,22 @@ impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>

let (data_read_as, data_write_as) = {
match local_opcode {
LOADW => (e, d),
STOREW | SHINTW => (d, e),
LOADW | LOADW4 => (e, d),
STOREW | STOREW4 | HINT_STOREW | HINT_STOREW4 => (d, e),
}
};
let (data_read_ptr, data_write_ptr) = {
match local_opcode {
LOADW => (read_cell.1 + b, a),
STOREW => (a, read_cell.1 + b),
SHINTW => (a, read_cell.1 + b),
LOADW | LOADW4 => (read_cell.1 + b, a),
STOREW | STOREW4 | HINT_STOREW | HINT_STOREW4 => (a, read_cell.1 + b),
}
};

// TODO[yi]: Fix when vectorizing
let data_read = match local_opcode {
SHINTW => None,
_ => Some(memory.read::<1>(data_read_as, data_read_ptr)),
HINT_STOREW | HINT_STOREW4 => None,
LOADW | LOADW4 | STOREW | STOREW4 => {
Some(memory.read::<NUM_CELLS>(data_read_as, data_read_ptr))
}
};
let record = NativeLoadStoreReadRecord {
pointer_read: read_cell.0,
Expand All @@ -288,7 +267,10 @@ impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>
e,
};

Ok(((read_cell.1, data_read.map_or(F::ZERO, |x| x.1[0])), record))
Ok((
(read_cell.1, data_read.map_or([F::ZERO; NUM_CELLS], |x| x.1)),
record,
))
}

fn postprocess(
Expand Down Expand Up @@ -331,19 +313,17 @@ impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>

let data_read = read_record.data_read.map(|read| memory.record_by_id(read));
if let Some(data_read) = data_read {
cols.data_read_as = data_read.address_space;
cols.data_read_pointer = data_read.pointer;
cols.data_read_aux_cols = aux_cols_factory.make_read_or_immediate_aux_cols(data_read);
cols.data_read_aux_cols = aux_cols_factory.make_read_aux_cols(data_read);
} else {
cols.data_read_aux_cols = MemoryReadOrImmediateAuxCols::disabled();
cols.data_read_aux_cols = MemoryReadAuxCols::disabled();
}

let write = memory.record_by_id(write_record.write_id);
cols.data_write_as = write.address_space;
cols.data_write_pointer = write.pointer;

cols.pointer_read_aux_cols = aux_cols_factory
.make_read_or_immediate_aux_cols(memory.record_by_id(read_record.pointer_read));
cols.pointer_read_aux_cols =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(read_record.pointer_read));
cols.data_write_aux_cols = aux_cols_factory.make_write_aux_cols(write);
}

Expand Down
29 changes: 27 additions & 2 deletions extensions/native/circuit/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use openvm_instructions::{
};
use openvm_native_compiler::{
FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode,
NativeJalOpcode, NativeLoadStoreOpcode, NativePhantom,
NativeJalOpcode, NativeLoadStoreOpcode, NativePhantom, BLOCK_LOAD_STORE_OPCODES,
BLOCK_LOAD_STORE_SIZE, SINGLE_LOAD_STORE_OPCODES,
};
use openvm_poseidon2_air::Poseidon2Config;
use openvm_rv32im_circuit::BranchEqualCoreChip;
Expand Down Expand Up @@ -72,6 +73,7 @@ pub struct Native;
#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)]
pub enum NativeExecutor<F: PrimeField32> {
LoadStore(NativeLoadStoreChip<F, 1>),
BlockLoadStore(NativeLoadStoreChip<F, 4>),
BranchEqual(NativeBranchEqChip<F>),
Jal(NativeJalChip<F>),
FieldArithmetic(FieldArithmeticChip<F>),
Expand Down Expand Up @@ -115,7 +117,30 @@ impl<F: PrimeField32> VmExtension<F> for Native {

inventory.add_executor(
load_store_chip,
NativeLoadStoreOpcode::iter().map(VmOpcode::with_default_offset),
SINGLE_LOAD_STORE_OPCODES
.iter()
.map(|&opcode| VmOpcode::with_default_offset(opcode)),
)?;

let mut block_load_store_chip = NativeLoadStoreChip::<F, BLOCK_LOAD_STORE_SIZE>::new(
NativeLoadStoreAdapterChip::new(
execution_bus,
program_bus,
memory_bridge,
NativeLoadStoreOpcode::default_offset(),
),
NativeLoadStoreCoreChip::new(NativeLoadStoreOpcode::default_offset()),
offline_memory.clone(),
);
block_load_store_chip
.core
.set_streams(builder.streams().clone());

inventory.add_executor(
block_load_store_chip,
BLOCK_LOAD_STORE_OPCODES
.iter()
.map(|&opcode| VmOpcode::with_default_offset(opcode)),
)?;

let branch_equal_chip = NativeBranchEqChip::new(
Expand Down
Loading

0 comments on commit ee534ee

Please sign in to comment.