Skip to content

Commit

Permalink
feat(math): binary exp (keep-starknet-strange#175)
Browse files Browse the repository at this point in the history
<!--- Please provide a general summary of your changes in the title
above -->

Updating the `pow` impl to use binary exponentiation, which is cheaper. 

<!-- Please try to limit your pull request to one type; submit multiple
pull requests if needed. -->

Please check the type of change your PR introduces:

- [ ] Bugfix
- [x] Feature
- [ ] Code style update (formatting, renaming)
- [x] Refactoring (no functional changes, no API changes)
- [ ] Build-related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

`pow` is O(n)

## What is the new behavior?

`pow` is O(log n)

## Does this introduce a breaking change?

- [ ] Yes
- [x] No

<!-- If this does introduce a breaking change, please describe the
impact and migration path for existing applications below. -->

## Other information

I did a simple
[benchmark](https://gist.github.com/milancermak/adf6c70cdd550155a754e335d92fcbf5).
Stepwise performance is the same for a simple case, the greater the
exponent the more efficient this new implementation is.

Added a testcase to check for the whole 2^N range.

FWIW, it's pretty much the same as `BitShift::fpow` 😁 but without the
use of Bitwise ops.
  • Loading branch information
milancermak authored Sep 11, 2023
1 parent a3052ff commit e863e5b
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 8 deletions.
10 changes: 7 additions & 3 deletions src/math/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@ use option::OptionTrait;
use traits::Into;

/// Raise a number to a power.
/// O(n) time complexity.
/// O(log n) time complexity.
/// * `base` - The number to raise.
/// * `exp` - The exponent.
/// # Returns
/// * `u128` - The result of base raised to the power of exp.
fn pow(base: u128, mut exp: u128) -> u128 {
fn pow(base: u128, exp: u128) -> u128 {
if exp == 0 {
1
} else if exp == 1 {
base
} else if exp % 2 == 0 {
pow(base * base, exp / 2)
} else {
base * pow(base, exp - 1)
base * pow(base * base, (exp - 1) / 2)
}
}

Expand Down
143 changes: 138 additions & 5 deletions src/math/src/tests/math_test.cairo
Original file line number Diff line number Diff line change
@@ -1,13 +1,146 @@
use alexandria_math::{pow, BitShift, count_digits_of_base};

// Test power function
#[test]
#[available_gas(1000000000)]
fn test_pow_power_2_all() {
assert(pow(2, 0) == 1, '0');
assert(pow(2, 1) == 2, '1');
assert(pow(2, 2) == 4, '2');
assert(pow(2, 3) == 8, '3');
assert(pow(2, 4) == 16, '4');
assert(pow(2, 5) == 32, '5');
assert(pow(2, 6) == 64, '6');
assert(pow(2, 7) == 128, '7');
assert(pow(2, 8) == 256, '8');
assert(pow(2, 9) == 512, '9');
assert(pow(2, 10) == 1024, '10');
assert(pow(2, 11) == 2048, '11');
assert(pow(2, 12) == 4096, '12');
assert(pow(2, 13) == 8192, '13');
assert(pow(2, 14) == 16384, '14');
assert(pow(2, 15) == 32768, '15');
assert(pow(2, 16) == 65536, '16');
assert(pow(2, 17) == 131072, '17');
assert(pow(2, 18) == 262144, '18');
assert(pow(2, 19) == 524288, '19');
assert(pow(2, 20) == 1048576, '20');
assert(pow(2, 21) == 2097152, '21');
assert(pow(2, 22) == 4194304, '22');
assert(pow(2, 23) == 8388608, '23');
assert(pow(2, 24) == 16777216, '24');
assert(pow(2, 25) == 33554432, '25');
assert(pow(2, 26) == 67108864, '26');
assert(pow(2, 27) == 134217728, '27');
assert(pow(2, 28) == 268435456, '28');
assert(pow(2, 29) == 536870912, '29');
assert(pow(2, 30) == 1073741824, '30');
assert(pow(2, 31) == 2147483648, '31');
assert(pow(2, 32) == 4294967296, '32');
assert(pow(2, 33) == 8589934592, '33');
assert(pow(2, 34) == 17179869184, '34');
assert(pow(2, 35) == 34359738368, '35');
assert(pow(2, 36) == 68719476736, '36');
assert(pow(2, 37) == 137438953472, '37');
assert(pow(2, 38) == 274877906944, '38');
assert(pow(2, 39) == 549755813888, '39');
assert(pow(2, 40) == 1099511627776, '40');
assert(pow(2, 41) == 2199023255552, '41');
assert(pow(2, 42) == 4398046511104, '42');
assert(pow(2, 43) == 8796093022208, '43');
assert(pow(2, 44) == 17592186044416, '44');
assert(pow(2, 45) == 35184372088832, '45');
assert(pow(2, 46) == 70368744177664, '46');
assert(pow(2, 47) == 140737488355328, '47');
assert(pow(2, 48) == 281474976710656, '48');
assert(pow(2, 49) == 562949953421312, '49');
assert(pow(2, 50) == 1125899906842624, '50');
assert(pow(2, 51) == 2251799813685248, '51');
assert(pow(2, 52) == 4503599627370496, '52');
assert(pow(2, 53) == 9007199254740992, '53');
assert(pow(2, 54) == 18014398509481984, '54');
assert(pow(2, 55) == 36028797018963968, '55');
assert(pow(2, 56) == 72057594037927936, '56');
assert(pow(2, 57) == 144115188075855872, '57');
assert(pow(2, 58) == 288230376151711744, '58');
assert(pow(2, 59) == 576460752303423488, '59');
assert(pow(2, 60) == 1152921504606846976, '60');
assert(pow(2, 61) == 2305843009213693952, '61');
assert(pow(2, 62) == 4611686018427387904, '62');
assert(pow(2, 63) == 9223372036854775808, '63');
assert(pow(2, 64) == 18446744073709551616, '64');
assert(pow(2, 65) == 36893488147419103232, '65');
assert(pow(2, 66) == 73786976294838206464, '66');
assert(pow(2, 67) == 147573952589676412928, '67');
assert(pow(2, 68) == 295147905179352825856, '68');
assert(pow(2, 69) == 590295810358705651712, '69');
assert(pow(2, 70) == 1180591620717411303424, '70');
assert(pow(2, 71) == 2361183241434822606848, '71');
assert(pow(2, 72) == 4722366482869645213696, '72');
assert(pow(2, 73) == 9444732965739290427392, '73');
assert(pow(2, 74) == 18889465931478580854784, '74');
assert(pow(2, 75) == 37778931862957161709568, '75');
assert(pow(2, 76) == 75557863725914323419136, '76');
assert(pow(2, 77) == 151115727451828646838272, '77');
assert(pow(2, 78) == 302231454903657293676544, '78');
assert(pow(2, 79) == 604462909807314587353088, '79');
assert(pow(2, 80) == 1208925819614629174706176, '80');
assert(pow(2, 81) == 2417851639229258349412352, '81');
assert(pow(2, 82) == 4835703278458516698824704, '82');
assert(pow(2, 83) == 9671406556917033397649408, '83');
assert(pow(2, 84) == 19342813113834066795298816, '84');
assert(pow(2, 85) == 38685626227668133590597632, '85');
assert(pow(2, 86) == 77371252455336267181195264, '86');
assert(pow(2, 87) == 154742504910672534362390528, '87');
assert(pow(2, 88) == 309485009821345068724781056, '88');
assert(pow(2, 89) == 618970019642690137449562112, '89');
assert(pow(2, 90) == 1237940039285380274899124224, '90');
assert(pow(2, 91) == 2475880078570760549798248448, '91');
assert(pow(2, 92) == 4951760157141521099596496896, '92');
assert(pow(2, 93) == 9903520314283042199192993792, '93');
assert(pow(2, 94) == 19807040628566084398385987584, '94');
assert(pow(2, 95) == 39614081257132168796771975168, '95');
assert(pow(2, 96) == 79228162514264337593543950336, '96');
assert(pow(2, 97) == 158456325028528675187087900672, '97');
assert(pow(2, 98) == 316912650057057350374175801344, '98');
assert(pow(2, 99) == 633825300114114700748351602688, '99');
assert(pow(2, 100) == 1267650600228229401496703205376, '100');
assert(pow(2, 101) == 2535301200456458802993406410752, '101');
assert(pow(2, 102) == 5070602400912917605986812821504, '102');
assert(pow(2, 103) == 10141204801825835211973625643008, '103');
assert(pow(2, 104) == 20282409603651670423947251286016, '104');
assert(pow(2, 105) == 40564819207303340847894502572032, '105');
assert(pow(2, 106) == 81129638414606681695789005144064, '106');
assert(pow(2, 107) == 162259276829213363391578010288128, '107');
assert(pow(2, 108) == 324518553658426726783156020576256, '108');
assert(pow(2, 109) == 649037107316853453566312041152512, '109');
assert(pow(2, 110) == 1298074214633706907132624082305024, '110');
assert(pow(2, 111) == 2596148429267413814265248164610048, '111');
assert(pow(2, 112) == 5192296858534827628530496329220096, '112');
assert(pow(2, 113) == 10384593717069655257060992658440192, '113');
assert(pow(2, 114) == 20769187434139310514121985316880384, '114');
assert(pow(2, 115) == 41538374868278621028243970633760768, '115');
assert(pow(2, 116) == 83076749736557242056487941267521536, '116');
assert(pow(2, 117) == 166153499473114484112975882535043072, '117');
assert(pow(2, 118) == 332306998946228968225951765070086144, '118');
assert(pow(2, 119) == 664613997892457936451903530140172288, '119');
assert(pow(2, 120) == 1329227995784915872903807060280344576, '120');
assert(pow(2, 121) == 2658455991569831745807614120560689152, '121');
assert(pow(2, 122) == 5316911983139663491615228241121378304, '122');
assert(pow(2, 123) == 10633823966279326983230456482242756608, '123');
assert(pow(2, 124) == 21267647932558653966460912964485513216, '124');
assert(pow(2, 125) == 42535295865117307932921825928971026432, '125');
assert(pow(2, 126) == 85070591730234615865843651857942052864, '126');
assert(pow(2, 127) == 170141183460469231731687303715884105728, '127');
}


#[test]
#[available_gas(2000000)]
fn pow_test() {
assert(pow(2, 0) == 1, 'invalid result');
assert(pow(2, 1) == 2, 'invalid result');
assert(pow(2, 12) == 4096, 'invalid result');
assert(pow(5, 9) == 1953125, 'invalid result');
assert(pow(200, 0) == 1, '200^0');
assert(pow(5, 9) == 1953125, '5^9');
assert(pow(14, 30) == 24201432355484595421941037243826176, '14^30');
}

// Test counting of number of digits function
Expand All @@ -32,6 +165,6 @@ fn fpow_test() {

#[test]
#[available_gas(2000000)]
fn fpow_test_u2156() {
fn fpow_test_u256() {
assert(BitShift::fpow(3_u256, 8) == 6561, 'invalid result');
}

0 comments on commit e863e5b

Please sign in to comment.