Skip to content

Commit

Permalink
Change one of the arrays back to elems (?)
Browse files Browse the repository at this point in the history
  • Loading branch information
TlatoaniHJ committed Nov 1, 2024
1 parent b8f9ab6 commit 05e047a
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 30 deletions.
4 changes: 2 additions & 2 deletions toolchain/native-compiler/src/ir/fri.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::ir::{Array, Builder, Config, Ext};
use crate::ir::{Array, Builder, Config, Ext, Felt};

impl<C: Config> Builder<C> {
pub fn fri_fold(
&mut self,
alpha: Ext<C::F, C::EF>,
curr_alpha_pow: Ext<C::F, C::EF>,
at_x_array: Array<C, Ext<C::F, C::EF>>,
at_x_array: Array<C, Felt<C::F>>,
at_z_array: Array<C, Ext<C::F, C::EF>>,
) -> Ext<C::F, C::EF> {
let result = self.uninit();
Expand Down
2 changes: 1 addition & 1 deletion toolchain/native-compiler/src/ir/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ pub enum DslIr<C: Config> {
FriFold(
Ext<C::F, C::EF>,
Ext<C::F, C::EF>,
Array<C, Ext<C::F, C::EF>>,
Array<C, Felt<C::F>>,
Array<C, Ext<C::F, C::EF>>,
Ext<C::F, C::EF>,
),
Expand Down
8 changes: 4 additions & 4 deletions toolchain/native-compiler/tests/fri_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use axvm_circuit::system::program::util::execute_program;
use axvm_native_compiler::{
asm::{AsmBuilder, AsmCompiler},
conversion::{convert_program, CompilerOptions},
ir::{Array, Ext},
ir::{Array, Ext, Felt},
};
use p3_baby_bear::BabyBear;
use p3_field::{extension::BinomialExtensionField, AbstractField};
Expand All @@ -23,13 +23,13 @@ fn test_fri_fold() {
let x_value = rng.gen::<EF>();
let z_value = rng.gen::<EF>();

let mat_opening: Array<_, Ext<_, _>> = builder.dyn_array(n);
let ps_at_z: Array<_, Ext<_, _>> = builder.dyn_array(n);
let mat_opening: Array<_, Felt<_>> = builder.dyn_array(n);

for i in 0..n {
let a_value = rng.gen::<EF>();
let a_value = rng.gen::<F>();
let b_value = rng.gen::<EF>();
let val = builder.constant::<Ext<_, _>>(a_value);
let val = builder.constant::<Felt<_>>(a_value);
builder.set(&mat_opening, i, val);
let val = builder.constant::<Ext<_, _>>(b_value);
builder.set(&ps_at_z, i, val);
Expand Down
39 changes: 22 additions & 17 deletions vm/src/kernels/fri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub struct FriFoldCols<T> {

pub a_pointer_aux: MemoryReadAuxCols<T, 1>,
pub b_pointer_aux: MemoryReadAuxCols<T, 1>,
pub a_aux: MemoryReadAuxCols<T, EXT_DEG>,
pub a_aux: MemoryReadAuxCols<T, 1>,
pub b_aux: MemoryReadAuxCols<T, EXT_DEG>,
pub result_aux: MemoryWriteAuxCols<T, EXT_DEG>,
pub length_aux: MemoryReadAuxCols<T, 1>,
Expand All @@ -68,7 +68,7 @@ pub struct FriFoldCols<T> {

pub a_pointer: T,
pub b_pointer: T,
pub a: [T; EXT_DEG],
pub a: T,
pub b: [T; EXT_DEG],
pub alpha: [T; EXT_DEG],
pub alpha_pow_original: [T; EXT_DEG],
Expand Down Expand Up @@ -158,12 +158,12 @@ impl<AB: InteractionBuilder> Air<AB> for FriFoldAir {

let mut when_is_not_last = builder.when(not(is_last));

let next_alpha_pow_times_a = FieldExtension::multiply(next.alpha_pow_current, next.a);
let next_alpha_pow_times_b = FieldExtension::multiply(next.alpha_pow_current, next.b);
for i in 0..EXT_DEG {
when_is_not_last.assert_eq(
next.current[i],
next_alpha_pow_times_b[i].clone() - next_alpha_pow_times_a[i].clone() + current[i],
next_alpha_pow_times_b[i].clone() - (next.alpha_pow_current[i] * next.a)
+ current[i],
);
}

Expand Down Expand Up @@ -191,12 +191,11 @@ impl<AB: InteractionBuilder> Air<AB> for FriFoldAir {
alpha_pow_original,
);

let alpha_pow_times_a = FieldExtension::multiply(alpha_pow_current, a);
let alpha_pow_times_b = FieldExtension::multiply(alpha_pow_current, b);
for i in 0..EXT_DEG {
builder.when(is_first).assert_eq(
current[i],
alpha_pow_times_b[i].clone() - alpha_pow_times_a[i].clone(),
alpha_pow_times_b[i].clone() - (alpha_pow_current[i] * a),
);
}

Expand Down Expand Up @@ -277,11 +276,8 @@ impl<AB: InteractionBuilder> Air<AB> for FriFoldAir {

self.memory_bridge
.read(
MemoryAddress::new(
address_space,
a_pointer + (index * AB::F::from_canonical_usize(4)),
),
a,
MemoryAddress::new(address_space, a_pointer + index),
[a],
start_timestamp + num_initial_accesses + (index * AB::F::two()),
&a_aux,
)
Expand Down Expand Up @@ -330,7 +326,7 @@ pub struct FriFoldRecord<F: Field> {
pub length_read: MemoryReadRecord<F, 1>,
pub a_pointer_read: MemoryReadRecord<F, 1>,
pub b_pointer_read: MemoryReadRecord<F, 1>,
pub a_reads: Vec<MemoryReadRecord<F, EXT_DEG>>,
pub a_reads: Vec<MemoryReadRecord<F, 1>>,
pub b_reads: Vec<MemoryReadRecord<F, EXT_DEG>>,
pub alpha_pow_write: MemoryWriteRecord<F, EXT_DEG>,
pub result_write: MemoryWriteRecord<F, EXT_DEG>,
Expand Down Expand Up @@ -363,6 +359,12 @@ impl<F: PrimeField32> FriFoldChip<F> {
}
}

fn elem_to_ext<F: Field>(elem: F) -> [F; EXT_DEG] {
let mut ret = [F::zero(); EXT_DEG];
ret[0] = elem;
ret
}

impl<F: PrimeField32> InstructionExecutor<F> for FriFoldChip<F> {
fn execute(
&mut self,
Expand Down Expand Up @@ -404,15 +406,15 @@ impl<F: PrimeField32> InstructionExecutor<F> for FriFoldChip<F> {
let mut result = [F::zero(); EXT_DEG];

for i in 0..length {
let a_read = memory.read(address_space, a_pointer + F::from_canonical_usize(4 * i));
let a_read = memory.read_cell(address_space, a_pointer + F::from_canonical_usize(i));
let b_read = memory.read(address_space, b_pointer + F::from_canonical_usize(4 * i));
a_reads.push(a_read);
b_reads.push(b_read);
let a = a_read.data;
let a = a_read.data[0];
let b = b_read.data;
result = FieldExtension::add(
result,
FieldExtension::multiply(FieldExtension::subtract(b, a), alpha_pow),
FieldExtension::multiply(FieldExtension::subtract(b, elem_to_ext(a)), alpha_pow),
);
alpha_pow = FieldExtension::multiply(alpha, alpha_pow);
}
Expand Down Expand Up @@ -502,11 +504,14 @@ impl<F: PrimeField32> FriFoldChip<F> {
let result_aux = aux_cols_factory.make_write_aux_cols(record.result_write);

for i in 0..length {
let a = record.a_reads[i].data;
let a = record.a_reads[i].data[0];
let b = record.b_reads[i].data;
current = FieldExtension::add(
current,
FieldExtension::multiply(FieldExtension::subtract(b, a), alpha_pow_current),
FieldExtension::multiply(
FieldExtension::subtract(b, elem_to_ext(a)),
alpha_pow_current,
),
);

let mut index_is_zero = F::zero();
Expand Down
14 changes: 8 additions & 6 deletions vm/src/kernels/fri/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@ use crate::{
arch::testing::{memory::gen_pointer, VmChipTestBuilder},
kernels::{
field_extension::FieldExtension,
fri::{FriFoldChip, FriFoldCols, EXT_DEG},
fri::{elem_to_ext, FriFoldChip, FriFoldCols, EXT_DEG},
},
};

fn compute_fri_fold<F: Field>(
alpha: [F; EXT_DEG],
mut alpha_pow: [F; EXT_DEG],
a: &[[F; EXT_DEG]],
a: &[F],
b: &[[F; EXT_DEG]],
) -> ([F; EXT_DEG], [F; EXT_DEG]) {
let mut result = [F::zero(); EXT_DEG];
for (&a, &b) in a.iter().zip_eq(b) {
result = FieldExtension::add(
result,
FieldExtension::multiply(FieldExtension::subtract(b, a), alpha_pow),
FieldExtension::multiply(FieldExtension::subtract(b, elem_to_ext(a)), alpha_pow),
);
alpha_pow = FieldExtension::multiply(alpha, alpha_pow);
}
Expand Down Expand Up @@ -59,7 +59,9 @@ fn fri_fold_air_test() {
let alpha = gen_ext!();
let length = rng.gen_range(length_range());
let alpha_pow_initial = gen_ext!();
let a = (0..length).map(|_| gen_ext!()).collect_vec();
let a = (0..length)
.map(|_| BabyBear::from_canonical_u32(rng.gen_range(elem_range())))
.collect_vec();
let b = (0..length).map(|_| gen_ext!()).collect_vec();

let (alpha_pow_final, result) = compute_fri_fold(alpha, alpha_pow_initial, &a, &b);
Expand All @@ -70,7 +72,7 @@ fn fri_fold_air_test() {
let b_pointer_pointer = gen_pointer(&mut rng, 1);
let alpha_pow_pointer = gen_pointer(&mut rng, 4);
let result_pointer = gen_pointer(&mut rng, 4);
let a_pointer = gen_pointer(&mut rng, 4);
let a_pointer = gen_pointer(&mut rng, 1);
let b_pointer = gen_pointer(&mut rng, 4);

let address_space = rng.gen_range(address_space_range());
Expand Down Expand Up @@ -98,7 +100,7 @@ fn fri_fold_air_test() {
);
tester.write(address_space, alpha_pow_pointer, alpha_pow_initial);
for i in 0..length {
tester.write(address_space, a_pointer + (4 * i), a[i]);
tester.write_cell(address_space, a_pointer + i, a[i]);
tester.write(address_space, b_pointer + (4 * i), b[i]);
}

Expand Down

0 comments on commit 05e047a

Please sign in to comment.