diff --git a/Cargo.lock b/Cargo.lock index 273b3a7..57ddbe9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -256,7 +256,9 @@ dependencies = [ name = "cainome-cairo-serde" version = "0.1.0" dependencies = [ + "num-bigint", "serde", + "serde_with", "starknet", "thiserror", ] @@ -1655,9 +1657,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.9.0" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cecfa94848272156ea67b2b1a53f20fc7bc638c4a46d2f8abde08f05f4b857" +checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" dependencies = [ "base64 0.22.1", "chrono", @@ -1673,9 +1675,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.9.0" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8fee4991ef4f274617a51ad4af30519438dacb2f56ac773b08a1922ff743350" +checksum = "9d846214a9854ef724f3da161b426242d8de7c1fc7de2f89bb1efcb154dca79d" dependencies = [ "darling", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index a6937e7..ccaa231 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,13 +21,14 @@ cainome-rs = { path = "crates/rs" } anyhow = "1.0" async-trait = "0.1" -camino = { version = "1.1", features = [ "serde1" ] } +num-bigint = "0.4.6" +camino = { version = "1.1", features = ["serde1"] } convert_case = "0.6" serde = { version = "1.0", default-features = false, features = ["alloc"] } serde_json = { version = "1.0", default-features = false, features = ["std"] } thiserror = "1.0" tracing = "0.1" -tracing-subscriber = { version = "0.3", features = [ "env-filter", "json" ] } +tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } url = "2.5" starknet = "0.12" starknet-types-core = "0.1.6" @@ -42,7 +43,7 @@ cainome-rs-macro = { path = "crates/rs-macro", optional = true } async-trait.workspace = true anyhow.workspace = true -clap = { version = "4.5", features = [ "derive" ] } +clap = { version = "4.5", features = ["derive"] } clap_complete = "4.5" convert_case.workspace = true serde.workspace = true diff --git a/crates/cairo-serde/Cargo.toml b/crates/cairo-serde/Cargo.toml index 7fb9f3f..7995d95 100644 --- a/crates/cairo-serde/Cargo.toml +++ b/crates/cairo-serde/Cargo.toml @@ -9,3 +9,5 @@ edition = "2021" starknet.workspace = true thiserror.workspace = true serde.workspace = true +serde_with = { version = "3.11.0", default-features = false } +num-bigint.workspace = true diff --git a/crates/cairo-serde/src/types/u256.rs b/crates/cairo-serde/src/types/u256.rs index f5e7f62..9b04479 100644 --- a/crates/cairo-serde/src/types/u256.rs +++ b/crates/cairo-serde/src/types/u256.rs @@ -1,8 +1,15 @@ use crate::CairoSerde; +use num_bigint::{BigInt, BigUint, ParseBigIntError}; +use serde_with::{DeserializeAs, DisplayFromStr, SerializeAs}; use starknet::core::types::Felt; -use std::cmp::Ordering; +use std::{ + cmp::Ordering, + fmt::Display, + ops::{Add, BitOr, Sub}, + str::FromStr, +}; -#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct U256 { pub low: u128, pub high: u128, @@ -17,6 +24,104 @@ impl PartialOrd for U256 { } } +impl Add for U256 { + type Output = Self; + fn add(mut self, other: Self) -> Self { + let (low, overflow_low) = self.low.overflowing_add(other.low); + if overflow_low { + self.high += 1; + } + let (high, _overflow_high) = self.high.overflowing_add(other.high); + U256 { low, high } + } +} + +impl Sub for U256 { + type Output = Self; + fn sub(self, other: Self) -> Self { + let (low, overflow_low) = self.low.overflowing_sub(other.low); + let (high, overflow_high) = self.high.overflowing_sub(other.high); + if overflow_high { + panic!("High underflow"); + } + let final_high = if overflow_low { + if high == 0 { + panic!("High underflow"); + } + high.wrapping_sub(1) + } else { + high + }; + U256 { + low, + high: final_high, + } + } +} + +impl BitOr for U256 { + type Output = Self; + + fn bitor(self, other: Self) -> Self { + U256 { + low: self.low | other.low, + high: self.high | other.high, + } + } +} + +impl Display for U256 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut num = BigUint::from(0u128); + num = num + BigUint::from(self.high); + num = num << 128; + num = num + BigUint::from(self.low); + write!(f, "{}", num) + } +} + +impl FromStr for U256 { + type Err = ParseBigIntError; + fn from_str(s: &str) -> Result { + let num = BigInt::from_str(s)?; + let num_big_uint = num.to_biguint().unwrap(); + let mask = (BigUint::from(1u128) << 128u32) - BigUint::from(1u128); + let b_low: BigUint = (num_big_uint.clone() >> 0) & mask.clone(); + let b_high: BigUint = (num_big_uint.clone() >> 128) & mask.clone(); + + let mut low = 0; + let mut high = 0; + + for (i, digit) in b_low.to_u64_digits().iter().take(2).enumerate() { + low |= (*digit as u128) << (i * 64); + } + + for (i, digit) in b_high.to_u64_digits().iter().take(2).enumerate() { + high |= (*digit as u128) << (i * 64); + } + + Ok(U256 { low, high }) + } +} + +impl serde::Serialize for U256 { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + DisplayFromStr::serialize_as(self, serializer) + } +} + +impl<'de> serde::Deserialize<'de> for U256 { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + DisplayFromStr::deserialize_as(deserializer) + } +} + impl CairoSerde for U256 { type RustType = Self; @@ -106,6 +211,111 @@ mod tests { assert_eq!(felts[1], Felt::from(8_u128)); } + #[test] + fn test_add_u256_low_overflow() { + let u256_1 = U256 { + low: u128::MAX, + high: 1_u128, + }; + let u256_2 = U256 { + low: 1_u128, + high: 2_u128, + }; + let u256_3 = u256_1 + u256_2; + assert_eq!(u256_3.low, 0_u128); + assert_eq!(u256_3.high, 4_u128); + } + + #[test] + fn test_add_u256_high_overflow() { + let u256_1 = U256 { + low: 0_u128, + high: u128::MAX, + }; + let u256_2 = U256 { + low: 0_u128, + high: 1_u128, + }; + + let u256_3 = u256_1 + u256_2; + + assert_eq!(u256_3.low, 0_u128); + assert_eq!(u256_3.high, 0_u128); + } + + #[test] + fn test_sub_u256() { + let u256_1 = U256 { + low: 1_u128, + high: 2_u128, + }; + let u256_2 = U256 { + low: 0_u128, + high: 1_u128, + }; + let u256_3 = u256_1 - u256_2; + assert_eq!(u256_3.low, 1_u128); + assert_eq!(u256_3.high, 1_u128); + } + + #[test] + fn test_sub_u256_underflow_low() { + let u256_1 = U256 { + low: 0_u128, + high: 1_u128, + }; + let u256_2 = U256 { + low: 2_u128, + high: 0_u128, + }; + let u256_3 = u256_1 - u256_2; + assert_eq!(u256_3.low, u128::MAX - 1); + assert_eq!(u256_3.high, 0_u128); + } + + #[test] + #[should_panic] + fn test_sub_u256_underflow_high() { + let u256_1 = U256 { + low: 0_u128, + high: 1_u128, + }; + let u256_2 = U256 { + low: 0_u128, + high: 2_u128, + }; + let _u256_3 = u256_1 - u256_2; + } + + #[test] + #[should_panic] + fn test_sub_u256_underflow_high_2() { + let u256_1 = U256 { + low: 10_u128, + high: 2_u128, + }; + let u256_2 = U256 { + low: 11_u128, + high: 2_u128, + }; + let _u256_3 = u256_1 - u256_2; + } + + #[test] + fn test_bit_or_u256() { + let u256_1 = U256 { + low: 0b1010_u128, + high: 0b1100_u128, + }; + let u256_2 = U256 { + low: 0b0110_u128, + high: 0b0011_u128, + }; + let u256_3 = u256_1 | u256_2; + assert_eq!(u256_3.low, 0b1110_u128); + assert_eq!(u256_3.high, 0b1111_u128) + } + #[test] fn test_serialize_u256_max() { let low = u128::MAX; @@ -126,6 +336,23 @@ mod tests { assert_eq!(felts[1], Felt::from(u128::MIN)); } + #[test] + fn test_display_u256() { + let u256 = U256 { + low: 12_u128, + high: 0_u128, + }; + println!("{}", u256); + assert_eq!(format!("{}", u256), "12"); + } + + #[test] + fn test_from_str() { + let u256 = U256::from_str("18446744073709551616").unwrap(); + assert_eq!(u256.low, 18446744073709551616_u128); + assert_eq!(u256.high, 0_u128); + } + #[test] fn test_deserialize_u256() { let felts = vec![Felt::from(9_u128), Felt::from(8_u128)];