Skip to content

Commit

Permalink
feat: add tests for iter and zip
Browse files Browse the repository at this point in the history
  • Loading branch information
yi-sun committed Jan 11, 2025
1 parent e0886a8 commit 4a1d614
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 5 deletions.
3 changes: 0 additions & 3 deletions extensions/native/compiler/src/ir/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,6 @@ impl<C: Config> Builder<C> {
) -> ZippedPointerIteratorBuilder<'a, C> {
assert!(!arrays.is_empty());
if arrays.iter().all(|array| array.is_fixed()) {
assert!(arrays
.windows(2)
.all(|array| array[0].len() == array[1].len()));
ZippedPointerIteratorBuilder {
starts: vec![RVar::zero(); arrays.len()],
end0: arrays[0].len().into(),
Expand Down
8 changes: 6 additions & 2 deletions extensions/native/compiler/src/ir/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,9 @@ impl<C: Config> Builder<C> {
var.clone(),
Ptr {
address: match ptr {
RVar::Const(_) => unimplemented!(),
RVar::Const(_) => panic!(
"iter_ptr_get on dynamic array not supported for constant ptr"
),
RVar::Val(v) => v,
},
},
Expand Down Expand Up @@ -336,7 +338,9 @@ impl<C: Config> Builder<C> {
self.store(
Ptr {
address: match ptr {
RVar::Const(_) => unimplemented!(),
RVar::Const(_) => panic!(
"iter_ptr_set on dynamic array not supported for constant ptr"
),
RVar::Val(v) => v,
},
},
Expand Down
126 changes: 126 additions & 0 deletions extensions/native/compiler/tests/for_loops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use openvm_native_circuit::execute_program;
use openvm_native_compiler::{
asm::{AsmBuilder, AsmConfig},
ir::{Array, RVar, SymbolicVar, Var},
prelude::ArrayLike,
};
use openvm_native_compiler_derive::compile_zip;
use openvm_stark_backend::p3_field::{extension::BinomialExtensionField, FieldAlgebra};
use openvm_stark_sdk::p3_baby_bear::BabyBear;

Expand Down Expand Up @@ -47,6 +49,130 @@ fn test_compiler_for_loops() {
execute_program(program, vec![]);
}

#[test]
fn test_compiler_iter_fixed() {
let mut builder = AsmBuilder::<F, EF>::default();
let zero: Var<_> = builder.eval(F::ZERO);
let one: Var<_> = builder.eval(F::ONE);
let two: Var<_> = builder.eval(F::TWO);
let arr = builder.vec(vec![zero, one, two]);
let x: Var<_> = builder.eval(F::ZERO);
let count: Var<_> = builder.eval(F::ZERO);
builder.iter(&arr).for_each(|val: Var<_>, builder| {
builder.assign(&x, x + val);
builder.assign(&count, count + F::ONE);
});
builder.assert_var_eq(count, F::from_canonical_usize(3));
builder.assert_var_eq(x, F::from_canonical_usize(3));
builder.halt();

let program = builder.compile_isa();
execute_program(program, vec![]);
}

#[test]
fn test_compiler_iter_dyn() {
let mut builder = AsmBuilder::<F, EF>::default();
let zero: Var<_> = builder.eval(F::ZERO);
let one: Var<_> = builder.eval(F::ONE);
let two: Var<_> = builder.eval(F::TWO);
let arr = builder.dyn_array(3);
builder.set(&arr, 0, zero);
builder.set(&arr, 1, one);
builder.set(&arr, 2, two);
let x: Var<_> = builder.eval(F::ZERO);
let count: Var<_> = builder.eval(F::ZERO);
builder.iter(&arr).for_each(|val: Var<_>, builder| {
builder.assign(&x, x + val);
builder.assign(&count, count + F::ONE);
});
builder.assert_var_eq(count, F::from_canonical_usize(3));
builder.assert_var_eq(x, F::from_canonical_usize(3));
builder.halt();

let program = builder.compile_isa();
execute_program(program, vec![]);
}

#[test]
fn test_compiler_zip_fixed() {
let mut builder = AsmBuilder::<F, EF>::default();
let zero: Var<_> = builder.eval(F::ZERO);
let one: Var<_> = builder.eval(F::ONE);
let three: Var<_> = builder.eval(F::TWO + F::ONE);
let four: Var<_> = builder.eval(F::TWO + F::TWO);
let five: Var<_> = builder.eval(F::TWO + F::TWO + F::ONE);
let arr1 = builder.vec(vec![zero, one]);
let arr2 = builder.vec(vec![three, four, five]);

let x1: Var<_> = builder.eval(F::ZERO);
let x2: Var<_> = builder.eval(F::ZERO);
let count: Var<_> = builder.eval(F::ZERO);
let ptr1_cache: Var<_> = builder.eval(F::ZERO);
let ptr2_cache: Var<_> = builder.eval(F::ZERO);

compile_zip!(builder, arr1, arr2).for_each(|ptr_vec, builder| {
let val1 = builder.iter_ptr_get(&arr1, ptr_vec[0]);
let val2 = builder.iter_ptr_get(&arr2, ptr_vec[1]);
builder.assign(&x1, x1 + val1);
builder.assign(&x2, x2 + val2);
builder.assign(&count, count + F::ONE);
builder.assign(&ptr1_cache, ptr_vec[0]);
builder.assign(&ptr2_cache, ptr_vec[1]);
});
builder.assert_var_eq(count, F::from_canonical_usize(2));
builder.assert_var_eq(x1, F::from_canonical_usize(1));
builder.assert_var_eq(x2, F::from_canonical_usize(7));
builder.assert_var_eq(ptr1_cache, F::from_canonical_usize(1));
builder.assert_var_eq(ptr2_cache, F::from_canonical_usize(1));
builder.halt();

let program = builder.compile_isa();
execute_program(program, vec![]);
}

#[test]
fn test_compiler_zip_dyn() {
let mut builder = AsmBuilder::<F, EF>::default();
let zero: Var<_> = builder.eval(F::ZERO);
let one: Var<_> = builder.eval(F::ONE);
let three: Var<_> = builder.eval(F::TWO + F::ONE);
let four: Var<_> = builder.eval(F::TWO + F::TWO);
let five: Var<_> = builder.eval(F::TWO + F::TWO + F::ONE);
let arr1 = builder.dyn_array(2);
let arr2 = builder.dyn_array(3);
builder.set(&arr1, 0, zero);
builder.set(&arr1, 1, one);
builder.set(&arr2, 0, three);
builder.set(&arr2, 1, four);
builder.set(&arr2, 2, five);

let x1: Var<_> = builder.eval(F::ZERO);
let x2: Var<_> = builder.eval(F::ZERO);
let count: Var<_> = builder.eval(F::ZERO);
let ptr1_cache: Var<_> = builder.eval(F::ZERO);
let ptr2_cache: Var<_> = builder.eval(F::ZERO);

compile_zip!(builder, arr1, arr2).for_each(|ptr_vec, builder| {
let val1: Var<_> = builder.iter_ptr_get(&arr1, ptr_vec[0]);
let val2: Var<_> = builder.iter_ptr_get(&arr2, ptr_vec[1]);
builder.assign(&x1, x1 + val1);
builder.assign(&x2, x2 + val2);
builder.assign(&count, count + F::ONE);
builder.assign(&ptr1_cache, ptr_vec[0]);
builder.assign(&ptr2_cache, ptr_vec[1]);
});
builder.assert_var_eq(count, F::from_canonical_usize(2));
builder.assert_var_eq(x1, F::from_canonical_usize(1));
builder.assert_var_eq(x2, F::from_canonical_usize(7));
builder.assert_var_eq(ptr1_cache, arr1.ptr().address + F::from_canonical_usize(1));
builder.assert_var_eq(ptr2_cache, arr2.ptr().address + F::from_canonical_usize(1));
builder.halt();

let program = builder.compile_isa();
execute_program(program, vec![]);
}

#[test]
fn test_compiler_nested_array_loop() {
let mut builder = AsmBuilder::<F, EF>::default();
Expand Down

0 comments on commit 4a1d614

Please sign in to comment.