Skip to content

Commit

Permalink
Add FriFold eDSL test and switch recursion to use FriFoldChip
Browse files Browse the repository at this point in the history
  • Loading branch information
TlatoaniHJ committed Nov 1, 2024
1 parent a86655e commit b8f9ab6
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 8 deletions.
9 changes: 5 additions & 4 deletions lib/recursion/src/fri/two_adic_pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,18 @@ pub fn verify_two_adic_pcs<C: Config>(

builder.cycle_tracker_start("sp1-fri-fold");

builder.range(0, ps_at_z.len()).for_each(|t, builder| {
/*builder.range(0, ps_at_z.len()).for_each(|t, builder| {
let p_at_x = builder.get(&mat_opening, t);
let p_at_z = builder.get(&ps_at_z, t);
let quotient = (p_at_z - p_at_x) / (z - x);
builder.assign(&cur_ro, cur_ro + cur_alpha_pow * quotient);
builder.assign(&cur_alpha_pow, cur_alpha_pow * alpha);
});
});*/

//let fri_fold_result = builder.fri_fold(alpha, cur_alpha_pow, mat_opening, ps_at_z);
//builder.assign(&cur_ro, cur_ro + (fri_fold_result / (z - x)));
let fri_fold_result =
builder.fri_fold(alpha, cur_alpha_pow, mat_opening, ps_at_z);
builder.assign(&cur_ro, cur_ro + (fri_fold_result / (z - x)));

builder.cycle_tracker_end("sp1-fri-fold");
});
Expand Down
88 changes: 88 additions & 0 deletions toolchain/native-compiler/tests/fri_fold.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use axvm_circuit::system::program::util::execute_program;
use axvm_native_compiler::{
asm::{AsmBuilder, AsmCompiler},
conversion::{convert_program, CompilerOptions},
ir::{Array, Ext},
};
use p3_baby_bear::BabyBear;
use p3_field::{extension::BinomialExtensionField, AbstractField};
use rand::{thread_rng, Rng};

type F = BabyBear;
type EF = BinomialExtensionField<BabyBear, 4>;

#[test]
fn test_fri_fold() {
let mut builder = AsmBuilder::<F, EF>::default();

let mut rng = thread_rng();
let n = 3;

let alpha_value = rng.gen::<EF>();
let initial_alpha_pow_value = rng.gen::<EF>();
let x_value = rng.gen::<EF>();
let z_value = rng.gen::<EF>();

let mat_opening: Array<_, Ext<_, _>> = builder.dyn_array(n);
let ps_at_z: Array<_, Ext<_, _>> = builder.dyn_array(n);

for i in 0..n {
let a_value = rng.gen::<EF>();
let b_value = rng.gen::<EF>();
let val = builder.constant::<Ext<_, _>>(a_value);
builder.set(&mat_opening, i, val);
let val = builder.constant::<Ext<_, _>>(b_value);
builder.set(&ps_at_z, i, val);
}

let alpha: Ext<_, _> = builder.constant(alpha_value);
let initial_alpha_pow: Ext<_, _> = builder.constant(initial_alpha_pow_value);
let x: Ext<_, _> = builder.constant(x_value);
let z: Ext<_, _> = builder.constant(z_value);

let cur_ro: Ext<_, _> = builder.constant(EF::zero());
let cur_alpha_pow: Ext<_, _> = builder.uninit();
builder.assign(&cur_alpha_pow, initial_alpha_pow);
builder.range(0, ps_at_z.len()).for_each(|t, builder| {
let p_at_x = builder.get(&mat_opening, t);
let p_at_z = builder.get(&ps_at_z, t);
let quotient = (p_at_z - p_at_x) / (z - x);

builder.assign(&cur_ro, cur_ro + cur_alpha_pow * quotient);
builder.assign(&cur_alpha_pow, cur_alpha_pow * alpha);
});
let expected_result = cur_ro;
let expected_final_alpha_pow = cur_alpha_pow;

// prints don't work?
/*builder.print_e(expected_result);
builder.print_e(expected_final_alpha_pow);
let two = builder.constant(F::two());
builder.print_f(two);
let ext_1210 = builder.constant(EF::from_base_slice(&[F::one(), F::two(), F::one(), F::zero()]));
builder.print_e(ext_1210);*/

let cur_alpha_pow: Ext<_, _> = builder.uninit();
builder.assign(&cur_alpha_pow, initial_alpha_pow);
let fri_fold_result = builder.fri_fold(alpha, cur_alpha_pow, mat_opening, ps_at_z);
let actual_final_alpha_pow = cur_alpha_pow;
let actual_result: Ext<_, _> = builder.uninit();
builder.assign(&actual_result, fri_fold_result / (z - x));

//builder.print_e(actual_result);
//builder.print_e(actual_final_alpha_pow);

builder.assert_ext_eq(expected_result, actual_result);
builder.assert_ext_eq(expected_final_alpha_pow, actual_final_alpha_pow);

builder.halt();

let mut compiler = AsmCompiler::new(1);
compiler.build(builder.operations);
let asm_code = compiler.code();
// println!("{}", asm_code);

let program = convert_program::<F, EF>(asm_code, CompilerOptions::default());
execute_program(program, vec![]);
}
17 changes: 17 additions & 0 deletions vm/src/arch/chip_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ use crate::{
castf::{CastFChip, CastFCoreChip},
field_arithmetic::{FieldArithmeticChip, FieldArithmeticCoreChip},
field_extension::{FieldExtensionChip, FieldExtensionCoreChip},
fri::FriFoldChip,
jal::{JalCoreChip, KernelJalChip},
loadstore::{KernelLoadStoreChip, KernelLoadStoreCoreChip},
modular::{KernelModularAddSubChip, KernelModularMulDivChip},
Expand Down Expand Up @@ -403,6 +404,17 @@ impl VmConfig {
}
chips.push(AxVmChip::Keccak256(chip));
}
ExecutorName::FriFold => {
let chip = Rc::new(RefCell::new(FriFoldChip::new(
memory_controller.clone(),
execution_bus,
program_bus,
)));
for opcode in range {
executors.insert(opcode, chip.clone().into());
}
chips.push(AxVmChip::FriFold(chip));
}
ExecutorName::BaseAluRv32 => {
let chip = Rc::new(RefCell::new(Rv32BaseAluChip::new(
Rv32BaseAluAdapterChip::new(
Expand Down Expand Up @@ -1304,6 +1316,11 @@ fn default_executor_range(executor: ExecutorName) -> (Range<usize>, usize) {
Keccak256Opcode::COUNT,
Keccak256Opcode::default_offset(),
),
ExecutorName::FriFold => (
FriFoldOpcode::default_offset(),
FriFoldOpcode::COUNT,
FriFoldOpcode::default_offset(),
),
ExecutorName::BaseAluRv32 => (
BaseAluOpcode::default_offset(),
BaseAluOpcode::COUNT,
Expand Down
3 changes: 3 additions & 0 deletions vm/src/arch/chips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use crate::{
castf::CastFChip,
field_arithmetic::FieldArithmeticChip,
field_extension::FieldExtensionChip,
fri::FriFoldChip,
jal::KernelJalChip,
loadstore::KernelLoadStoreChip,
modular::{KernelModularAddSubChip, KernelModularMulDivChip},
Expand Down Expand Up @@ -87,6 +88,7 @@ pub enum AxVmInstructionExecutor<F: PrimeField32> {
PublicValues(Rc<RefCell<PublicValuesChip<F>>>),
Poseidon2(Rc<RefCell<Poseidon2Chip<F>>>),
Keccak256(Rc<RefCell<KeccakVmChip<F>>>),
FriFold(Rc<RefCell<FriFoldChip<F>>>),
/// Rv32 (for standard 32-bit integers):
BaseAluRv32(Rc<RefCell<Rv32BaseAluChip<F>>>),
LessThanRv32(Rc<RefCell<Rv32LessThanChip<F>>>),
Expand Down Expand Up @@ -143,6 +145,7 @@ pub enum AxVmChip<F: PrimeField32> {
RangeChecker(Arc<VariableRangeCheckerChip>),
RangeTupleChecker(Arc<RangeTupleCheckerChip<2>>),
Keccak256(Rc<RefCell<KeccakVmChip<F>>>),
FriFold(Rc<RefCell<FriFoldChip<F>>>),
BitwiseOperationLookup(Arc<BitwiseOperationLookupChip<8>>),
BaseAluRv32(Rc<RefCell<Rv32BaseAluChip<F>>>),
BaseAlu256Rv32(Rc<RefCell<Rv32BaseAlu256Chip<F>>>),
Expand Down
10 changes: 6 additions & 4 deletions vm/src/kernels/fri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use ax_stark_backend::{
rap::{AnyRap, BaseAirWithPublicValues, PartitionedBaseAir},
Chip, ChipUsageGetter,
};
use axvm_instructions::{instruction::Instruction, FriFoldOpcode::FRI_FOLD};
use axvm_instructions::{
instruction::Instruction, program::DEFAULT_PC_STEP, FriFoldOpcode::FRI_FOLD,
};
use itertools::Itertools;
use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::{AbstractField, Field, PrimeField32};
Expand Down Expand Up @@ -230,7 +232,7 @@ impl<AB: InteractionBuilder> Air<AB> for FriFoldAir {
],
ExecutionState::new(pc, start_timestamp),
ExecutionState::<AB::Expr>::new(
AB::Expr::one() + pc,
AB::Expr::from_canonical_u32(DEFAULT_PC_STEP) + pc,
total_accesses + start_timestamp - AB::F::one(),
),
)
Expand Down Expand Up @@ -343,7 +345,7 @@ pub struct FriFoldChip<F: Field> {

impl<F: PrimeField32> FriFoldChip<F> {
#[allow(dead_code)]
fn new(
pub(crate) fn new(
memory: MemoryControllerRef<F>,
execution_bus: ExecutionBus,
program_bus: ProgramBus,
Expand Down Expand Up @@ -436,7 +438,7 @@ impl<F: PrimeField32> InstructionExecutor<F> for FriFoldChip<F> {
self.height += length;

Ok(ExecutionState {
pc: from_state.pc + 1,
pc: from_state.pc + DEFAULT_PC_STEP,
timestamp: result_write.timestamp,
})
}
Expand Down
1 change: 1 addition & 0 deletions vm/src/system/program/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub fn execute_program(program: Program<BabyBear>, input_stream: Vec<Vec<BabyBea
.add_executor(ExecutorName::FieldArithmetic)
.add_executor(ExecutorName::FieldExtension)
.add_executor(ExecutorName::Poseidon2)
.add_executor(ExecutorName::FriFold)
.add_executor(ExecutorName::BaseAlu256Rv32)
.add_executor(ExecutorName::LessThan256Rv32)
.add_executor(ExecutorName::Multiplication256Rv32)
Expand Down

0 comments on commit b8f9ab6

Please sign in to comment.