diff --git a/src/linalg/README.md b/src/linalg/README.md index e1583a6d..5a4703a1 100644 --- a/src/linalg/README.md +++ b/src/linalg/README.md @@ -3,3 +3,7 @@ ## [Dot product](./src/dot.cairo) The dot product or scalar product is an algebraic operation that takes two equal-length sequences of numbers (usually coordinate vectors), and returns a single number. Algebraically, the dot product is the sum of the products of the corresponding entries of the two sequences of numbers ([see also](https://en.wikipedia.org/wiki/Dot_product)). + +## [Kronecker product](./src/kron.cairo) + +The Kronecker product is an an algebraic operation that takes two equal-length sequences of numbers and returns an array of numbers([see also](https://numpy.org/doc/stable/reference/generated/numpy.kron.html)). diff --git a/src/linalg/src/kron.cairo b/src/linalg/src/kron.cairo new file mode 100644 index 00000000..febf091a --- /dev/null +++ b/src/linalg/src/kron.cairo @@ -0,0 +1,42 @@ +use core::array::SpanTrait; +//! Kronecker product of two arrays + +#[derive(Drop, Copy, PartialEq)] +enum KronError { + UnequalLength, +} + +/// Compute the Kronecker product for 2 given arrays. +/// # Arguments +/// * `xs` - The first sequence of len L. +/// * `ys` - The second sequence of len L. +/// # Returns +/// * `Result, KronError>` - The Kronecker product. +fn kron, +AddEq, +Zeroable, +Copy, +Drop,>( + mut xs: Span, mut ys: Span +) -> Result, KronError> { + // [Check] Inputs + if xs.len() != ys.len() { + return Result::Err(KronError::UnequalLength); + } + assert(xs.len() == ys.len(), 'Arrays must have the same len'); + + // [Compute] Kronecker product in a loop + let mut array = array![]; + loop { + match xs.pop_front() { + Option::Some(x_value) => { + let mut ys_clone = ys; + loop { + match ys_clone.pop_front() { + Option::Some(y_value) => { array.append(*x_value * *y_value); }, + Option::None => { break; }, + }; + }; + }, + Option::None => { break; }, + }; + }; + + Result::Ok(array) +} diff --git a/src/linalg/src/lib.cairo b/src/linalg/src/lib.cairo index e0f8508c..145dbdaa 100644 --- a/src/linalg/src/lib.cairo +++ b/src/linalg/src/lib.cairo @@ -1,4 +1,5 @@ mod dot; +mod kron; #[cfg(test)] mod tests; diff --git a/src/linalg/src/tests.cairo b/src/linalg/src/tests.cairo index 8b51795c..f48b4abd 100644 --- a/src/linalg/src/tests.cairo +++ b/src/linalg/src/tests.cairo @@ -1 +1,2 @@ mod dot_test; +mod kron_test; diff --git a/src/linalg/src/tests/kron_test.cairo b/src/linalg/src/tests/kron_test.cairo new file mode 100644 index 00000000..35e66e85 --- /dev/null +++ b/src/linalg/src/tests/kron_test.cairo @@ -0,0 +1,29 @@ +use alexandria_linalg::kron::{kron, KronError}; + +#[test] +#[available_gas(2000000)] +fn kron_product_test() { + let mut xs: Array = array![1, 10, 100]; + let mut ys = array![5, 6, 7]; + let zs = kron(xs.span(), ys.span()).unwrap(); + assert(*zs[0] == 5, 'wrong value at index 0'); + assert(*zs[1] == 6, 'wrong value at index 1'); + assert(*zs[2] == 7, 'wrong value at index 2'); + assert(*zs[3] == 50, 'wrong value at index 3'); + assert(*zs[4] == 60, 'wrong value at index 4'); + assert(*zs[5] == 70, 'wrong value at index 5'); + assert(*zs[6] == 500, 'wrong value at index 6'); + assert(*zs[7] == 600, 'wrong value at index 7'); + assert(*zs[8] == 700, 'wrong value at index 8'); +} + +#[test] +#[available_gas(2000000)] +fn kron_product_test_check_len() { + let mut xs: Array = array![1]; + let mut ys = array![]; + assert( + kron(xs.span(), ys.span()) == Result::Err(KronError::UnequalLength), + 'Arrays must have the same len' + ); +}