From c2cbfa195d803c97bbdabae09a5210fd9e7619d2 Mon Sep 17 00:00:00 2001 From: chachaleo Date: Mon, 8 Jul 2024 14:00:12 +0200 Subject: [PATCH] span math trait f32x32 --- packages/orion-algo/src/algo/linear_fit.cairo | 7 +- packages/orion-algo/src/lib.cairo | 2 +- packages/orion-algo/src/span_math.cairo | 16 ++- packages/orion-algo/src/span_math/core.cairo | 38 ------ .../{math.cairo => span_f16x16.cairo} | 38 +++++- .../src/span_math/span_f32x32.cairo | 126 ++++++++++++++++++ 6 files changed, 177 insertions(+), 50 deletions(-) delete mode 100644 packages/orion-algo/src/span_math/core.cairo rename packages/orion-algo/src/span_math/{math.cairo => span_f16x16.cairo} (85%) create mode 100644 packages/orion-algo/src/span_math/span_f32x32.cairo diff --git a/packages/orion-algo/src/algo/linear_fit.cairo b/packages/orion-algo/src/algo/linear_fit.cairo index 4e68f6aaf..d953a98fe 100644 --- a/packages/orion-algo/src/algo/linear_fit.cairo +++ b/packages/orion-algo/src/algo/linear_fit.cairo @@ -1,6 +1,6 @@ -use orion_numbers::f16x16::core::{f16x16, FixedTrait}; -use orion_algo::span_math::core::SpanMathTrait; -use orion_numbers::f16x16::core_trait::I32Div; +use orion_numbers::{f16x16::core::{f16x16}, FixedTrait}; +use orion_algo::span_math::SpanMathTrait; +use orion_numbers::core_trait::I32Div; pub fn linear_fit(x: Span, y: Span) -> (f16x16, f16x16) { if x.len() != y.len() || x.len() == 0 { @@ -24,7 +24,6 @@ pub fn linear_fit(x: Span, y: Span) -> (f16x16, f16x16) { (a, b) } - #[cfg(test)] mod tests { use super::linear_fit; diff --git a/packages/orion-algo/src/lib.cairo b/packages/orion-algo/src/lib.cairo index 84b2d51cd..24839e0c9 100644 --- a/packages/orion-algo/src/lib.cairo +++ b/packages/orion-algo/src/lib.cairo @@ -1,2 +1,2 @@ pub mod span_math; -pub mod algo; \ No newline at end of file +pub mod algo; diff --git a/packages/orion-algo/src/span_math.cairo b/packages/orion-algo/src/span_math.cairo index 2bfce71aa..626876302 100644 --- a/packages/orion-algo/src/span_math.cairo +++ b/packages/orion-algo/src/span_math.cairo @@ -1,2 +1,14 @@ -pub mod core; -mod math; \ No newline at end of file +pub mod span_f32x32; +pub mod span_f16x16; + +use span_f16x16::F16x16SpanMath; +use span_f32x32::F32x32SpanMath; + +pub trait SpanMathTrait { + fn arange(n: u32) -> Span; + fn dot(self: Span, other: Span) -> T; + fn max(self: Span) -> T; + fn min(self: Span) -> T; + fn prod(self: Span) -> T; + fn sum(self: Span) -> T; +} \ No newline at end of file diff --git a/packages/orion-algo/src/span_math/core.cairo b/packages/orion-algo/src/span_math/core.cairo deleted file mode 100644 index 70f2e440c..000000000 --- a/packages/orion-algo/src/span_math/core.cairo +++ /dev/null @@ -1,38 +0,0 @@ -use core::array::ArrayTrait; -use core::option::OptionTrait; -use core::traits::TryInto; -use orion_numbers::f16x16::core::{f16x16, FixedTrait, ONE}; -use orion_numbers::f16x16::core_trait::{I32Rem, I32Div}; - - -use orion_algo::span_math::math; - - - -#[generate_trait] -pub impl F16x16SpanMath of SpanMathTrait { - fn arange(n: u32) -> Span { - math::arange(n) - } - - fn dot(self: Span, other: Span) -> f16x16 { - math::dot(self, other) - } - - fn max(self: Span) -> f16x16 { - math::max(self) - } - - fn min(self: Span) -> f16x16 { - math::min(self) - } - - fn prod(self: Span) -> f16x16 { - math::prod(self) - } - - fn sum(self: Span) -> f16x16 { - math::sum(self) - } - -} \ No newline at end of file diff --git a/packages/orion-algo/src/span_math/math.cairo b/packages/orion-algo/src/span_math/span_f16x16.cairo similarity index 85% rename from packages/orion-algo/src/span_math/math.cairo rename to packages/orion-algo/src/span_math/span_f16x16.cairo index 5469b9cd0..32d34f559 100644 --- a/packages/orion-algo/src/span_math/math.cairo +++ b/packages/orion-algo/src/span_math/span_f16x16.cairo @@ -1,8 +1,36 @@ -use core::array::ArrayTrait; -use core::option::OptionTrait; -use core::traits::TryInto; -use orion_numbers::f16x16::core::{f16x16, FixedTrait, ONE}; -use orion_numbers::f16x16::core_trait::{I32Rem, I32Div}; +use orion_numbers::{f16x16::core::{f16x16, ONE}, FixedTrait}; +use orion_numbers::core_trait::{I32Rem, I32Div}; + +use orion_algo::span_math::SpanMathTrait; + + +pub impl F16x16SpanMath of SpanMathTrait { + fn arange(n: u32) -> Span { + arange(n) + } + + fn dot(self: Span, other: Span) -> f16x16 { + dot(self, other) + } + + fn max(self: Span) -> f16x16 { + max(self) + } + + fn min(self: Span) -> f16x16 { + min(self) + } + + fn prod(self: Span) -> f16x16 { + prod(self) + } + + fn sum(self: Span) -> f16x16 { + sum(self) + } + +} + pub(crate) fn arange(n: u32) -> Span { let mut i = 0; diff --git a/packages/orion-algo/src/span_math/span_f32x32.cairo b/packages/orion-algo/src/span_math/span_f32x32.cairo new file mode 100644 index 000000000..b6b014d62 --- /dev/null +++ b/packages/orion-algo/src/span_math/span_f32x32.cairo @@ -0,0 +1,126 @@ +use orion_numbers::{core_trait::{I64Rem, I64Div}, FixedTrait}; +use orion_numbers::f32x32::core::{f32x32, ONE}; + +use orion_algo::span_math::SpanMathTrait; + + +pub impl F32x32SpanMath of SpanMathTrait { + fn arange(n: u32) -> Span { + arange(n) + } + + fn dot(self: Span, other: Span) -> f32x32 { + dot(self, other) + } + + fn max(self: Span) -> f32x32 { + max(self) + } + + fn min(self: Span) -> f32x32 { + min(self) + } + + fn prod(self: Span) -> f32x32 { + prod(self) + } + + fn sum(self: Span) -> f32x32 { + sum(self) + } +} + +fn arange(n: u32) -> Span { + let mut i = 0; + let mut arr = array![]; + while i < n { + arr.append(i.try_into().unwrap() * ONE); + i += 1; + }; + + arr.span() +} + +fn dot(a: Span, b: Span) -> f32x32 { + let mut i = 0; + let mut acc = 0; + while i != a.len() { + acc += FixedTrait::mul(*a.at(i), *b.at(i)); + i += 1; + }; + + acc +} + +fn max(mut a: Span) -> f32x32 { + assert(a.len() > 0, 'span cannot be empty'); + + let mut max = FixedTrait::MIN(); + + loop { + match a.pop_front() { + Option::Some(item) => { if *item > max { + max = *item; + } }, + Option::None => { break max; }, + } + } +} + +fn min(mut a: Span) -> f32x32 { + assert(a.len() > 0, 'span cannot be empty'); + + let mut min = FixedTrait::MAX(); + + loop { + match a.pop_front() { + Option::Some(item) => { if *item < min { + min = *item; + } }, + Option::None => { break min; }, + } + } +} + +fn prod(mut a: Span) -> f32x32 { + let mut prod = 1; + loop { + match a.pop_front() { + Option::Some(v) => { prod = prod.mul(*v); }, + Option::None => { break prod; } + }; + } +} + +fn sum(mut a: Span) -> f32x32 { + let mut prod = 1; + loop { + match a.pop_front() { + Option::Some(v) => { prod = prod + *v; }, + Option::None => { break prod; } + }; + } +} + + +pub fn linear_fit(x: Span, y: Span) -> (f32x32, f32x32) { + if x.len() != y.len() || x.len() == 0 { + panic!("x and y should be of the same lenght") + } + + let n: f32x32 = x.len().try_into().unwrap(); + let sum_x = x.sum(); + let sum_y = y.sum(); + let sum_xx = x.dot(x); + let sum_xy = x.dot(y); + + let denominator = n * sum_xx - (sum_x.mul(sum_x)); + if denominator == 0 { + panic!("division by zero exception") + } + + let a = ((n * sum_xy) - sum_x.mul(sum_y)).div(denominator); + let b = (sum_y - a.mul(sum_x)) / n; + + (a, b) +}