diff --git a/src/math/src/lib.cairo b/src/math/src/lib.cairo index 08a6e47c..b29d91e1 100644 --- a/src/math/src/lib.cairo +++ b/src/math/src/lib.cairo @@ -125,3 +125,110 @@ impl U256BitShift of BitShift { x / pow(2, n) } } + +/// Rotate the bits of an unsigned integer of type T +trait BitRotate { + /// Take the bits of an unsigned integer and rotate in the left direction + /// # Arguments + /// * `x` - rotate its bit representation in the leftward direction + /// * `n` - number of steps to rotate + /// # Returns + /// * `T` - the result of rotating the bits of number `x` left, `n` number of steps + fn rotate_left(x: T, n: T) -> T; + /// Take the bits of an unsigned integer and rotate in the right direction + /// # Arguments + /// * `x` - rotate its bit representation in the rightward direction + /// * `n` - number of steps to rotate + /// # Returns + /// * `T` - the result of rotating the bits of number `x` right, `n` number of steps + fn rotate_right(x: T, n: T) -> T; +} + +impl U8BitRotate of BitRotate { + fn rotate_left(x: u8, n: u8) -> u8 { + let word = u8_wide_mul(x, pow(2, n)); + let (quotient, remainder) = DivRem::div_rem(word, 0x100_u16.try_into().unwrap()); + (quotient + remainder).try_into().unwrap() + } + + fn rotate_right(x: u8, n: u8) -> u8 { + let step = pow(2, n); + let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap()); + remainder * pow(2, 8 - n) + quotient + } +} + +impl U16BitRotate of BitRotate { + fn rotate_left(x: u16, n: u16) -> u16 { + let word = u16_wide_mul(x, pow(2, n)); + let (quotient, remainder) = DivRem::div_rem(word, 0x10000_u32.try_into().unwrap()); + (quotient + remainder).try_into().unwrap() + } + + fn rotate_right(x: u16, n: u16) -> u16 { + let step = pow(2, n); + let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap()); + remainder * pow(2, 16 - n) + quotient + } +} + +impl U32BitRotate of BitRotate { + fn rotate_left(x: u32, n: u32) -> u32 { + let word = u32_wide_mul(x, pow(2, n)); + let (quotient, remainder) = DivRem::div_rem(word, 0x100000000_u64.try_into().unwrap()); + (quotient + remainder).try_into().unwrap() + } + + fn rotate_right(x: u32, n: u32) -> u32 { + let step = pow(2, n); + let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap()); + remainder * pow(2, 32 - n) + quotient + } +} + +impl U64BitRotate of BitRotate { + fn rotate_left(x: u64, n: u64) -> u64 { + let word = u64_wide_mul(x, pow(2, n)); + let (quotient, remainder) = DivRem::div_rem( + word, 0x10000000000000000_u128.try_into().unwrap() + ); + (quotient + remainder).try_into().unwrap() + } + + fn rotate_right(x: u64, n: u64) -> u64 { + let step = pow(2, n); + let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap()); + remainder * pow(2, 64 - n) + quotient + } +} + +impl U128BitRotate of BitRotate { + fn rotate_left(x: u128, n: u128) -> u128 { + let (high, low) = u128_wide_mul(x, pow(2, n)); + let word = u256 { low, high }; + let (quotient, remainder) = DivRem::div_rem( + word, u256 { low: 0, high: 1 }.try_into().unwrap() + ); + (quotient + remainder).try_into().unwrap() + } + + fn rotate_right(x: u128, n: u128) -> u128 { + let step = pow(2, n); + let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap()); + remainder * pow(2, 128 - n) + quotient + } +} + +impl U256BitRotate of BitRotate { + fn rotate_left(x: u256, n: u256) -> u256 { + // Alternative solution since we cannot divide u512 yet + BitShift::shl(x, n) + BitShift::shr(x, 256 - n) + } + + fn rotate_right(x: u256, n: u256) -> u256 { + let step = pow(2, n); + let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap()); + remainder * pow(2, 256 - n) + quotient + } +} + diff --git a/src/math/src/tests/math_test.cairo b/src/math/src/tests/math_test.cairo index 0e07cfb8..d90bef4d 100644 --- a/src/math/src/tests/math_test.cairo +++ b/src/math/src/tests/math_test.cairo @@ -1,4 +1,4 @@ -use alexandria_math::{pow, BitShift, count_digits_of_base}; +use alexandria_math::{pow, BitShift, BitRotate, count_digits_of_base}; use integer::BoundedInt; // Test power function @@ -135,7 +135,6 @@ fn test_pow_power_2_all() { assert(pow::(2, 127) == 170141183460469231731687303715884105728, '127'); } - #[test] #[available_gas(2000000)] fn pow_test() { @@ -146,7 +145,6 @@ fn pow_test() { assert(pow::(3, 8) == 6561, '3^8_u256'); } - // Test counting of number of digits function #[test] #[available_gas(2000000)] @@ -170,3 +168,57 @@ fn shl_should_not_overflow() { assert(BitShift::shl(pow::(2, 127), 1) == 0, 'invalid result'); assert(BitShift::shl(pow::(2, 255), 1) == 0, 'invalid result'); } + +#[test] +#[available_gas(3000000)] +fn test_rotl_min() { + assert(BitRotate::rotate_left(pow::(2, 7) + 1, 1) == 3, 'invalid result'); + assert(BitRotate::rotate_left(pow::(2, 15) + 1, 1) == 3, 'invalid result'); + assert(BitRotate::rotate_left(pow::(2, 31) + 1, 1) == 3, 'invalid result'); + assert(BitRotate::rotate_left(pow::(2, 63) + 1, 1) == 3, 'invalid result'); + assert(BitRotate::rotate_left(pow::(2, 127) + 1, 1) == 3, 'invalid result'); + assert(BitRotate::rotate_left(pow::(2, 255) + 1, 1) == 3, 'invalid result'); +} + +#[test] +#[available_gas(3000000)] +fn test_rotl_max() { + assert(BitRotate::rotate_left(0b101, 7) == pow::(2, 7) + 0b10, 'invalid result'); + assert(BitRotate::rotate_left(0b101, 15) == pow::(2, 15) + 0b10, 'invalid result'); + assert(BitRotate::rotate_left(0b101, 31) == pow::(2, 31) + 0b10, 'invalid result'); + assert(BitRotate::rotate_left(0b101, 63) == pow::(2, 63) + 0b10, 'invalid result'); + assert(BitRotate::rotate_left(0b101, 127) == pow::(2, 127) + 0b10, 'invalid result'); + assert(BitRotate::rotate_left(0b101, 255) == pow::(2, 255) + 0b10, 'invalid result'); +} + +#[test] +#[available_gas(4000000)] +fn test_rotr_min() { + assert(BitRotate::rotate_right(pow::(2, 7) + 1, 1) == 0b11 * pow(2, 6), 'invalid result'); + assert( + BitRotate::rotate_right(pow::(2, 15) + 1, 1) == 0b11 * pow(2, 14), 'invalid result' + ); + assert( + BitRotate::rotate_right(pow::(2, 31) + 1, 1) == 0b11 * pow(2, 30), 'invalid result' + ); + assert( + BitRotate::rotate_right(pow::(2, 63) + 1, 1) == 0b11 * pow(2, 62), 'invalid result' + ); + assert( + BitRotate::rotate_right(pow::(2, 127) + 1, 1) == 0b11 * pow(2, 126), 'invalid result' + ); + assert( + BitRotate::rotate_right(pow::(2, 255) + 1, 1) == 0b11 * pow(2, 254), 'invalid result' + ); +} + +#[test] +#[available_gas(2000000)] +fn test_rotr_max() { + assert(BitRotate::rotate_right(0b101_u8, 7) == 0b1010, 'invalid result'); + assert(BitRotate::rotate_right(0b101_u16, 15) == 0b1010, 'invalid result'); + assert(BitRotate::rotate_right(0b101_u32, 31) == 0b1010, 'invalid result'); + assert(BitRotate::rotate_right(0b101_u64, 63) == 0b1010, 'invalid result'); + assert(BitRotate::rotate_right(0b101_u128, 127) == 0b1010, 'invalid result'); + assert(BitRotate::rotate_right(0b101_u256, 255) == 0b1010, 'invalid result'); +}