Skip to content

Commit

Permalink
perf: Move range checks to trace gen (#1194)
Browse files Browse the repository at this point in the history
  • Loading branch information
zlangley authored Jan 8, 2025
1 parent 5d7c7fc commit 3a1d8ec
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 48 deletions.
22 changes: 11 additions & 11 deletions extensions/native/circuit/src/castf/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ where
#[derive(Debug)]
pub struct CastFRecord<F> {
pub in_val: F,
pub out_val: [F; RV32_REGISTER_NUM_LIMBS],
pub out_val: [u32; RV32_REGISTER_NUM_LIMBS],
}

#[derive(Debug)]
Expand Down Expand Up @@ -142,15 +142,7 @@ where
);

let y = reads.into()[0][0];

let x = CastF::solve(y.as_canonical_u32());
for (i, limb) in x.iter().enumerate() {
if i == 3 {
self.range_checker_chip.add_count(*limb, FINAL_LIMB_BITS);
} else {
self.range_checker_chip.add_count(*limb, LIMB_BITS);
}
}

let output = AdapterRuntimeContext {
to_pc: None,
Expand All @@ -159,7 +151,7 @@ where

let record = CastFRecord {
in_val: y,
out_val: x.map(F::from_canonical_u32),
out_val: x,
};

Ok((output, record))
Expand All @@ -170,9 +162,17 @@ where
}

fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
for (i, limb) in record.out_val.iter().enumerate() {
if i == 3 {
self.range_checker_chip.add_count(*limb, FINAL_LIMB_BITS);
} else {
self.range_checker_chip.add_count(*limb, LIMB_BITS);
}
}

let cols: &mut CastFCoreCols<F> = row_slice.borrow_mut();
cols.in_val = record.in_val;
cols.out_val = record.out_val;
cols.out_val = record.out_val.map(F::from_canonical_u32);
cols.is_valid = F::ONE;
}

Expand Down
26 changes: 14 additions & 12 deletions extensions/rv32im/circuit/src/adapters/hintstore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ impl<F: PrimeField32> Rv32HintStoreAdapterChip<F> {
pointer_max_bits: usize,
range_checker_chip: Arc<VariableRangeCheckerChip>,
) -> Self {
assert!(range_checker_chip.range_max_bits() >= 16);
Self {
air: Rv32HintStoreAdapterAir {
execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
Expand All @@ -72,7 +73,7 @@ pub struct Rv32HintStoreReadRecord<F: Field> {

pub imm: F,
pub imm_sign: bool,
pub mem_ptr_limbs: [F; 2],
pub mem_ptr_limbs: [u32; 2],
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -249,7 +250,6 @@ impl<F: PrimeField32> VmAdapterChip<F> for Rv32HintStoreAdapterChip<F> {
let Instruction { b, c, d, e, .. } = *instruction;
debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
assert!(self.range_checker_chip.range_max_bits() >= 16);

let rs1_record = memory.read::<RV32_REGISTER_NUM_LIMBS>(d, b);
let rs1_val = compose(rs1_record.1);
Expand All @@ -260,12 +260,6 @@ impl<F: PrimeField32> VmAdapterChip<F> for Rv32HintStoreAdapterChip<F> {
let ptr_val = rs1_val.wrapping_add(imm_extended);
assert!(ptr_val < (1 << self.air.pointer_max_bits));
let mem_ptr_limbs = array::from_fn(|i| ((ptr_val >> (i * (RV32_CELL_BITS * 2))) & 0xffff));
self.range_checker_chip
.add_count(mem_ptr_limbs[0], RV32_CELL_BITS * 2);
self.range_checker_chip.add_count(
mem_ptr_limbs[1],
self.air.pointer_max_bits - RV32_CELL_BITS * 2,
);

Ok((
[],
Expand All @@ -274,7 +268,7 @@ impl<F: PrimeField32> VmAdapterChip<F> for Rv32HintStoreAdapterChip<F> {
rs1_ptr: b,
imm: c,
imm_sign: imm_sign == 1,
mem_ptr_limbs: mem_ptr_limbs.map(F::from_canonical_u32),
mem_ptr_limbs,
},
))
}
Expand All @@ -288,8 +282,9 @@ impl<F: PrimeField32> VmAdapterChip<F> for Rv32HintStoreAdapterChip<F> {
read_record: &Self::ReadRecord,
) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
let ptr = read_record.mem_ptr_limbs[0]
+ read_record.mem_ptr_limbs[1] * F::from_canonical_u32(1 << (RV32_CELL_BITS * 2));
let (write_record_id, _) = memory.write(instruction.e, ptr, output.writes[0]);
+ read_record.mem_ptr_limbs[1] * (1 << (RV32_CELL_BITS * 2));
let (write_record_id, _) =
memory.write(instruction.e, F::from_canonical_u32(ptr), output.writes[0]);

Ok((
ExecutionState {
Expand All @@ -310,6 +305,13 @@ impl<F: PrimeField32> VmAdapterChip<F> for Rv32HintStoreAdapterChip<F> {
write_record: Self::WriteRecord,
memory: &OfflineMemory<F>,
) {
self.range_checker_chip
.add_count(read_record.mem_ptr_limbs[0], RV32_CELL_BITS * 2);
self.range_checker_chip.add_count(
read_record.mem_ptr_limbs[1],
self.air.pointer_max_bits - RV32_CELL_BITS * 2,
);

let aux_cols_factory = memory.aux_cols_factory();
let adapter_cols: &mut Rv32HintStoreAdapterCols<_> = row_slice.borrow_mut();
adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32);
Expand All @@ -319,7 +321,7 @@ impl<F: PrimeField32> VmAdapterChip<F> for Rv32HintStoreAdapterChip<F> {
adapter_cols.rs1_ptr = read_record.rs1_ptr;
adapter_cols.imm = read_record.imm;
adapter_cols.imm_sign = F::from_bool(read_record.imm_sign);
adapter_cols.mem_ptr_limbs = read_record.mem_ptr_limbs;
adapter_cols.mem_ptr_limbs = read_record.mem_ptr_limbs.map(F::from_canonical_u32);

let rd = memory.record_by_id(write_record.record_id);
adapter_cols.write_aux = aux_cols_factory.make_write_aux_cols(rd);
Expand Down
21 changes: 12 additions & 9 deletions extensions/rv32im/circuit/src/adapters/loadstore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ impl<F: PrimeField32> Rv32LoadStoreAdapterChip<F> {
range_checker_chip: Arc<VariableRangeCheckerChip>,
offset: usize,
) -> Self {
assert!(range_checker_chip.range_max_bits() >= 15);
Self {
air: Rv32LoadStoreAdapterAir {
execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
Expand All @@ -133,6 +134,7 @@ pub struct Rv32LoadStoreReadRecord<F: Field> {
pub imm_sign: bool,
pub mem_ptr_limbs: [u32; 2],
pub mem_as: F,
pub shift_amount: u32,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -351,7 +353,6 @@ impl<F: PrimeField32> VmAdapterChip<F> for Rv32LoadStoreAdapterChip<F> {
} = *instruction;
debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
debug_assert!(e.as_canonical_u32() != RV32_IMM_AS);
assert!(self.range_checker_chip.range_max_bits() >= 15);

let local_opcode = Rv32LoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset));
let rs1_record = memory.read::<RV32_REGISTER_NUM_LIMBS>(d, b);
Expand All @@ -370,14 +371,6 @@ impl<F: PrimeField32> VmAdapterChip<F> for Rv32LoadStoreAdapterChip<F> {
);

let mem_ptr_limbs = array::from_fn(|i| ((ptr_val >> (i * (RV32_CELL_BITS * 2))) & 0xffff));
self.range_checker_chip.add_count(
(mem_ptr_limbs[0] - shift_amount) / 4,
RV32_CELL_BITS * 2 - 2,
);
self.range_checker_chip.add_count(
mem_ptr_limbs[1],
self.air.pointer_max_bits - RV32_CELL_BITS * 2,
);

let ptr_val = ptr_val - shift_amount;
let read_record = match local_opcode {
Expand Down Expand Up @@ -408,6 +401,7 @@ impl<F: PrimeField32> VmAdapterChip<F> for Rv32LoadStoreAdapterChip<F> {
read: read_record.0,
imm: c,
imm_sign: imm_sign == 1,
shift_amount,
mem_ptr_limbs,
mem_as: e,
},
Expand Down Expand Up @@ -457,6 +451,15 @@ impl<F: PrimeField32> VmAdapterChip<F> for Rv32LoadStoreAdapterChip<F> {
write_record: Self::WriteRecord,
memory: &OfflineMemory<F>,
) {
self.range_checker_chip.add_count(
(read_record.mem_ptr_limbs[0] - read_record.shift_amount) / 4,
RV32_CELL_BITS * 2 - 2,
);
self.range_checker_chip.add_count(
read_record.mem_ptr_limbs[1],
self.air.pointer_max_bits - RV32_CELL_BITS * 2,
);

let aux_cols_factory = memory.aux_cols_factory();
let adapter_cols: &mut Rv32LoadStoreAdapterCols<_> = row_slice.borrow_mut();
adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32);
Expand Down
8 changes: 4 additions & 4 deletions extensions/rv32im/circuit/src/jalr/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ impl Rv32JalrCoreChip {
range_checker_chip: Arc<VariableRangeCheckerChip>,
offset: usize,
) -> Self {
assert!(range_checker_chip.bus().range_max_bits >= 15);
assert!(range_checker_chip.range_max_bits() >= 16);
Self {
air: Rv32JalrCoreAir {
bitwise_lookup_bus: bitwise_lookup_chip.bus(),
Expand Down Expand Up @@ -211,7 +211,6 @@ where
from_pc: u32,
reads: I::Reads,
) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
assert!(self.range_checker_chip.range_max_bits() >= 16);
let Instruction { opcode, c, .. } = *instruction;
let local_opcode = Rv32JalrOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));

Expand All @@ -235,8 +234,6 @@ where
let to_pc_least_sig_bit = rs1_val.wrapping_add(imm_extended) & 1;

let to_pc_limbs = array::from_fn(|i| ((to_pc >> (1 + i * 15)) & mask));
self.range_checker_chip.add_count(to_pc_limbs[0], 15);
self.range_checker_chip.add_count(to_pc_limbs[1], 14);

let rd_data = rd_data.map(F::from_canonical_u32);

Expand All @@ -263,6 +260,9 @@ where
}

fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
self.range_checker_chip.add_count(record.to_pc_limbs[0], 15);
self.range_checker_chip.add_count(record.to_pc_limbs[1], 14);

let core_cols: &mut Rv32JalrCoreCols<F> = row_slice.borrow_mut();
core_cols.imm = record.imm;
core_cols.rd_data = record.rd_data;
Expand Down
29 changes: 17 additions & 12 deletions extensions/rv32im/circuit/src/shift/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ pub struct ShiftCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
pub a: [T; NUM_LIMBS],
pub b: [T; NUM_LIMBS],
pub c: [T; NUM_LIMBS],
pub bit_shift_carry: [T; NUM_LIMBS],
pub bit_shift_carry: [u32; NUM_LIMBS],
pub bit_shift: usize,
pub limb_shift: usize,
pub b_sign: T,
Expand Down Expand Up @@ -304,27 +304,18 @@ where
.request_xor(b[NUM_LIMBS - 1], 1 << (LIMB_BITS - 1));
}

let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2();
self.range_checker_chip.add_count(
(((c[0] as usize) - bit_shift - limb_shift * LIMB_BITS) >> num_bits_log) as u32,
LIMB_BITS - num_bits_log as usize,
);

for i in 0..(NUM_LIMBS / 2) {
self.bitwise_lookup_chip
.request_range(a[i * 2], a[i * 2 + 1]);
}
for carry_val in bit_shift_carry {
self.range_checker_chip.add_count(carry_val, bit_shift);
}

let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]);
let record = ShiftCoreRecord {
opcode: shift_opcode,
a: a.map(F::from_canonical_u32),
b: data[0],
c: data[1],
bit_shift_carry: bit_shift_carry.map(F::from_canonical_u32),
bit_shift_carry,
bit_shift,
limb_shift,
b_sign: F::from_canonical_u32(b_sign),
Expand All @@ -338,6 +329,20 @@ where
}

fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
for carry_val in record.bit_shift_carry {
self.range_checker_chip
.add_count(carry_val, record.bit_shift);
}

let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2();
self.range_checker_chip.add_count(
(((record.c[0].as_canonical_u32() as usize)
- record.bit_shift
- record.limb_shift * LIMB_BITS)
>> num_bits_log) as u32,
LIMB_BITS - num_bits_log as usize,
);

let row_slice: &mut ShiftCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut();
row_slice.a = record.a;
row_slice.b = record.b;
Expand All @@ -353,7 +358,7 @@ where
row_slice.b_sign = record.b_sign;
row_slice.bit_shift_marker = array::from_fn(|i| F::from_bool(i == record.bit_shift));
row_slice.limb_shift_marker = array::from_fn(|i| F::from_bool(i == record.limb_shift));
row_slice.bit_shift_carry = record.bit_shift_carry;
row_slice.bit_shift_carry = record.bit_shift_carry.map(F::from_canonical_u32);
row_slice.opcode_sll_flag = F::from_bool(record.opcode == ShiftOpcode::SLL);
row_slice.opcode_srl_flag = F::from_bool(record.opcode == ShiftOpcode::SRL);
row_slice.opcode_sra_flag = F::from_bool(record.opcode == ShiftOpcode::SRA);
Expand Down

0 comments on commit 3a1d8ec

Please sign in to comment.