diff --git a/sway-lib-std/src/u128.sw b/sway-lib-std/src/u128.sw index 020445a79ed..1b8280dcd5c 100644 --- a/sway-lib-std/src/u128.sw +++ b/sway-lib-std/src/u128.sw @@ -412,7 +412,7 @@ impl core::ops::Subtract for U128 { } } impl core::ops::Multiply for U128 { - /// Multiply a `U128` with a `U128`. Reverts of overflow. + /// Multiply a `U128` with a `U128`. Panics if overflow. fn multiply(self, other: Self) -> Self { // in case both of the `U128` upper parts are bigger than zero, // it automatically means overflow, as any `U128` value diff --git a/sway-lib-std/src/u256.sw b/sway-lib-std/src/u256.sw index 1521889e87e..a82e606b530 100644 --- a/sway-lib-std/src/u256.sw +++ b/sway-lib-std/src/u256.sw @@ -5,6 +5,7 @@ use ::assert::assert; use ::convert::From; use ::result::Result::{self, *}; use ::u128::U128; +use ::math::Power; /// Left shift a `u64` and preserve the overflow amount if any. fn lsh_with_carry(word: u64, shift_amount: u64) -> (u64, u64) { @@ -100,6 +101,45 @@ impl U256 { } } + /// Initializes a new `U256` with a value of 1. + /// + /// ### Examples + /// + /// ```sway + /// use std::u256::U256; + /// + /// let init_one = U256::one(); + /// let one_u256 = U256 { a: 0, b: 0, c: 0, d: 1 }; + /// + /// assert(init_one == one_u256); + /// ``` + pub fn one() -> Self { + Self { + a: 0, + b: 0, + c: 0, + d: 1, + } + } + + /// Returns true if value is zero. + /// + /// ### Examples + /// + /// ```sway + /// use std::u256::U256 + /// + /// let zero_u256 = U256::new(); + /// assert(zero_u256.is_zero()); + /// ``` + pub fn is_zero(self) -> bool { + self.a == 0 && self.b == 0 && self.c == 0 && self.d == 0 + } + + pub fn low_u64(self) -> u64 { + self.a + } + /// Safely downcast to `u64` without loss of precision. /// /// # Additional Information @@ -638,3 +678,82 @@ impl core::ops::Divide for U256 { quotient } } + +impl Power for U256 { + /// Fast exponentiation by squaring + /// https://en.wikipedia.org/wiki/Exponentiation_by_squaring + /// + /// # Panics + /// + /// Panics if the result overflows the type. + fn pow(self, expon: Self) -> Self { + if expon.is_zero() { + return Self::one() + } + + let u_one = Self::one(); + let mut y = u_one; + let mut n = expon; + let mut x = self; + while n > u_one { + if is_even(n) { + x = x * x; + n >>= 1; + } else { + y = x * y; + x = x * x; + // to reduce odd number by 1 we should just clear the last bit + n.d = n.d & ((!0u64)>>1); + n >>= 1; + } + } + x * y + } +} + + +fn is_even(x: U256) -> bool { + x.low_u64() & 1 == 0 +} + +#[test] +fn test_five_pow_two_u256() { + let five = U256::from((0, 0, 0, 5)); + let two = U256::from((0, 0, 0, 2)); + + let five_pow_two = five.pow(two); + assert(five_pow_two.a == 0); + assert(five_pow_two.b == 0); + assert(five_pow_two.c == 0); + assert(five_pow_two.d == 25); +} + +#[test] +fn test_five_pow_three_u256() { + let five = U256::from((0, 0, 0, 5)); + let three = U256::from((0, 0, 0, 3)); + + let five_pow_three = five.pow(three); + assert(five_pow_three.a == 0); + assert(five_pow_three.b == 0); + assert(five_pow_three.c == 0); + assert(five_pow_three.d == 125); +} + +#[test] +fn test_five_pow_28_u256() { + let five = U256::from((0, 0, 0, 5)); + let twenty_eight = U256::from((0, 0, 0, 28)); + + let five_pow_28 = five.pow(twenty_eight); + assert(five_pow_28.a == 0); + assert(five_pow_28.b == 0); + assert(five_pow_28.c == 2); + assert(five_pow_28.d == 359414837200037395); +} + +#[test] +fn test_is_zero() { + let zero_u256 = U256::new(); + assert(zero_u256.is_zero()); +}