Skip to content

Commit

Permalink
add dummy test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Jan 8, 2025
1 parent 2ad340e commit 687a6a7
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
33 changes: 33 additions & 0 deletions tests/ui/batching/batch_const.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Problem: The user might want to pass either [f64; 4], (f64, f64, f64, f64), or S to the
// function. All of these are valid (modulo we have to force the user to set the right repr).
// Our current design doesn't allow users to specify those, so we will want at least one iteration.
// However, for the sake of similarity to the current autodiff (where we'd also want a change),
// leave it as is.

struct _S {
x1: f64,
x2: f64,
x3: f64,
x4: f64,
}

#[batch(bsquare4, 4, Const, Leaf(8))]
#[batch(vsquare4, 4, Const, Vector)]
fn square(multiplier: f64, x: f64) -> f64 {
x * x * multiplier
}

fn main() {
let vals = [23.1, 10.0, 100.0, 3.14];
let expected = [square(3.14, vals[0]), square(3.14, vals[1]), square(3.14, vals[2]), square(3.14, vals[3])];
let result1 = bsquare4(3.14, vals[0], vals[1], vals[2], vals[3]);
let result2 = vsquare4(3.14, vals);
assert_eq!(result.x1, expected[0]);
assert_eq!(result.x2, expected[1]);
assert_eq!(result.x3, expected[2]);
assert_eq!(result.x4, expected[3]);
assert_eq!(result2.x1, expected[0]);
assert_eq!(result2.x2, expected[1]);
assert_eq!(result2.x3, expected[2]);
assert_eq!(result2.x4, expected[3]);
}
34 changes: 34 additions & 0 deletions tests/ui/batching/slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// We want a batch size of 4.
// The original function processes 2 elements a 64 bit, so for our vfoo we have an offset of 16 bytes.
// Both vfoo and bfoo return [f64; 4].

#[batch(vfoo, 4, Leaf(16))]
#[batch(bfoo, 4, Batch)]
fn foo(x: &[f64]) -> f64 {
assert!(x.len() == 2);
x.iter().map(|&x| x * x).sum()
}

fn main() {
// 8 elements
let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];

let x2 = vec![1.0, 2.0];
let x3 = vec![3.0, 4.0];
let x4 = vec![5.0, 6.0];
let x5 = vec![7.0, 8.0];

let mut res1 = [0.0;4];
for i in 0..4 {
res1[i] = foo(&x1[i..i + 1]);
}

let res2: [f64; 4] = bfoo(&x2, &x3, &x4, &x5);

let res3: [f64; 4] = vfoo(&x1);

for i in 0..4 {
assert_eq!(res1[i], res2[i]);
assert_eq!(res1[i], res3[i]);
}
}
47 changes: 47 additions & 0 deletions tests/ui/batching/vector_char-ptr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Showcasing a slightly more complex type.


#[repr(C, packed)]
struct Foo {
arr: [i32; 3],
x: f64,
y: f32,
res: f64,
}

//#pragma pack(1)
//struct Foo {
// int arr[3];
// double x;
// float y;
// double res;
//};

#[batch(df, 4, Vector)]
unsafe fn f(foo: *mut i32) {
let xptr = foo.add(3) as *mut f64;
let yptr = foo.add(5) as *mut f32;
let resptr = foo.add(6) as *mut f64;
let x: f64 = *xptr;
let y: f32 = *yptr;
*resptr = x * y;
}

fn main() {
let foo1: Foo = Foo { [0,0,0], 10.0, 9.0, 0.0 };
let foo2: Foo = Foo { [0,0,0], 99.0, 7.0, 0.0 };
let foo3: Foo = Foo { [0,0,0], 1.10, 9.0, 0.0 };
let foo4: Foo = Foo { [0,0,0], 3.14, 0.1, 0.0 };

let expected = [90.0, 693.0, 9.9, 0.314};

df(&foo1.as_ptr() as *mut i32,
&foo2.as_ptr() as *mut i32,
&foo3.as_ptr() as *mut i32,
&foo4.as_ptr() as *mut i32);

assert_eq!(foo1.res, expected[0]);
assert_eq!(foo2.res, expected[1]);
assert_eq!(foo3.res, expected[2]);
assert_eq!(foo4.res, expected[3]);
}

0 comments on commit 687a6a7

Please sign in to comment.