forked from rust-lang/rust
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
114 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Problem: The user might want to pass either [f64; 4], (f64, f64, f64, f64), or S to the | ||
// function. All of these are valid (modulo we have to force the user to set the right repr). | ||
// Our current design doesn't allow users to specify those, so we will want at least one iteration. | ||
// However, for the sake of similarity to the current autodiff (where we'd also want a change), | ||
// leave it as is. | ||
|
||
struct _S { | ||
x1: f64, | ||
x2: f64, | ||
x3: f64, | ||
x4: f64, | ||
} | ||
|
||
#[batch(bsquare4, 4, Const, Leaf(8))] | ||
#[batch(vsquare4, 4, Const, Vector)] | ||
fn square(multiplier: f64, x: f64) -> f64 { | ||
x * x * multiplier | ||
} | ||
|
||
fn main() { | ||
let vals = [23.1, 10.0, 100.0, 3.14]; | ||
let expected = [square(3.14, vals[0]), square(3.14, vals[1]), square(3.14, vals[2]), square(3.14, vals[3])]; | ||
let result1 = bsquare4(3.14, vals[0], vals[1], vals[2], vals[3]); | ||
let result2 = vsquare4(3.14, vals); | ||
assert_eq!(result.x1, expected[0]); | ||
assert_eq!(result.x2, expected[1]); | ||
assert_eq!(result.x3, expected[2]); | ||
assert_eq!(result.x4, expected[3]); | ||
assert_eq!(result2.x1, expected[0]); | ||
assert_eq!(result2.x2, expected[1]); | ||
assert_eq!(result2.x3, expected[2]); | ||
assert_eq!(result2.x4, expected[3]); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
// We want a batch size of 4. | ||
// The original function processes 2 elements a 64 bit, so for our vfoo we have an offset of 16 bytes. | ||
// Both vfoo and bfoo return [f64; 4]. | ||
|
||
#[batch(vfoo, 4, Leaf(16))] | ||
#[batch(bfoo, 4, Batch)] | ||
fn foo(x: &[f64]) -> f64 { | ||
assert!(x.len() == 2); | ||
x.iter().map(|&x| x * x).sum() | ||
} | ||
|
||
fn main() { | ||
// 8 elements | ||
let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; | ||
|
||
let x2 = vec![1.0, 2.0]; | ||
let x3 = vec![3.0, 4.0]; | ||
let x4 = vec![5.0, 6.0]; | ||
let x5 = vec![7.0, 8.0]; | ||
|
||
let mut res1 = [0.0;4]; | ||
for i in 0..4 { | ||
res1[i] = foo(&x1[i..i + 1]); | ||
} | ||
|
||
let res2: [f64; 4] = bfoo(&x2, &x3, &x4, &x5); | ||
|
||
let res3: [f64; 4] = vfoo(&x1); | ||
|
||
for i in 0..4 { | ||
assert_eq!(res1[i], res2[i]); | ||
assert_eq!(res1[i], res3[i]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// Showcasing a slightly more complex type. | ||
|
||
|
||
#[repr(C, packed)] | ||
struct Foo { | ||
arr: [i32; 3], | ||
x: f64, | ||
y: f32, | ||
res: f64, | ||
} | ||
|
||
//#pragma pack(1) | ||
//struct Foo { | ||
// int arr[3]; | ||
// double x; | ||
// float y; | ||
// double res; | ||
//}; | ||
|
||
#[batch(df, 4, Vector)] | ||
unsafe fn f(foo: *mut i32) { | ||
let xptr = foo.add(3) as *mut f64; | ||
let yptr = foo.add(5) as *mut f32; | ||
let resptr = foo.add(6) as *mut f64; | ||
let x: f64 = *xptr; | ||
let y: f32 = *yptr; | ||
*resptr = x * y; | ||
} | ||
|
||
fn main() { | ||
let foo1: Foo = Foo { [0,0,0], 10.0, 9.0, 0.0 }; | ||
let foo2: Foo = Foo { [0,0,0], 99.0, 7.0, 0.0 }; | ||
let foo3: Foo = Foo { [0,0,0], 1.10, 9.0, 0.0 }; | ||
let foo4: Foo = Foo { [0,0,0], 3.14, 0.1, 0.0 }; | ||
|
||
let expected = [90.0, 693.0, 9.9, 0.314}; | ||
|
||
df(&foo1.as_ptr() as *mut i32, | ||
&foo2.as_ptr() as *mut i32, | ||
&foo3.as_ptr() as *mut i32, | ||
&foo4.as_ptr() as *mut i32); | ||
|
||
assert_eq!(foo1.res, expected[0]); | ||
assert_eq!(foo2.res, expected[1]); | ||
assert_eq!(foo3.res, expected[2]); | ||
assert_eq!(foo4.res, expected[3]); | ||
} |