Skip to content

Commit

Permalink
Better use of generics
Browse files Browse the repository at this point in the history
This addresses some limitations that were getting in the way of finishing the prover.

- Removes `FE` from `Machine` - it's not really inherent to the machine, and we have it in `StarkConfig`.
- Give `Instruction` a `F: Field` generic, so we can have instructions which only support certain fields.
- Give `Chip` a `SC: StarkConfig` generic, so we can have chips which only support certain fields etc. (I think ideally it would be `F: Field` as above rather than `SC: StarkConfig`, but Rust doesn't support bounds like `Chip<...>: for<'a> for<SC> Air<ConstraintFolder<'a, M, SC>>`.)
  • Loading branch information
dlubarov committed Jan 14, 2024
1 parent 413bb4c commit 692e6d8
Show file tree
Hide file tree
Showing 27 changed files with 604 additions and 606 deletions.
28 changes: 15 additions & 13 deletions alu_u32/src/add/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ use valida_opcodes::ADD32;
use valida_range::MachineWithRangeChip;

use p3_air::VirtualPairCol;
use p3_field::PrimeField;
use p3_field::{AbstractField, Field, PrimeField};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::*;
use valida_machine::config::StarkConfig;
use valida_util::pad_to_power_of_two;

pub mod columns;
Expand All @@ -29,12 +30,12 @@ pub struct Add32Chip {
pub operations: Vec<Operation>,
}

impl<F, M> Chip<M> for Add32Chip
impl<M, SC> Chip<M, SC> for Add32Chip
where
F: PrimeField,
M: MachineWithGeneralBus<F = F> + MachineWithRangeBus8,
M: MachineWithGeneralBus<SC::Val> + MachineWithRangeBus8<SC::Val>,
SC: StarkConfig,
{
fn generate_trace(&self, _machine: &M) -> RowMajorMatrix<M::F> {
fn generate_trace(&self, _machine: &M) -> RowMajorMatrix<SC::Val> {
let rows = self
.operations
.par_iter()
Expand All @@ -44,12 +45,12 @@ where
let mut trace =
RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_ADD_COLS);

pad_to_power_of_two::<NUM_ADD_COLS, F>(&mut trace.values);
pad_to_power_of_two::<NUM_ADD_COLS, SC::Val>(&mut trace.values);

trace
}

fn global_sends(&self, machine: &M) -> Vec<Interaction<M::F>> {
fn global_sends(&self, machine: &M) -> Vec<Interaction<SC::Val>> {
let sends = ADD_COL_MAP
.output
.0
Expand All @@ -66,8 +67,8 @@ where
sends
}

fn global_receives(&self, machine: &M) -> Vec<Interaction<M::F>> {
let opcode = VirtualPairCol::constant(M::F::from_canonical_u32(ADD32));
fn global_receives(&self, machine: &M) -> Vec<Interaction<SC::Val>> {
let opcode = VirtualPairCol::constant(SC::Val::from_canonical_u32(ADD32));
let input_1 = ADD_COL_MAP.input_1.0.map(VirtualPairCol::single_main);
let input_2 = ADD_COL_MAP.input_2.0.map(VirtualPairCol::single_main);
let output = ADD_COL_MAP.output.0.map(VirtualPairCol::single_main);
Expand Down Expand Up @@ -120,21 +121,22 @@ impl Add32Chip {
}
}

pub trait MachineWithAdd32Chip: MachineWithCpuChip {
pub trait MachineWithAdd32Chip<F: Field>: MachineWithCpuChip<F> {
fn add_u32(&self) -> &Add32Chip;
fn add_u32_mut(&mut self) -> &mut Add32Chip;
}

instructions!(Add32Instruction);

impl<M> Instruction<M> for Add32Instruction
impl<M, F> Instruction<M, F> for Add32Instruction
where
M: MachineWithAdd32Chip + MachineWithRangeChip<256>,
M: MachineWithAdd32Chip<F> + MachineWithRangeChip<F, 256>,
F: Field,
{
const OPCODE: u32 = ADD32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M>>::OPCODE;
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
Expand Down
46 changes: 25 additions & 21 deletions alu_u32/src/bitwise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ use valida_machine::{instructions, Chip, Instruction, Interaction, Operands, Wor
use valida_opcodes::{AND32, OR32, XOR32};

use p3_air::VirtualPairCol;
use p3_field::PrimeField;
use p3_field::{AbstractField, Field, PrimeField};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::*;
use valida_machine::config::StarkConfig;
use valida_util::pad_to_power_of_two;

pub mod columns;
Expand All @@ -30,12 +31,12 @@ pub struct Bitwise32Chip {
pub operations: Vec<Operation>,
}

impl<F, M> Chip<M> for Bitwise32Chip
impl<M, SC> Chip<M, SC> for Bitwise32Chip
where
F: PrimeField,
M: MachineWithGeneralBus<F = F>,
M: MachineWithGeneralBus<SC::Val>,
SC: StarkConfig,
{
fn generate_trace(&self, _machine: &M) -> RowMajorMatrix<M::F> {
fn generate_trace(&self, _machine: &M) -> RowMajorMatrix<SC::Val> {
let rows = self
.operations
.par_iter()
Expand All @@ -47,19 +48,19 @@ where
NUM_BITWISE_COLS,
);

pad_to_power_of_two::<NUM_BITWISE_COLS, F>(&mut trace.values);
pad_to_power_of_two::<NUM_BITWISE_COLS, SC::Val>(&mut trace.values);

trace
}

fn global_receives(&self, machine: &M) -> Vec<Interaction<M::F>> {
fn global_receives(&self, machine: &M) -> Vec<Interaction<SC::Val>> {
let opcode = VirtualPairCol::new_main(
vec![
(COL_MAP.is_and, M::F::from_canonical_u32(AND32)),
(COL_MAP.is_or, M::F::from_canonical_u32(OR32)),
(COL_MAP.is_xor, M::F::from_canonical_u32(XOR32)),
(COL_MAP.is_and, SC::Val::from_canonical_u32(AND32)),
(COL_MAP.is_or, SC::Val::from_canonical_u32(OR32)),
(COL_MAP.is_xor, SC::Val::from_canonical_u32(XOR32)),
],
M::F::zero(),
SC::Val::zero(),
);
let input_1 = COL_MAP.input_1.0.map(VirtualPairCol::single_main);
let input_2 = COL_MAP.input_2.0.map(VirtualPairCol::single_main);
Expand Down Expand Up @@ -126,21 +127,22 @@ impl Bitwise32Chip {
}
}

pub trait MachineWithBitwise32Chip: MachineWithCpuChip {
pub trait MachineWithBitwise32Chip<F: Field>: MachineWithCpuChip<F> {
fn bitwise_u32(&self) -> &Bitwise32Chip;
fn bitwise_u32_mut(&mut self) -> &mut Bitwise32Chip;
}

instructions!(And32Instruction, Or32Instruction, Xor32Instruction);

impl<M> Instruction<M> for Xor32Instruction
impl<M, F> Instruction<M, F> for Xor32Instruction
where
M: MachineWithBitwise32Chip,
M: MachineWithBitwise32Chip<F>,
F: Field,
{
const OPCODE: u32 = XOR32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M>>::OPCODE;
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
Expand Down Expand Up @@ -171,14 +173,15 @@ where
}
}

impl<M> Instruction<M> for And32Instruction
impl<M, F> Instruction<M, F> for And32Instruction
where
M: MachineWithBitwise32Chip,
M: MachineWithBitwise32Chip<F>,
F: Field,
{
const OPCODE: u32 = AND32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M>>::OPCODE;
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
Expand Down Expand Up @@ -209,14 +212,15 @@ where
}
}

impl<M> Instruction<M> for Or32Instruction
impl<M, F> Instruction<M, F> for Or32Instruction
where
M: MachineWithBitwise32Chip,
M: MachineWithBitwise32Chip<F>,
F: Field,
{
const OPCODE: u32 = OR32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M>>::OPCODE;
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
Expand Down
37 changes: 20 additions & 17 deletions alu_u32/src/div/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ use valida_opcodes::{DIV32, SDIV32};
use valida_range::MachineWithRangeChip;

use p3_air::VirtualPairCol;
use p3_field::PrimeField;
use p3_field::{AbstractField, Field, PrimeField};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::*;
use valida_machine::config::StarkConfig;
use valida_util::pad_to_power_of_two;

pub mod columns;
Expand All @@ -31,12 +32,12 @@ pub struct Div32Chip {
pub operations: Vec<Operation>,
}

impl<F, M> Chip<M> for Div32Chip
impl<M, SC> Chip<M, SC> for Div32Chip
where
F: PrimeField,
M: MachineWithGeneralBus<F = F>,
M: MachineWithGeneralBus<SC::Val>,
SC: StarkConfig,
{
fn generate_trace(&self, _machine: &M) -> RowMajorMatrix<M::F> {
fn generate_trace(&self, _machine: &M) -> RowMajorMatrix<SC::Val> {
let rows = self
.operations
.par_iter()
Expand All @@ -46,18 +47,18 @@ where
let mut trace =
RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_DIV_COLS);

pad_to_power_of_two::<NUM_DIV_COLS, F>(&mut trace.values);
pad_to_power_of_two::<NUM_DIV_COLS, SC::Val>(&mut trace.values);

trace
}

fn global_receives(&self, machine: &M) -> Vec<Interaction<M::F>> {
fn global_receives(&self, machine: &M) -> Vec<Interaction<SC::Val>> {
let opcode = VirtualPairCol::new_main(
vec![
(DIV_COL_MAP.is_div, M::F::from_canonical_u32(DIV32)),
(DIV_COL_MAP.is_sdiv, M::F::from_canonical_u32(SDIV32)),
(DIV_COL_MAP.is_div, SC::Val::from_canonical_u32(DIV32)),
(DIV_COL_MAP.is_sdiv, SC::Val::from_canonical_u32(SDIV32)),
],
M::F::zero(),
SC::Val::zero(),
);
let input_1 = DIV_COL_MAP.input_1.0.map(VirtualPairCol::single_main);
let input_2 = DIV_COL_MAP.input_2.0.map(VirtualPairCol::single_main);
Expand Down Expand Up @@ -102,21 +103,22 @@ impl Div32Chip {
}
}

pub trait MachineWithDiv32Chip: MachineWithCpuChip {
pub trait MachineWithDiv32Chip<F: Field>: MachineWithCpuChip<F> {
fn div_u32(&self) -> &Div32Chip;
fn div_u32_mut(&mut self) -> &mut Div32Chip;
}

instructions!(Div32Instruction, SDiv32Instruction);

impl<M> Instruction<M> for Div32Instruction
impl<M, F> Instruction<M, F> for Div32Instruction
where
M: MachineWithDiv32Chip + MachineWithRangeChip<256>,
M: MachineWithDiv32Chip<F> + MachineWithRangeChip<F, 256>,
F: Field,
{
const OPCODE: u32 = DIV32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M>>::OPCODE;
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
Expand Down Expand Up @@ -149,14 +151,15 @@ where
}
}

impl<M> Instruction<M> for SDiv32Instruction
impl<M, F> Instruction<M, F> for SDiv32Instruction
where
M: MachineWithDiv32Chip + MachineWithRangeChip<256>,
M: MachineWithDiv32Chip<F> + MachineWithRangeChip<F, 256>,
F: Field,
{
const OPCODE: u32 = SDIV32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M>>::OPCODE;
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
Expand Down
28 changes: 15 additions & 13 deletions alu_u32/src/lt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ use valida_machine::{
use valida_opcodes::LT32;

use p3_air::VirtualPairCol;
use p3_field::PrimeField;
use p3_field::{AbstractField, Field, PrimeField};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::*;
use valida_machine::config::StarkConfig;
use valida_util::pad_to_power_of_two;

pub mod columns;
Expand All @@ -31,12 +32,12 @@ pub struct Lt32Chip {
pub operations: Vec<Operation>,
}

impl<F, M> Chip<M> for Lt32Chip
impl<M, SC> Chip<M, SC> for Lt32Chip
where
F: PrimeField,
M: MachineWithGeneralBus<F = F>,
M: MachineWithGeneralBus<SC::Val>,
SC: StarkConfig,
{
fn generate_trace(&self, _machine: &M) -> RowMajorMatrix<M::F> {
fn generate_trace(&self, _machine: &M) -> RowMajorMatrix<SC::Val> {
let rows = self
.operations
.par_iter()
Expand All @@ -46,17 +47,17 @@ where
let mut trace =
RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_LT_COLS);

pad_to_power_of_two::<NUM_LT_COLS, F>(&mut trace.values);
pad_to_power_of_two::<NUM_LT_COLS, SC::Val>(&mut trace.values);

trace
}

fn global_receives(&self, machine: &M) -> Vec<Interaction<M::F>> {
let opcode = VirtualPairCol::constant(M::F::from_canonical_u32(LT32));
fn global_receives(&self, machine: &M) -> Vec<Interaction<SC::Val>> {
let opcode = VirtualPairCol::constant(SC::Val::from_canonical_u32(LT32));
let input_1 = LT_COL_MAP.input_1.0.map(VirtualPairCol::single_main);
let input_2 = LT_COL_MAP.input_2.0.map(VirtualPairCol::single_main);
let output = (0..MEMORY_CELL_BYTES - 1)
.map(|_| VirtualPairCol::constant(M::F::zero()))
.map(|_| VirtualPairCol::constant(SC::Val::zero()))
.chain(iter::once(VirtualPairCol::single_main(LT_COL_MAP.output)));

let mut fields = vec![opcode];
Expand Down Expand Up @@ -107,21 +108,22 @@ impl Lt32Chip {
}
}

pub trait MachineWithLt32Chip: MachineWithCpuChip {
pub trait MachineWithLt32Chip<F: Field>: MachineWithCpuChip<F> {
fn lt_u32(&self) -> &Lt32Chip;
fn lt_u32_mut(&mut self) -> &mut Lt32Chip;
}

instructions!(Lt32Instruction);

impl<M> Instruction<M> for Lt32Instruction
impl<M, F> Instruction<M, F> for Lt32Instruction
where
M: MachineWithLt32Chip,
M: MachineWithLt32Chip<F>,
F: Field,
{
const OPCODE: u32 = LT32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M>>::OPCODE;
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
Expand Down
Loading

0 comments on commit 692e6d8

Please sign in to comment.