Skip to content

Commit

Permalink
feat: add Kronecker product (#264)
Browse files Browse the repository at this point in the history
This PR adds Kronecker product.

## Pull Request type

<!-- Please try to limit your pull request to one type; submit multiple
pull requests if needed. -->

Please check the type of change your PR introduces:

- [ ] Bugfix
- [x] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no API changes)
- [ ] Build-related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

<!-- Please describe the current behavior that you are modifying, or
link to a relevant issue. -->

Issue Number: N/A

## What is the new behavior?

<!-- Please describe the behavior or changes that are being added by
this PR. -->

- calculating Kronecker product of two lists.
- all tests are passed.

## Does this introduce a breaking change?

- [ ] Yes
- [x] No

<!-- If this does introduce a breaking change, please describe the
impact and migration path for existing applications below. -->

## Other information

<!-- Any other information that is important to this PR, such as
screenshots of how the component looks before and after the change. -->
  • Loading branch information
Soptq authored Jan 25, 2024
1 parent 4327786 commit d307aca
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/linalg/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
## [Dot product](./src/dot.cairo)

The dot product or scalar product is an algebraic operation that takes two equal-length sequences of numbers (usually coordinate vectors), and returns a single number. Algebraically, the dot product is the sum of the products of the corresponding entries of the two sequences of numbers ([see also](https://en.wikipedia.org/wiki/Dot_product)).

## [Kronecker product](./src/kron.cairo)

The Kronecker product is an an algebraic operation that takes two equal-length sequences of numbers and returns an array of numbers([see also](https://numpy.org/doc/stable/reference/generated/numpy.kron.html)).
41 changes: 41 additions & 0 deletions src/linalg/src/kron.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use core::array::SpanTrait;
//! Kronecker product of two arrays

#[derive(Drop, Copy, PartialEq)]
enum KronError {
UnequalLength,
}

/// Compute the Kronecker product for 2 given arrays.
/// # Arguments
/// * `xs` - The first sequence of len L.
/// * `ys` - The second sequence of len L.
/// # Returns
/// * `Result<Array<T>, KronError>` - The Kronecker product.
fn kron<T, +Mul<T>, +AddEq<T>, +Zeroable<T>, +Copy<T>, +Drop<T>,>(
mut xs: Span<T>, mut ys: Span<T>
) -> Result<Array<T>, KronError> {
// [Check] Inputs
if xs.len() != ys.len() {
return Result::Err(KronError::UnequalLength);
}

// [Compute] Kronecker product in a loop
let mut array = array![];
loop {
match xs.pop_front() {
Option::Some(x_value) => {
let mut ys_clone = ys;
loop {
match ys_clone.pop_front() {
Option::Some(y_value) => { array.append(*x_value * *y_value); },
Option::None => { break; },
};
};
},
Option::None => { break; },
};
};

Result::Ok(array)
}
1 change: 1 addition & 0 deletions src/linalg/src/lib.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod dot;
mod kron;

#[cfg(test)]
mod tests;
1 change: 1 addition & 0 deletions src/linalg/src/tests.cairo
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
mod dot_test;
mod kron_test;
29 changes: 29 additions & 0 deletions src/linalg/src/tests/kron_test.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use alexandria_linalg::kron::{kron, KronError};

#[test]
#[available_gas(2000000)]
fn kron_product_test() {
let mut xs: Array<u64> = array![1, 10, 100];
let mut ys = array![5, 6, 7];
let zs = kron(xs.span(), ys.span()).unwrap();
assert(*zs[0] == 5, 'wrong value at index 0');
assert(*zs[1] == 6, 'wrong value at index 1');
assert(*zs[2] == 7, 'wrong value at index 2');
assert(*zs[3] == 50, 'wrong value at index 3');
assert(*zs[4] == 60, 'wrong value at index 4');
assert(*zs[5] == 70, 'wrong value at index 5');
assert(*zs[6] == 500, 'wrong value at index 6');
assert(*zs[7] == 600, 'wrong value at index 7');
assert(*zs[8] == 700, 'wrong value at index 8');
}

#[test]
#[available_gas(2000000)]
fn kron_product_test_check_len() {
let mut xs: Array<u64> = array![1];
let mut ys = array![];
assert(
kron(xs.span(), ys.span()) == Result::Err(KronError::UnequalLength),
'Arrays must have the same len'
);
}

0 comments on commit d307aca

Please sign in to comment.