Skip to content

Commit

Permalink
Feat: fast interpolation bin search closest+bug fix (#286)
Browse files Browse the repository at this point in the history
- implement a gas efficient implementation of interpolation
- implement binary search for closest element in sorted array
- fix interpolation error in the case of `Interpolation::ConstantLeft`
- additional tests are provided

## Pull Request type

Please check the type of change your PR introduces:

- [ ] Bugfix
- [X] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no API changes)
- [ ] Build-related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

Interpolation is gas inefficient with large arrays

closes #285

## What is the new behavior?

Fast interpolation complexity is now logarithmic (vs linear) in the
number of array elements.

## Does this introduce a breaking change?

- [ ] Yes
- [X] No

<!-- If this does introduce a breaking change, please describe the
impact and migration path for existing applications below. -->

## Other information

Bug fix in the original interpolation is corrected and more tests cases
are provided.
  • Loading branch information
tekkac authored Mar 19, 2024
1 parent c1a604e commit 116c348
Show file tree
Hide file tree
Showing 10 changed files with 467 additions and 105 deletions.
1 change: 1 addition & 0 deletions Scarb.lock
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ name = "alexandria_numeric"
version = "0.1.0"
dependencies = [
"alexandria_math",
"alexandria_searching",
]

[[package]]
Expand Down
3 changes: 2 additions & 1 deletion src/numeric/Scarb.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ homepage = "https://github.com/keep-starknet-strange/alexandria/tree/main/src/nu
fmt.workspace = true

[dependencies]
alexandria_math = { path = "../math" }
alexandria_math = { path = "../math" }
alexandria_searching = { path = "../searching" }
113 changes: 96 additions & 17 deletions src/numeric/src/interpolate.cairo
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use alexandria_searching::binary_search::binary_search_closest as search;
//! One-dimensional linear interpolation for monotonically increasing sample points.

#[derive(Serde, Copy, Drop, PartialEq)]
enum Interpolation {
Linear: (),
Nearest: (),
ConstantLeft: (),
ConstantRight: (),
Linear,
Nearest,
ConstantLeft,
ConstantRight,
}

#[derive(Serde, Copy, Drop, PartialEq)]
enum Extrapolation {
Null: (),
Constant: (),
Null,
Constant,
}

/// Interpolate y(x) at x.
Expand All @@ -38,39 +39,39 @@ fn interpolate<
x: T, xs: Span<T>, ys: Span<T>, interpolation: Interpolation, extrapolation: Extrapolation
) -> T {
// [Check] Inputs
assert(xs.len() == ys.len(), 'Arrays must have the same len');
assert(xs.len() >= 2, 'Array must have at least 2 elts');
assert!(xs.len() == ys.len(), "Arrays must have the same len");
assert!(xs.len() >= 2, "Array must have at least 2 elts");

// [Check] Extrapolation
if x <= *xs[0] {
return match extrapolation {
Extrapolation::Null(()) => Zeroable::zero(),
Extrapolation::Constant(()) => *ys[0],
Extrapolation::Null => Zeroable::zero(),
Extrapolation::Constant => *ys[0],
};
}
if x >= *xs[xs.len() - 1] {
return match extrapolation {
Extrapolation::Null(()) => Zeroable::zero(),
Extrapolation::Constant(()) => *ys[xs.len() - 1],
Extrapolation::Null => Zeroable::zero(),
Extrapolation::Constant => *ys[xs.len() - 1],
};
}

// [Compute] Interpolation, could be optimized with binary search
let mut index = 0;
loop {
assert(*xs[index + 1] > *xs[index], 'Abscissa must be sorted');
assert!(*xs[index + 1] > *xs[index], "Abscissa must be sorted");

if x < *xs[index + 1] {
break match interpolation {
Interpolation::Linear(()) => {
Interpolation::Linear => {
// y = [(xb - x) * ya + (x - xa) * yb] / (xb - xa)
// y = [alpha * ya + beta * yb] / den
let den = *xs[index + 1] - *xs[index];
let alpha = *xs[index + 1] - x;
let beta = x - *xs[index];
(alpha * *ys[index] + beta * *ys[index + 1]) / den
},
Interpolation::Nearest(()) => {
Interpolation::Nearest => {
// y = ya or yb
let alpha = *xs[index + 1] - x;
let beta = x - *xs[index];
Expand All @@ -80,11 +81,89 @@ fn interpolate<
*ys[index + 1]
}
},
Interpolation::ConstantLeft(()) => *ys[index + 1],
Interpolation::ConstantRight(()) => *ys[index],
Interpolation::ConstantLeft => {
// Handle equality case: x == *xs[index]
if x <= *xs[index] {
*ys[index]
} else {
*ys[index + 1]
}
},
Interpolation::ConstantRight => *ys[index],
};
}

index += 1;
}
}

fn interpolate_fast<
T,
impl TPartialOrd: PartialOrd<T>,
impl TNumericLiteral: NumericLiteral<T>,
impl TAdd: Add<T>,
impl TSub: Sub<T>,
impl TMul: Mul<T>,
impl TDiv: Div<T>,
impl TZeroable: Zeroable<T>,
impl TCopy: Copy<T>,
impl TDrop: Drop<T>,
>(
x: T, xs: Span<T>, ys: Span<T>, interpolation: Interpolation, extrapolation: Extrapolation
) -> T {
// [Check] Inputs
assert!(xs.len() == ys.len(), "Arrays must have the same len");
assert!(xs.len() >= 2, "Array must have at least 2 elts");

// [Check] Extrapolation
if x <= *xs[0] {
let y = match extrapolation {
Extrapolation::Null => Zeroable::zero(),
Extrapolation::Constant => *ys[0],
};
return y;
}
if x >= *xs[xs.len() - 1] {
let y = match extrapolation {
Extrapolation::Null => Zeroable::zero(),
Extrapolation::Constant => *ys[xs.len() - 1],
};
return y;
}

// [Compute] Interpolation with binary search
let index: u32 = search(xs, x).expect('search error');

assert!(*xs[index + 1] > *xs[index], "Abscissa must be sorted");
assert!(x < *xs[index + 1], "search error");

match interpolation {
Interpolation::Linear => {
// y = [(xb - x) * ya + (x - xa) * yb] / (xb - xa)
// y = [alpha * ya + beta * yb] / den
let den = *xs[index + 1] - *xs[index];
let alpha = *xs[index + 1] - x;
let beta = x - *xs[index];
(alpha * *ys[index] + beta * *ys[index + 1]) / den
},
Interpolation::Nearest => {
// y = ya or yb
let alpha = *xs[index + 1] - x;
let beta = x - *xs[index];
if alpha >= beta {
*ys[index]
} else {
*ys[index + 1]
}
},
Interpolation::ConstantLeft => {
// Handle equality case: x == *xs[index]
if x <= *xs[index] {
*ys[index]
} else {
*ys[index + 1]
}
},
Interpolation::ConstantRight => *ys[index],
}
}
1 change: 1 addition & 0 deletions src/numeric/src/tests.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ mod cumprod_test;
mod cumsum_test;
mod diff_test;
mod integers_test;
mod interpolate_fast_test;
mod interpolate_test;
mod trapezoidal_rule_test;
122 changes: 122 additions & 0 deletions src/numeric/src/tests/interpolate_fast_test.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
use alexandria_numeric::interpolate::{
interpolate_fast as interpolate, Interpolation, Extrapolation
};

#[test]
#[available_gas(2000000)]
fn interp_extrapolation_test() {
let xs: Array::<u64> = array![3, 5, 7];
let ys = array![11, 13, 17];
assert_eq!(
interpolate(0, xs.span(), ys.span(), Interpolation::Linear, Extrapolation::Constant), 11
);
assert_eq!(
interpolate(9, xs.span(), ys.span(), Interpolation::Linear, Extrapolation::Constant), 17
);
assert_eq!(interpolate(0, xs.span(), ys.span(), Interpolation::Linear, Extrapolation::Null), 0);
assert_eq!(interpolate(9, xs.span(), ys.span(), Interpolation::Linear, Extrapolation::Null), 0);
}

#[test]
#[available_gas(2000000)]
fn interp_linear_test() {
let xs: Array::<u64> = array![3, 5, 7];
let ys = array![11, 13, 17];
assert_eq!(
interpolate(4, xs.span(), ys.span(), Interpolation::Linear, Extrapolation::Constant), 12
);
assert_eq!(
interpolate(4, xs.span(), ys.span(), Interpolation::Linear, Extrapolation::Constant), 12
);
}

#[test]
#[available_gas(2000000)]
fn interp_nearest_test() {
let xs: Array::<u64> = array![3, 5, 7];
let ys = array![11, 13, 17];
assert_eq!(
interpolate(4, xs.span(), ys.span(), Interpolation::Nearest, Extrapolation::Constant), 11
);
assert_eq!(
interpolate(6, xs.span(), ys.span(), Interpolation::Nearest, Extrapolation::Constant), 13
);
assert_eq!(
interpolate(7, xs.span(), ys.span(), Interpolation::Nearest, Extrapolation::Constant), 17
);
}

#[test]
#[available_gas(2000000)]
fn interp_constant_left_test() {
let xs: Array::<u64> = array![3, 5, 7];
let ys = array![11, 13, 17];
assert_eq!(
interpolate(4, xs.span(), ys.span(), Interpolation::ConstantLeft, Extrapolation::Constant),
13
);
assert_eq!(
interpolate(6, xs.span(), ys.span(), Interpolation::ConstantLeft, Extrapolation::Constant),
17
);
assert_eq!(
interpolate(7, xs.span(), ys.span(), Interpolation::ConstantLeft, Extrapolation::Constant),
17
);
}

#[test]
#[available_gas(2000000)]
fn interp_constant_left_diff() {
let xs: Span<u64> = array![0, 2, 4, 6, 8].span();
let ys: Span<u64> = array![0, 2, 4, 6, 8].span();
let inter = Interpolation::ConstantLeft;
let extra = Extrapolation::Constant;
assert_eq!(@interpolate(0, xs, ys, inter, extra), @0);
assert_eq!(@interpolate(1, xs, ys, inter, extra), @2);
assert_eq!(@interpolate(2, xs, ys, inter, extra), @2);
assert_eq!(@interpolate(3, xs, ys, inter, extra), @4);
assert_eq!(@interpolate(4, xs, ys, inter, extra), @4);
assert_eq!(@interpolate(5, xs, ys, inter, extra), @6);
assert_eq!(@interpolate(6, xs, ys, inter, extra), @6);
assert_eq!(@interpolate(7, xs, ys, inter, extra), @8);
assert_eq!(@interpolate(8, xs, ys, inter, extra), @8);
assert_eq!(@interpolate(9, xs, ys, inter, extra), @8);
}

#[test]
#[available_gas(2000000)]
fn interp_constant_right_test() {
let xs: Array::<u64> = array![3, 5, 8];
let ys = array![11, 13, 17];
assert_eq!(
interpolate(4, xs.span(), ys.span(), Interpolation::ConstantRight, Extrapolation::Constant),
11
);
assert_eq!(
interpolate(6, xs.span(), ys.span(), Interpolation::ConstantRight, Extrapolation::Constant),
13
);
assert_eq!(
interpolate(7, xs.span(), ys.span(), Interpolation::ConstantRight, Extrapolation::Constant),
13
);
}

#[test]
#[should_panic(expected: ("Arrays must have the same len",))]
#[available_gas(2000000)]
fn interp_revert_len_mismatch() {
let xs: Array::<u64> = array![3, 5];
let ys = array![11];
interpolate(4, xs.span(), ys.span(), Interpolation::Linear, Extrapolation::Constant);
}

#[test]
#[should_panic(expected: ("Array must have at least 2 elts",))]
#[available_gas(2000000)]
fn interp_revert_len_too_short() {
let xs: Array::<u64> = array![3];
let ys = array![11];
interpolate(4, xs.span(), ys.span(), Interpolation::Linear, Extrapolation::Constant);
}
Loading

0 comments on commit 116c348

Please sign in to comment.