Skip to content

Commit

Permalink
feat: add Kronecker product
Browse files Browse the repository at this point in the history
  • Loading branch information
Soptq committed Jan 24, 2024
1 parent 7474cca commit f20bc8f
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/linalg/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
## [Dot product](./src/dot.cairo)

The dot product or scalar product is an algebraic operation that takes two equal-length sequences of numbers (usually coordinate vectors), and returns a single number. Algebraically, the dot product is the sum of the products of the corresponding entries of the two sequences of numbers ([see also](https://en.wikipedia.org/wiki/Dot_product)).

## [Kronecker product](./src/kron.cairo)

The Kronecker product is an an algebraic operation that takes two equal-length sequences of numbers and returns an array of numbers([see also](https://numpy.org/doc/stable/reference/generated/numpy.kron.html)).
33 changes: 33 additions & 0 deletions src/linalg/src/kron.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//! Kronecker product of two arrays

/// Compute the Kronecker product for 2 given arrays.
/// # Arguments
/// * `xs` - The first sequence of len L.
/// * `ys` - The second sequence of len L.
/// # Returns
/// * `Array<T>` - The Kronecker product.
fn kron<T, +Mul<T>, +AddEq<T>, +Zeroable<T>, +Copy<T>, +Drop<T>,>(
mut xs: Span<T>, mut ys: Span<T>
) -> Array<T> {
// [Check] Inputs
assert(xs.len() == ys.len(), 'Arrays must have the same len');

// [Compute] Kronecker product in a loop
let mut array = array![];
loop {
match xs.pop_front() {
Option::Some(x_value) => {
let mut ys_clone = ys.clone();
loop {
match ys_clone.pop_front() {
Option::Some(y_value) => { array.append(*x_value * *y_value); },
Option::None => { break; },
};
};
},
Option::None => { break; },
};
};

array
}
1 change: 1 addition & 0 deletions src/linalg/src/lib.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod dot;
mod kron;

#[cfg(test)]
mod tests;
1 change: 1 addition & 0 deletions src/linalg/src/tests.cairo
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
mod dot_test;
mod kron_test;
27 changes: 27 additions & 0 deletions src/linalg/src/tests/kron_test.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use alexandria_linalg::kron::kron;

#[test]
#[available_gas(2000000)]
fn kron_product_test() {
let mut xs: Array<u64> = array![1, 10, 100];
let mut ys = array![5, 6, 7];
let zs = kron(xs.span(), ys.span());
assert(*zs[0] == 5, 'wrong value at index 0');
assert(*zs[1] == 6, 'wrong value at index 1');
assert(*zs[2] == 7, 'wrong value at index 2');
assert(*zs[3] == 50, 'wrong value at index 3');
assert(*zs[4] == 60, 'wrong value at index 4');
assert(*zs[5] == 70, 'wrong value at index 5');
assert(*zs[6] == 500, 'wrong value at index 6');
assert(*zs[7] == 600, 'wrong value at index 7');
assert(*zs[8] == 700, 'wrong value at index 8');
}

#[test]
#[should_panic(expected: ('Arrays must have the same len',))]
#[available_gas(2000000)]
fn kron_product_test_check_len() {
let mut xs: Array<u64> = array![1];
let mut ys = array![];
kron(xs.span(), ys.span());
}

0 comments on commit f20bc8f

Please sign in to comment.