Skip to content

Commit

Permalink
refactor: Turn pow into a generic function (keep-starknet-strange#198)
Browse files Browse the repository at this point in the history
<!--- Please provide a general summary of your changes in the title
above -->

## Pull Request type

<!-- Please try to limit your pull request to one type; submit multiple
pull requests if needed. -->

Please check the type of change your PR introduces:

- [ ] Bugfix
- [ ] Feature
- [ ] Code style update (formatting, renaming)
- [x] Refactoring (no functional changes, no API changes)
- [ ] Build-related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

<!-- Please describe the current behavior that you are modifying, or
link to a relevant issue. -->

Issue Number: N/A

## What is the new behavior?

<!-- Please describe the behavior or changes that are being added by
this PR. -->

-
-
-

## Does this introduce a breaking change?

- [ ] Yes
- [ ] No

<!-- If this does introduce a breaking change, please describe the
impact and migration path for existing applications below. -->

## Other information

<!-- Any other information that is important to this PR, such as
screenshots of how the component looks before and after the change. -->
  • Loading branch information
maciejka authored Oct 20, 2023
1 parent 1965197 commit ae1d514
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 243 deletions.
115 changes: 22 additions & 93 deletions src/math/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ use debug::PrintTrait;
/// * `base` - The number to raise.
/// * `exp` - The exponent.
/// # Returns
/// * `u128` - The result of base raised to the power of exp.
fn pow(base: u128, exp: u128) -> u128 {
if exp == 0 {
1
} else if exp == 1 {
/// * `T` - The result of base raised to the power of exp.
fn pow<T, +Sub<T>, +Mul<T>, +Div<T>, +Rem<T>, +PartialEq<T>, +Into<u8, T>, +Drop<T>, +Copy<T>>(
base: T, exp: T
) -> T {
if exp == 0_u8.into() {
1_u8.into()
} else if exp == 1_u8.into() {
base
} else if exp % 2 == 0 {
pow(base * base, exp / 2)
} else if exp % 2_u8.into() == 0_u8.into() {
pow(base * base, exp / 2_u8.into())
} else {
base * pow(base * base, (exp - 1) / 2)
base * pow(base * base, exp / 2_u8.into())
}
}

Expand All @@ -43,142 +45,69 @@ fn count_digits_of_base(mut num: u128, base: u128) -> u128 {
}

trait BitShift<T> {
fn fpow(x: T, n: T) -> T;
fn shl(x: T, n: T) -> T;
fn shr(x: T, n: T) -> T;
}

impl U8BitShift of BitShift<u8> {
fn fpow(x: u8, n: u8) -> u8 {
if n == 0 {
1
} else if n == 1 {
x
} else if (n & 1) == 1 {
x * BitShift::fpow(x * x, n / 2)
} else {
BitShift::fpow(x * x, n / 2)
}
}
fn shl(x: u8, n: u8) -> u8 {
(u8_wide_mul(x, BitShift::fpow(2, n)) & BoundedInt::<u8>::max().into()).try_into().unwrap()
(u8_wide_mul(x, pow(2, n)) & BoundedInt::<u8>::max().into()).try_into().unwrap()
}

fn shr(x: u8, n: u8) -> u8 {
x / BitShift::fpow(2, n)
x / pow(2, n)
}
}

impl U16BitShift of BitShift<u16> {
fn fpow(x: u16, n: u16) -> u16 {
if n == 0 {
1
} else if n == 1 {
x
} else if (n & 1) == 1 {
x * BitShift::fpow(x * x, n / 2)
} else {
BitShift::fpow(x * x, n / 2)
}
}
fn shl(x: u16, n: u16) -> u16 {
(u16_wide_mul(x, BitShift::fpow(2, n)) & BoundedInt::<u16>::max().into())
.try_into()
.unwrap()
(u16_wide_mul(x, pow(2, n)) & BoundedInt::<u16>::max().into()).try_into().unwrap()
}

fn shr(x: u16, n: u16) -> u16 {
x / BitShift::fpow(2, n)
x / pow(2, n)
}
}

impl U32BitShift of BitShift<u32> {
fn fpow(x: u32, n: u32) -> u32 {
if n == 0 {
1
} else if n == 1 {
x
} else if (n & 1) == 1 {
x * BitShift::fpow(x * x, n / 2)
} else {
BitShift::fpow(x * x, n / 2)
}
}
fn shl(x: u32, n: u32) -> u32 {
(u32_wide_mul(x, BitShift::fpow(2, n)) & BoundedInt::<u32>::max().into())
.try_into()
.unwrap()
(u32_wide_mul(x, pow(2, n)) & BoundedInt::<u32>::max().into()).try_into().unwrap()
}

fn shr(x: u32, n: u32) -> u32 {
x / BitShift::fpow(2, n)
x / pow(2, n)
}
}

impl U64BitShift of BitShift<u64> {
fn fpow(x: u64, n: u64) -> u64 {
if n == 0 {
1
} else if n == 1 {
x
} else if (n & 1) == 1 {
x * BitShift::fpow(x * x, n / 2)
} else {
BitShift::fpow(x * x, n / 2)
}
}
fn shl(x: u64, n: u64) -> u64 {
(u64_wide_mul(x, BitShift::fpow(2, n)) & BoundedInt::<u64>::max().into())
.try_into()
.unwrap()
(u64_wide_mul(x, pow(2, n)) & BoundedInt::<u64>::max().into()).try_into().unwrap()
}

fn shr(x: u64, n: u64) -> u64 {
x / BitShift::fpow(2, n)
x / pow(2, n)
}
}

impl U128BitShift of BitShift<u128> {
fn fpow(x: u128, n: u128) -> u128 {
if n == 0 {
1
} else if n == 1 {
x
} else if (n & 1) == 1 {
x * BitShift::fpow(x * x, n / 2)
} else {
BitShift::fpow(x * x, n / 2)
}
}
fn shl(x: u128, n: u128) -> u128 {
let (_, bottom_word) = u128_wide_mul(x, BitShift::fpow(2, n));
let (_, bottom_word) = u128_wide_mul(x, pow(2, n));
bottom_word
}

fn shr(x: u128, n: u128) -> u128 {
x / BitShift::fpow(2, n)
x / pow(2, n)
}
}

impl U256BitShift of BitShift<u256> {
fn fpow(x: u256, n: u256) -> u256 {
if n == 0 {
1
} else if n == 1 {
x
} else if (n & 1) == 1 {
x * BitShift::fpow(x * x, n / 2)
} else {
BitShift::fpow(x * x, n / 2)
}
}
fn shl(x: u256, n: u256) -> u256 {
let (r, _) = u256_overflow_mul(x, BitShift::fpow(2, n));
let (r, _) = u256_overflow_mul(x, pow(2, n));
r
}

fn shr(x: u256, n: u256) -> u256 {
x / BitShift::fpow(2, n)
x / pow(2, n)
}
}

Expand Down
Loading

0 comments on commit ae1d514

Please sign in to comment.