From a08da074cecee804005bfd7eb3e46f02ac2ee5cc Mon Sep 17 00:00:00 2001 From: Ho Date: Thu, 18 Apr 2024 16:15:55 +0900 Subject: [PATCH] Update zktrie with native rust lib (#1198) * update dep for new zktrie (rust-zktrie) Signed-off-by: noelwei * update dep * udpate zktrie Signed-off-by: noelwei * update zktrie dep --------- Signed-off-by: noelwei Co-authored-by: Zhuo Zhang --- Cargo.lock | 30 ++++++++++++++++++++++++-- zktrie/Cargo.toml | 3 +-- zktrie/src/state/builder.rs | 43 ++++++++----------------------------- 3 files changed, 38 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 91a08f68cd..34ed5ceb86 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3295,6 +3295,17 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-derive" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "num-integer" version = "0.1.46" @@ -6125,10 +6136,25 @@ dependencies = [ [[package]] name = "zktrie" -version = "0.2.0" -source = "git+https://github.com/scroll-tech/zktrie.git?tag=v0.7.1#a12f2f262ad3e82301e39ecdf9bfe235befc7074" +version = "0.3.0" +source = "git+https://github.com/scroll-tech/zktrie.git?tag=v0.8.0#dd9316ef82571c599ddd7ca8ffffa496270008c6" dependencies = [ "gobuild", + "zktrie_rust", +] + +[[package]] +name = "zktrie_rust" +version = "0.3.0" +source = "git+https://github.com/scroll-tech/zktrie.git?tag=v0.8.0#dd9316ef82571c599ddd7ca8ffffa496270008c6" +dependencies = [ + "hex", + "lazy_static", + "num", + "num-derive", + "num-traits", + "strum 0.24.1", + "strum_macros 0.24.3", ] [[package]] diff --git a/zktrie/Cargo.toml b/zktrie/Cargo.toml index 9845c7f7f6..63eac7d216 100644 --- a/zktrie/Cargo.toml +++ b/zktrie/Cargo.toml @@ -9,7 +9,7 @@ license.workspace = true [dependencies] halo2_proofs.workspace = true mpt-circuits = { package = "halo2-mpt-circuits", git = "https://github.com/scroll-tech/mpt-circuit.git", branch = "v0.7", default-features=false } -zktrie = { git = "https://github.com/scroll-tech/zktrie.git", tag = "v0.7.1" } +zktrie = { git = "https://github.com/scroll-tech/zktrie.git", tag = "v0.8.0", features= ["rs_zktrie"] } hash-circuit.workspace = true eth-types = { path = "../eth-types" } num-bigint.workspace = true @@ -24,4 +24,3 @@ serde_json.workspace = true [features] default = [] parallel_syn = ["mpt-circuits/parallel_syn"] - diff --git a/zktrie/src/state/builder.rs b/zktrie/src/state/builder.rs index 07b7840e14..f8311534e3 100644 --- a/zktrie/src/state/builder.rs +++ b/zktrie/src/state/builder.rs @@ -16,55 +16,30 @@ use hash_circuit::hash::Hashable; pub fn init_hash_scheme() { static INIT: Once = Once::new(); INIT.call_once(|| { - zktrie::init_hash_scheme(hash_scheme); + zktrie::init_hash_scheme_simple(poseidon_hash_scheme); }); } -static FILED_ERROR_READ: &str = "invalid input field"; -static FILED_ERROR_OUT: &str = "output field fail"; - -extern "C" fn hash_scheme( - a: *const u8, - b: *const u8, - domain: *const u8, - out: *mut u8, -) -> *const i8 { - use std::slice; - let a: [u8; 32] = - TryFrom::try_from(unsafe { slice::from_raw_parts(a, 32) }).expect("length specified"); - let b: [u8; 32] = - TryFrom::try_from(unsafe { slice::from_raw_parts(b, 32) }).expect("length specified"); - let domain: [u8; 32] = - TryFrom::try_from(unsafe { slice::from_raw_parts(domain, 32) }).expect("length specified"); - let out = unsafe { slice::from_raw_parts_mut(out, 32) }; - - let fa = Fr::from_bytes(&a); +fn poseidon_hash_scheme(a: &[u8; 32], b: &[u8; 32], domain: &[u8; 32]) -> Option<[u8; 32]> { + let fa = Fr::from_bytes(a); let fa = if fa.is_some().into() { fa.unwrap() } else { - return FILED_ERROR_READ.as_ptr().cast(); + return None; }; - let fb = Fr::from_bytes(&b); + let fb = Fr::from_bytes(b); let fb = if fb.is_some().into() { fb.unwrap() } else { - return FILED_ERROR_READ.as_ptr().cast(); + return None; }; - let fdomain = Fr::from_bytes(&domain); + let fdomain = Fr::from_bytes(domain); let fdomain = if fdomain.is_some().into() { fdomain.unwrap() } else { - return FILED_ERROR_READ.as_ptr().cast(); + return None; }; - - let h = Fr::hash_with_domain([fa, fb], fdomain); - let repr_h = h.to_repr(); - if repr_h.len() == 32 { - out.copy_from_slice(repr_h.as_ref()); - std::ptr::null() - } else { - FILED_ERROR_OUT.as_ptr().cast() - } + Some(Fr::hash_with_domain([fa, fb], fdomain).to_repr()) } pub(crate) const NODE_TYPE_MIDDLE_0: u8 = 6;