diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index b5d6486..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.gitignore b/.gitignore index 640ba97..88ec6ff 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,5 @@ db/ *.lock *.tgz .vscode/ + +.DS_Store \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 194fc60..f4fe275 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,9 +36,9 @@ dependencies = [ "ark-std 0.3.0", "bitstream-io", "c-kzg", - "ctor 0.1.26", + "ctor", "encoder", - "env_logger 0.10.2", + "env_logger", "eth-types", "ethers-core 2.0.7 (git+https://github.com/scroll-tech/ethers-rs.git?branch=v2.0.7)", "gadgets", @@ -518,22 +518,22 @@ dependencies = [ [[package]] name = "bindgen" -version = "0.64.0" +version = "0.69.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4243e6031260db77ede97ad86c27e501d646a27ab57b59a574f725d98ab1fb4" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.5.0", "cexpr", "clang-sys", + "itertools 0.12.1", "lazy_static", "lazycell", - "peeking_take_while", "proc-macro2", "quote", "regex", "rustc-hash", "shlex", - "syn 1.0.109", + "syn 2.0.66", ] [[package]] @@ -1062,16 +1062,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "ctor" -version = "0.2.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" -dependencies = [ - "quote", - "syn 2.0.66", -] - [[package]] name = "ctr" version = "0.9.2" @@ -1280,16 +1270,6 @@ dependencies = [ "syn 2.0.66", ] -[[package]] -name = "env_filter" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea" -dependencies = [ - "log", - "regex", -] - [[package]] name = "env_logger" version = "0.10.2" @@ -1303,19 +1283,6 @@ dependencies = [ "termcolor", ] -[[package]] -name = "env_logger" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b35839ba51819680ba087cd351788c9a3c476841207e0b8cee0b04722343b9" -dependencies = [ - "anstream", - "anstyle", - "env_filter", - "humantime", - "log", -] - [[package]] name = "equivalent" version = "1.0.1" @@ -1667,16 +1634,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "fs2" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "funty" version = "2.0.0" @@ -1819,7 +1776,7 @@ name = "geth-utils" version = "0.13.0" source = "git+https://github.com/scroll-tech/zkevm-circuits.git?tag=v0.13.1#4009e5593f13ba73f64f556011ee5ef47bc4ebf3" dependencies = [ - "env_logger 0.10.2", + "env_logger", "gobuild", "log", ] @@ -2565,14 +2522,13 @@ checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "librocksdb-sys" -version = "0.10.0+7.9.2" +version = "0.17.1+9.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fe4d5874f5ff2bc616e55e8c6086d478fcda13faf9495768a4aa1c22042d30b" +checksum = "2b7869a512ae9982f4d46ba482c2a304f1efd80c6412a3d4bf57bb79a619679f" dependencies = [ "bindgen", "bzip2-sys", "cc", - "glob", "libc", "libz-sys", "lz4-sys", @@ -3152,12 +3108,6 @@ dependencies = [ "hmac", ] -[[package]] -name = "peeking_take_while" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" - [[package]] name = "percent-encoding" version = "2.3.1" @@ -3798,9 +3748,9 @@ dependencies = [ [[package]] name = "rocksdb" -version = "0.20.1" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "015439787fce1e75d55f279078d33ff14b4af5d93d995e8838ee4631301c8a99" +checksum = "26ec73b20525cb235bad420f911473b69f9fe27cc856c5461bccd7e4af037f43" dependencies = [ "libc", "librocksdb-sys", @@ -4021,20 +3971,13 @@ dependencies = [ "anyhow", "async-trait", "axum", - "base64 0.13.1", "clap", - "ctor 0.2.8", "dotenv", - "env_logger 0.11.3", - "eth-keystore", "ethers-core 2.0.7 (git+https://github.com/scroll-tech/ethers-rs.git?branch=v2.0.7)", "ethers-providers 2.0.7 (git+https://github.com/scroll-tech/ethers-rs.git?branch=v2.0.7)", - "futures", - "halo2_proofs", "hex", "http 1.1.0", "log", - "once_cell", "prover", "rand", "reqwest 0.12.4", @@ -4044,8 +3987,6 @@ dependencies = [ "rocksdb", "serde", "serde_json", - "sled", - "snark-verifier-sdk", "tiny-keccak", "tokio", "tracing", @@ -4375,22 +4316,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "sled" -version = "0.34.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f96b4737c2ce5987354855aed3797279def4ebf734436c6aa4552cf8e169935" -dependencies = [ - "crc32fast", - "crossbeam-epoch", - "crossbeam-utils", - "fs2", - "fxhash", - "libc", - "log", - "parking_lot 0.11.2", -] - [[package]] name = "smallvec" version = "1.13.2" @@ -5491,7 +5416,7 @@ dependencies = [ "array-init", "bus-mapping", "either", - "env_logger 0.10.2", + "env_logger", "eth-types", "ethers-core 2.0.7 (git+https://github.com/scroll-tech/ethers-rs.git?branch=v2.0.7)", "ethers-signers", diff --git a/Cargo.toml b/Cargo.toml index 5d5ec8b..d3194d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,34 +16,24 @@ bls12_381 = { git = "https://github.com/scroll-tech/bls12_381", branch = "feat/i [dependencies] anyhow = "1.0" log = "0.4" -env_logger = "0.11.3" serde = { version = "1.0.198", features = ["derive"] } serde_json = "1.0.116" -futures = "0.3.30" ethers-core = { git = "https://github.com/scroll-tech/ethers-rs.git", branch = "v2.0.7" } ethers-providers = { git = "https://github.com/scroll-tech/ethers-rs.git", branch = "v2.0.7" } -halo2_proofs = { git = "https://github.com/scroll-tech/halo2.git", branch = "v1.1" } -snark-verifier-sdk = { git = "https://github.com/scroll-tech/snark-verifier", branch = "develop", default-features = false, features = ["loader_halo2", "loader_evm", "halo2-pse"] } -# prover_darwin = { git = "https://github.com/scroll-tech/zkevm-circuits.git", tag = "v0.12.2", package = "prover", default-features = false, features = ["parallel_syn", "scroll"] } prover_darwin_v2 = { git = "https://github.com/scroll-tech/zkevm-circuits.git", tag = "v0.13.1", package = "prover", default-features = false, features = ["parallel_syn", "scroll"] } -base64 = "0.13.1" reqwest = { version = "0.12.4", features = ["gzip"] } reqwest-middleware = "0.3" reqwest-retry = "0.5" -once_cell = "1.19.0" hex = "0.4.3" tiny-keccak = { version = "2.0.0", features = ["sha3", "keccak"] } rand = "0.8.5" -eth-keystore = "0.5.0" rlp = "0.5.2" tokio = { version = "1.37.0", features = ["full"] } async-trait = "0.1" -sled = "0.34.7" http = "1.1.0" clap = { version = "4.5", features = ["derive"] } -ctor = "0.2.8" tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } axum = "0.6.0" dotenv = "0.15" -rocksdb = "0.20" +rocksdb = "0.23.0" diff --git a/conf/config.json b/conf/config.json index 657207b..bf9002d 100644 --- a/conf/config.json +++ b/conf/config.json @@ -1,26 +1,38 @@ { - "prover_name_prefix": "cloud_prover_", - "keys_dir": "keys", - "coordinator": { - "base_url": "https://coordinator-api.scrollsdk", - "retry_count": 3, - "retry_wait_time_sec": 5, - "connection_timeout_sec": 60 - }, - "l2geth": { - "endpoint": "https://l2-rpc.scrollsdk" - }, - "prover": { - "circuit_type": 3, - "circuit_version": "v0.13.1", - "n_workers": 1, - "cloud": { - "base_url": "", - "api_key": "", + "prover_name_prefix": "prover_name", + "keys_dir": "keys", + "coordinator": { + "base_url": "https://coordinator-api.scrollsdk", "retry_count": 3, "retry_wait_time_sec": 5, "connection_timeout_sec": 60 - } - }, - "db_path": "db" -} + }, + "l2geth": { + "endpoint": "https://l2-rpc.scrollsdk" + }, + "prover": { + "circuit_type": [1,2,3], + "circuit_version": "v0.13.1", + "cloud": { + "base_url": "", + "api_key": "", + "retry_count": 3, + "retry_wait_time_sec": 5, + "connection_timeout_sec": 60 + }, + "local": { + "low_version_circuit": { + "hard_fork_name": "bernoulli", + "params_path": "params", + "assets_path": "assets" + }, + "high_version_circuit": { + "hard_fork_name": "curie", + "params_path": "params", + "assets_path": "assets" + } + } + }, + "db_path": "db" + } + \ No newline at end of file diff --git a/examples/cloud.rs b/examples/cloud.rs index 8150089..fbc312e 100644 --- a/examples/cloud.rs +++ b/examples/cloud.rs @@ -1,9 +1,11 @@ +use anyhow::{anyhow, Result}; use async_trait::async_trait; use clap::Parser; use reqwest::Url; +use std::fs::File; use scroll_proving_sdk::{ - config::{CloudProverConfig, Config}, + config::Config as SdkConfig, prover::{ proving_service::{ GetVkRequest, GetVkResponse, ProveRequest, ProveResponse, QueryTaskRequest, @@ -13,6 +15,7 @@ use scroll_proving_sdk::{ }, utils::init_tracing, }; +use serde::{Deserialize, Serialize}; #[derive(Parser, Debug)] #[clap(disable_version_flag = true)] @@ -22,6 +25,52 @@ struct Args { config_file: String, } +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct CloudProverConfig { + pub sdk_config: SdkConfig, + pub base_url: String, + pub api_key: String, +} + +impl CloudProverConfig { + pub fn from_reader(reader: R) -> Result + where + R: std::io::Read, + { + serde_json::from_reader(reader).map_err(|e| anyhow!(e)) + } + + pub fn from_file(file_name: String) -> Result { + let file = File::open(file_name)?; + Self::from_reader(&file) + } + + fn get_env_var(key: &str) -> Result> { + std::env::var_os(key) + .map(|val| { + val.to_str() + .ok_or_else(|| anyhow!("{key} env var is not valid UTF-8")) + .map(String::from) + }) + .transpose() + } + + pub fn from_file_and_env(file_name: String) -> Result { + let mut cfg = Self::from_file(file_name)?; + cfg.sdk_config.override_with_env()?; + + if let Some(val) = Self::get_env_var("PROVING_SERVICE_BASE_URL")? { + cfg.base_url = val; + } + + if let Some(val) = Self::get_env_var("PROVING_SERVICE_API_KEY")? { + cfg.api_key = val; + } + + Ok(cfg) + } +} + struct CloudProver { base_url: Url, api_key: String, @@ -32,7 +81,7 @@ impl ProvingService for CloudProver { fn is_local(&self) -> bool { false } - async fn get_vk(&self, req: GetVkRequest) -> GetVkResponse { + async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse { todo!() } async fn prove(&self, req: ProveRequest) -> ProveResponse { @@ -58,14 +107,10 @@ async fn main() -> anyhow::Result<()> { init_tracing(); let args = Args::parse(); - let cfg: Config = Config::from_file_and_env(args.config_file)?; - let cloud_prover = CloudProver::new( - cfg.prover - .cloud - .clone() - .ok_or_else(|| anyhow::anyhow!("Missing cloud prover configuration"))?, - ); - let prover = ProverBuilder::new(cfg) + let cfg = CloudProverConfig::from_file_and_env(args.config_file)?; + let sdk_config = cfg.sdk_config.clone(); + let cloud_prover = CloudProver::new(cfg); + let prover = ProverBuilder::new(sdk_config) .with_proving_service(Box::new(cloud_prover)) .build() .await?; diff --git a/examples/local.rs b/examples/local.rs index 0757a0e..0cb2726 100644 --- a/examples/local.rs +++ b/examples/local.rs @@ -1,8 +1,8 @@ +use anyhow::{anyhow, Result}; use async_trait::async_trait; use clap::Parser; - use scroll_proving_sdk::{ - config::{Config, LocalProverConfig}, + config::Config as SdkConfig, prover::{ proving_service::{ GetVkRequest, GetVkResponse, ProveRequest, ProveResponse, QueryTaskRequest, @@ -12,6 +12,8 @@ use scroll_proving_sdk::{ }, utils::init_tracing, }; +use serde::{Deserialize, Serialize}; +use std::fs::File; #[derive(Parser, Debug)] #[clap(disable_version_flag = true)] @@ -21,6 +23,33 @@ struct Args { config_file: String, } +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct LocalProverConfig { + pub sdk_config: SdkConfig, + pub conf1: String, + pub conf2: String, +} + +impl LocalProverConfig { + pub fn from_reader(reader: R) -> Result + where + R: std::io::Read, + { + serde_json::from_reader(reader).map_err(|e| anyhow!(e)) + } + + pub fn from_file(file_name: String) -> Result { + let file = File::open(file_name)?; + Self::from_reader(&file) + } + + pub fn from_file_and_env(file_name: String) -> Result { + let mut cfg = Self::from_file(file_name)?; + cfg.sdk_config.override_with_env()?; + Ok(cfg) + } +} + struct LocalProver {} #[async_trait] @@ -28,7 +57,7 @@ impl ProvingService for LocalProver { fn is_local(&self) -> bool { true } - async fn get_vk(&self, req: GetVkRequest) -> GetVkResponse { + async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse { todo!() } async fn prove(&self, req: ProveRequest) -> ProveResponse { @@ -50,9 +79,10 @@ async fn main() -> anyhow::Result<()> { init_tracing(); let args = Args::parse(); - let cfg: Config = Config::from_file(args.config_file)?; - let local_prover = LocalProver::new(cfg.prover.local.clone().unwrap()); - let prover = ProverBuilder::new(cfg) + let cfg = LocalProverConfig::from_file_and_env(args.config_file)?; + let sdk_config = cfg.sdk_config.clone(); + let local_prover = LocalProver::new(cfg); + let prover = ProverBuilder::new(sdk_config) .with_proving_service(Box::new(local_prover)) .build() .await?; diff --git a/src/.DS_Store b/src/.DS_Store deleted file mode 100644 index 595abd6..0000000 Binary files a/src/.DS_Store and /dev/null differ diff --git a/src/config.rs b/src/config.rs index 8829b02..5c37278 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,5 @@ use crate::prover::CircuitType; +use anyhow::{anyhow, Result}; use dotenv::dotenv; use serde::{Deserialize, Serialize}; use serde_json; @@ -16,10 +17,6 @@ pub struct Config { pub health_listener_addr: String, } -fn default_health_listener_addr() -> String { - "0.0.0.0:80".to_string() -} - #[derive(Debug, Serialize, Deserialize, Clone)] pub struct CoordinatorConfig { pub base_url: String, @@ -35,63 +32,52 @@ pub struct L2GethConfig { #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ProverConfig { - pub circuit_type: CircuitType, + pub circuit_types: Vec, pub circuit_version: String, + #[serde(default = "default_n_workers")] pub n_workers: usize, - pub cloud: Option, - pub local: Option, } - #[derive(Debug, Serialize, Deserialize, Clone)] -pub struct CloudProverConfig { - pub base_url: String, - pub api_key: String, - pub retry_count: u32, - pub retry_wait_time_sec: u64, - pub connection_timeout_sec: u64, -} +pub struct DbConfig {} -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct LocalProverConfig { - // TODO: - // params path - // assets path - // DB config +fn default_health_listener_addr() -> String { + "0.0.0.0:80".to_string() } -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct DbConfig {} +fn default_n_workers() -> usize { + 1 +} impl Config { - pub fn from_reader(reader: R) -> anyhow::Result + pub fn from_reader(reader: R) -> Result where R: std::io::Read, { - serde_json::from_reader(reader).map_err(|e| anyhow::anyhow!(e)) + serde_json::from_reader(reader).map_err(|e| anyhow!(e)) } - pub fn from_file(file_name: String) -> anyhow::Result { + pub fn from_file(file_name: String) -> Result { let file = File::open(file_name)?; Config::from_reader(&file) } - pub fn from_file_and_env(file_name: String) -> anyhow::Result { + pub fn from_file_and_env(file_name: String) -> Result { let mut cfg = Config::from_file(file_name)?; cfg.override_with_env()?; Ok(cfg) } - fn get_env_var(key: &str) -> anyhow::Result> { - Ok(std::env::var_os(key) + fn get_env_var(key: &str) -> Result> { + std::env::var_os(key) .map(|val| { val.to_str() - .ok_or_else(|| anyhow::anyhow!("{key} env var is not valid UTF-8")) + .ok_or_else(|| anyhow!("{key} env var is not valid UTF-8")) .map(String::from) }) - .transpose()?) + .transpose() } - fn override_with_env(&mut self) -> anyhow::Result<()> { + pub fn override_with_env(&mut self) -> Result<()> { dotenv().ok(); if let Some(val) = Self::get_env_var("PROVER_NAME_PREFIX")? { @@ -108,22 +94,28 @@ impl Config { l2geth.endpoint = val; } } - if let Some(val) = Self::get_env_var("CIRCUIT_TYPE")? { - self.prover.circuit_type = CircuitType::from_u8(val.parse()?); + if let Some(val) = Self::get_env_var("CIRCUIT_TYPES")? { + let values_vec: Vec<&str> = val + .trim_matches(|c| c == '[' || c == ']') + .split(',') + .map(|s| s.trim()) + .collect(); + + self.prover.circuit_types = values_vec + .iter() + .map(|value| match value.parse::() { + Ok(num) => CircuitType::from_u8(num), + Err(e) => { + panic!("Failed to parse circuit type: {}", e); + } + }) + .collect::>(); } + if let Some(val) = Self::get_env_var("N_WORKERS")? { self.prover.n_workers = val.parse()?; } - if let Some(val) = Self::get_env_var("PROVING_SERVICE_BASE_URL")? { - if let Some(cloud) = &mut self.prover.cloud { - cloud.base_url = val; - } - } - if let Some(val) = Self::get_env_var("PROVING_SERVICE_API_KEY")? { - if let Some(cloud) = &mut self.prover.cloud { - cloud.api_key = val; - } - } + if let Some(val) = Self::get_env_var("DB_PATH")? { self.db_path = Option::from(val); } diff --git a/src/coordinator_handler/coordinator_client.rs b/src/coordinator_handler/coordinator_client.rs index 82c8cbb..df48f03 100644 --- a/src/coordinator_handler/coordinator_client.rs +++ b/src/coordinator_handler/coordinator_client.rs @@ -2,14 +2,18 @@ use super::{ api::Api, error::ErrorCode, GetTaskRequest, GetTaskResponseData, KeySigner, LoginMessage, LoginRequest, Response, SubmitProofRequest, SubmitProofResponseData, }; -use crate::{config::CoordinatorConfig, prover::CircuitType, utils::get_version}; +use crate::{ + config::CoordinatorConfig, + prover::{CircuitType, ProverProviderType}, + utils::get_version, +}; use tokio::sync::{Mutex, MutexGuard}; pub struct CoordinatorClient { - circuit_type: CircuitType, + circuit_types: Vec, vks: Vec, - circuit_version: String, pub prover_name: String, + pub prover_provider_type: ProverProviderType, pub key_signer: KeySigner, api: Api, token: Mutex>, @@ -18,18 +22,18 @@ pub struct CoordinatorClient { impl CoordinatorClient { pub fn new( cfg: CoordinatorConfig, - circuit_type: CircuitType, + circuit_types: Vec, vks: Vec, - circuit_version: String, prover_name: String, + prover_provider_type: ProverProviderType, key_signer: KeySigner, ) -> anyhow::Result { let api = Api::new(cfg)?; let client = Self { - circuit_type, + circuit_types, vks, - circuit_version, prover_name, + prover_provider_type, key_signer, api, token: Mutex::new(None), @@ -107,15 +111,21 @@ impl CoordinatorClient { .as_ref() .ok_or_else(|| anyhow::anyhow!("Missing challenge token"))?; - let prover_types = match self.circuit_type { - CircuitType::Batch | CircuitType::Bundle => vec![CircuitType::Batch], // to conform to coordinator logic - _ => vec![self.circuit_type], - }; + let mut prover_types = vec![]; + if self.circuit_types.contains(&CircuitType::Bundle) + || self.circuit_types.contains(&CircuitType::Batch) + { + prover_types.push(CircuitType::Batch) + } + if self.circuit_types.contains(&CircuitType::Chunk) { + prover_types.push(CircuitType::Chunk) + } let login_message = LoginMessage { challenge: login_response_data.token.clone(), + prover_version: get_version().to_string(), prover_name: self.prover_name.clone(), - prover_version: get_version(&self.circuit_version).to_string(), + prover_provider_type: self.prover_provider_type, prover_types, vks: self.vks.clone(), }; diff --git a/src/coordinator_handler/key_signer.rs b/src/coordinator_handler/key_signer.rs index be93e54..c5ab112 100644 --- a/src/coordinator_handler/key_signer.rs +++ b/src/coordinator_handler/key_signer.rs @@ -58,6 +58,16 @@ impl KeySigner { }) } + pub fn new_from_secret_key(secret_key: &str) -> anyhow::Result { + let secret = hex::decode(secret_key).unwrap(); + let secret_key = SecretKey::from_bytes(secret.as_slice().into())?; + let signing_key = SigningKey::from(secret_key.clone()); + Ok(Self { + public_key: secret_key.public_key(), + signing_key, + }) + } + pub fn get_public_key(&self) -> String { let v: Vec = Vec::from(self.public_key.to_encoded_point(true).as_bytes()); buffer_to_hex(&v, false) diff --git a/src/coordinator_handler/types.rs b/src/coordinator_handler/types.rs index 85981f6..b93c594 100644 --- a/src/coordinator_handler/types.rs +++ b/src/coordinator_handler/types.rs @@ -1,5 +1,8 @@ use super::error::ErrorCode; -use crate::{prover::CircuitType, tracing_handler::CommonHash}; +use crate::{ + prover::{CircuitType, ProverProviderType}, + tracing_handler::CommonHash, +}; use rlp::{Encodable, RlpStream}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -13,19 +16,21 @@ pub struct Response { #[derive(Serialize, Deserialize)] pub struct LoginMessage { pub challenge: String, - pub prover_name: String, pub prover_version: String, + pub prover_name: String, + pub prover_provider_type: ProverProviderType, pub prover_types: Vec, pub vks: Vec, } impl Encodable for LoginMessage { fn rlp_append(&self, s: &mut RlpStream) { - let num_fields = 5; + let num_fields = 6; s.begin_list(num_fields); s.append(&self.challenge); s.append(&self.prover_version); s.append(&self.prover_name); + s.append(&(self.prover_provider_type as u8)); // The ProverType in go side is an type alias of uint8 // A uint8 slice is treated as a string when doing the rlp encoding let prover_types = self @@ -172,3 +177,44 @@ impl<'de> Deserialize<'de> for ProofStatus { Ok(ProofStatus::from_u8(v)) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{coordinator_handler::KeySigner, prover::types::ProverProviderType}; + + #[test] + fn test_prover_provider_type_encoding() { + // Test that ProverProviderType values match the coordinator's values + assert_eq!(ProverProviderType::Undefined as u8, 0); + assert_eq!(ProverProviderType::Internal as u8, 1); + assert_eq!(ProverProviderType::External as u8, 2); + } + + // This test uses the same private key as the coordinator's TestGenerateSignature + // to verify signature generation compatibility + #[test] + fn test_signature_compatibility() { + let private_key_hex = "8b8df68fddf7ee2724b79ccbd07799909d59b4dd4f4df3f6ecdc4fb8d56bdf4c"; + let key_signer = KeySigner::new_from_secret_key(private_key_hex).unwrap(); + + let login_message = LoginMessage { + challenge: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MjQ4Mzg0ODUsIm9yaWdfaWF0IjoxNzI0ODM0ODg1LCJyYW5kb20iOiJ6QmdNZGstNGc4UzNUNTFrVEFsYk1RTXg2TGJ4SUs4czY3ejM2SlNuSFlJPSJ9.x9PvihhNx2w4_OX5uCrv8QJCNYVQkIi-K2k8XFXYmik".to_string(), + prover_version: "v4.4.45-37af5ef5-38a68e2-1c5093c".to_string(), + prover_name: "test".to_string(), + prover_provider_type: ProverProviderType::Internal, + prover_types: vec![CircuitType::Chunk], + vks: vec!["mock_vk".to_string()], + }; + + let buffer = rlp::encode(&login_message); + let signature = key_signer + .sign_buffer(&buffer) + .map_err(|e| anyhow::anyhow!("Failed to sign the login message: {e}")) + .unwrap(); + + // expected signature from coordinator's TestGenerateSignature + let expected_signature = "0xb8659f094fde9ed697bd86b8d8a0a1cff902710d7750463858c8a9ff9e851b152240054f256ce9ea8a3eaf5f0d56ceed894b358d3505926dc6cfc36548f7001a01".to_string(); + assert_eq!(signature, expected_signature); + } +} diff --git a/src/db.rs b/src/db.rs index ad90f76..39c02b3 100644 --- a/src/db.rs +++ b/src/db.rs @@ -11,7 +11,29 @@ impl Db { Ok(Self { db }) } - pub fn get_coordinator_task_by_public_key( + pub fn get_task(&self, public_key: String) -> (Option, Option) { + ( + self.get_coordinator_task_by_public_key(public_key.clone()), + self.get_proving_task_id_by_public_key(public_key), + ) + } + + pub fn set_task( + &self, + public_key: String, + coordinator_task: &GetTaskResponseData, + proving_task_id: String, + ) { + self.set_coordinator_task_by_public_key(public_key.clone(), coordinator_task); + self.set_proving_task_id_by_public_key(public_key, proving_task_id); + } + + pub fn delete_task(&self, public_key: String) { + self.delete_coordinator_task_by_public_key(public_key.clone()); + self.delete_proving_task_id_by_public_key(public_key); + } + + fn get_coordinator_task_by_public_key( &self, public_key: String, ) -> Option { @@ -19,19 +41,17 @@ impl Db { .get(fmt_coordinator_task_key(public_key)) .ok()? .as_ref() - .map(|v| serde_json::from_slice(v).ok()) - .flatten() + .and_then(|v| serde_json::from_slice(v).ok()) } - pub fn get_proving_task_id_by_public_key(&self, public_key: String) -> Option { + fn get_proving_task_id_by_public_key(&self, public_key: String) -> Option { self.db .get(fmt_proving_task_id_key(public_key)) .ok()? - .map(|v| String::from_utf8(v).ok()) - .flatten() + .and_then(|v| String::from_utf8(v).ok()) } - pub fn set_coordinator_task_by_public_key( + fn set_coordinator_task_by_public_key( &self, public_key: String, coordinator_task: &GetTaskResponseData, @@ -40,18 +60,18 @@ impl Db { .map(|bytes| self.db.put(fmt_coordinator_task_key(public_key), bytes)); } - pub fn set_proving_task_id_by_public_key(&self, public_key: String, proving_task_id: String) { + fn set_proving_task_id_by_public_key(&self, public_key: String, proving_task_id: String) { let _ = self.db.put( fmt_proving_task_id_key(public_key), proving_task_id.as_bytes(), ); } - pub fn delete_coordinator_task_by_public_key(&self, public_key: String) { + fn delete_coordinator_task_by_public_key(&self, public_key: String) { let _ = self.db.delete(fmt_coordinator_task_key(public_key)); } - pub fn delete_proving_task_id_by_public_key(&self, public_key: String) { + fn delete_proving_task_id_by_public_key(&self, public_key: String) { let _ = self.db.delete(fmt_proving_task_id_key(public_key)); } } diff --git a/src/prover/builder.rs b/src/prover/builder.rs index 8a9cb4a..5c8f238 100644 --- a/src/prover/builder.rs +++ b/src/prover/builder.rs @@ -1,4 +1,4 @@ -use super::CircuitType; +use super::{CircuitType, ProverProviderType}; use crate::{ config::Config, coordinator_handler::{CoordinatorClient, KeySigner}, @@ -8,6 +8,7 @@ use crate::{ Prover, }, tracing_handler::L2gethClient, + utils::format_cloud_prover_name, }; use std::path::PathBuf; @@ -40,24 +41,31 @@ impl ProverBuilder { anyhow::bail!("cannot use multiple workers with local proving service"); } - if self.cfg.prover.circuit_type == CircuitType::Chunk && self.cfg.l2geth.is_none() { + if self.cfg.prover.circuit_types.contains(&CircuitType::Chunk) && self.cfg.l2geth.is_none() + { anyhow::bail!("circuit_type is chunk but l2geth config is not provided"); } let get_vk_request = GetVkRequest { - circuit_type: self.cfg.prover.circuit_type, + circuit_types: self.cfg.prover.circuit_types.clone(), circuit_version: self.cfg.prover.circuit_version.clone(), }; let get_vk_response = self .proving_service .as_ref() .unwrap() - .get_vk(get_vk_request) + .get_vks(get_vk_request) .await; if let Some(error) = get_vk_response.error { anyhow::bail!("failed to get vk: {}", error); } + let prover_provider_type = if self.proving_service.as_ref().unwrap().is_local() { + ProverProviderType::Internal + } else { + ProverProviderType::External + }; + let key_signers: Result, _> = (0..self.cfg.prover.n_workers) .map(|i| { let key_path = PathBuf::from(&self.cfg.keys_dir).join(i.to_string()); @@ -69,12 +77,18 @@ impl ProverBuilder { let coordinator_clients: Result, _> = (0..self.cfg.prover.n_workers) .map(|i| { + let prover_name = if self.proving_service.as_ref().unwrap().is_local() { + self.cfg.prover_name_prefix.clone() + } else { + format_cloud_prover_name(self.cfg.prover_name_prefix.clone(), i) + }; + CoordinatorClient::new( self.cfg.coordinator.clone(), - self.cfg.prover.circuit_type, - vec![get_vk_response.vk.clone()], - self.cfg.prover.circuit_version.clone(), - format!("{}{}", self.cfg.prover_name_prefix, i), + self.cfg.prover.circuit_types.clone(), + get_vk_response.vks.clone(), + prover_name, + prover_provider_type, key_signers[i].clone(), ) }) @@ -91,7 +105,7 @@ impl ProverBuilder { }); Ok(Prover { - circuit_type: self.cfg.prover.circuit_type, + circuit_types: self.cfg.prover.circuit_types.clone(), circuit_version: self.cfg.prover.circuit_version, coordinator_clients, l2geth_client, diff --git a/src/prover/mod.rs b/src/prover/mod.rs index 9c8ce73..188503b 100644 --- a/src/prover/mod.rs +++ b/src/prover/mod.rs @@ -22,7 +22,7 @@ pub use {builder::ProverBuilder, proving_service::ProvingService, types::*}; const WORKER_SLEEP_SEC: u64 = 20; pub struct Prover { - circuit_type: CircuitType, + circuit_types: Vec, circuit_version: String, coordinator_clients: Vec, l2geth_client: Option, @@ -35,7 +35,7 @@ pub struct Prover { impl Prover { pub async fn run(self) { assert!(self.n_workers == self.coordinator_clients.len()); - if self.circuit_type == CircuitType::Chunk { + if self.circuit_types.contains(&CircuitType::Chunk) { assert!(self.l2geth_client.is_some()); } @@ -86,12 +86,14 @@ impl Prover { } async fn handle_task(&self, coordinator_client: &CoordinatorClient) -> anyhow::Result<()> { - let public_key = coordinator_client.key_signer.get_public_key(); - if let (Some(coordinator_task), Some(proving_task_id)) = ( - self.db - .get_coordinator_task_by_public_key(public_key.clone()), - self.db.get_proving_task_id_by_public_key(public_key), - ) { + if let (Some(coordinator_task), Some(mut proving_task_id)) = self + .db + .get_task(coordinator_client.key_signer.get_public_key()) + { + if self.proving_service.is_local() { + let proving_task = self.request_proving(&coordinator_task).await?; + proving_task_id = proving_task.task_id + } return self .handle_proving_progress(coordinator_client, &coordinator_task, proving_task_id) .await; @@ -174,10 +176,9 @@ impl Prover { status = ?task.status, "Task status update" ); - self.db - .set_coordinator_task_by_public_key(public_key.clone(), coordinator_task); - self.db.set_proving_task_id_by_public_key( + self.db.set_task( public_key.clone(), + coordinator_task, proving_service_task_id.clone(), ); sleep(Duration::from_secs(WORKER_SLEEP_SEC)).await; @@ -199,10 +200,7 @@ impl Prover { None, ) .await?; - self.db - .delete_coordinator_task_by_public_key(public_key.clone()); - self.db - .delete_proving_task_id_by_public_key(public_key.clone()); + self.db.delete_task(public_key.clone()); break; } TaskStatus::Failed => { @@ -224,10 +222,7 @@ impl Prover { Some(task_err), ) .await?; - self.db - .delete_coordinator_task_by_public_key(public_key.clone()); - self.db - .delete_proving_task_id_by_public_key(public_key.clone()); + self.db.delete_task(public_key.clone()); break; } } @@ -271,12 +266,15 @@ impl Prover { None => None, Some(l2geth_client) => match l2geth_client.block_number().await { Ok(block_number) => block_number.as_number().map(|num| num.as_u64()), - Err(_) => None, + Err(_) => { + log::info!("Failed to get block number"); + None + } }, }; GetTaskRequest { - task_types: vec![self.circuit_type], + task_types: self.circuit_types.clone(), prover_height, } } @@ -286,9 +284,9 @@ impl Prover { task: &GetTaskResponseData, ) -> anyhow::Result { anyhow::ensure!( - task.task_type == self.circuit_type, - "task type mismatch. self: {:?}, task: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}", - self.circuit_type, + self.circuit_types.contains(&task.task_type), + "unsupported task type. self: {:?}, task: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}", + self.circuit_types, task.task_type, task.uuid, task.task_id @@ -300,13 +298,15 @@ impl Prover { } CircuitType::Chunk => { let chunk_task_detail: ChunkTaskDetail = serde_json::from_str(&task.task_data)?; - let traces = self + let serialized_traces = self .l2geth_client .as_ref() .unwrap() .get_sorted_traces_by_hashes(&chunk_task_detail.block_hashes) .await?; - let input = serde_json::to_string(&traces)?; + // Note: Manually join pre-serialized traces since they are already in JSON format. + // Using serde_json::to_string would escape the JSON strings, creating invalid nested JSON. + let input = format!("[{}]", serialized_traces.join(",")); Ok(ProveRequest { circuit_type: task.task_type, diff --git a/src/prover/proving_service.rs b/src/prover/proving_service.rs index 2625828..8eccebd 100644 --- a/src/prover/proving_service.rs +++ b/src/prover/proving_service.rs @@ -4,21 +4,22 @@ use async_trait::async_trait; #[async_trait] pub trait ProvingService { fn is_local(&self) -> bool; - async fn get_vk(&self, req: GetVkRequest) -> GetVkResponse; + async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse; async fn prove(&self, req: ProveRequest) -> ProveResponse; async fn query_task(&self, req: QueryTaskRequest) -> QueryTaskResponse; } pub struct GetVkRequest { - pub circuit_type: CircuitType, + pub circuit_types: Vec, pub circuit_version: String, } pub struct GetVkResponse { - pub vk: String, + pub vks: Vec, pub error: Option, } +#[derive(Clone)] pub struct ProveRequest { pub circuit_type: CircuitType, pub circuit_version: String, @@ -26,6 +27,7 @@ pub struct ProveRequest { pub input: String, } +#[derive(Default)] pub struct ProveResponse { pub task_id: String, pub circuit_type: CircuitType, @@ -46,6 +48,7 @@ pub struct QueryTaskRequest { pub task_id: String, } +#[derive(Default)] pub struct QueryTaskResponse { pub task_id: String, pub circuit_type: CircuitType, @@ -62,8 +65,9 @@ pub struct QueryTaskResponse { pub error: Option, } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Default)] pub enum TaskStatus { + #[default] Queued, Proving, Success, diff --git a/src/prover/types.rs b/src/prover/types.rs index e00e912..24fa157 100644 --- a/src/prover/types.rs +++ b/src/prover/types.rs @@ -1,7 +1,8 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] pub enum CircuitType { + #[default] Undefined, Chunk, Batch, @@ -51,3 +52,45 @@ impl<'de> Deserialize<'de> for CircuitType { Ok(CircuitType::from_u8(v)) } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +#[repr(u8)] +pub enum ProverProviderType { + #[default] + Undefined, + Internal, + External, +} + +impl ProverProviderType { + pub fn from_u8(v: u8) -> Self { + match v { + 1 => ProverProviderType::Internal, + 2 => ProverProviderType::External, + _ => ProverProviderType::Undefined, + } + } +} + +impl Serialize for ProverProviderType { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match *self { + ProverProviderType::Undefined => serializer.serialize_u8(0), + ProverProviderType::Internal => serializer.serialize_u8(1), + ProverProviderType::External => serializer.serialize_u8(2), + } + } +} + +impl<'de> Deserialize<'de> for ProverProviderType { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let v: u8 = u8::deserialize(deserializer)?; + Ok(ProverProviderType::from_u8(v)) + } +} diff --git a/src/tracing_handler.rs b/src/tracing_handler.rs index b1a980c..3522446 100644 --- a/src/tracing_handler.rs +++ b/src/tracing_handler.rs @@ -3,9 +3,7 @@ use ethers_core::types::BlockNumber; use ethers_core::types::H256; use ethers_providers::{Http, Provider}; use prover_darwin_v2::BlockTrace; -use serde::{de::DeserializeOwned, Serialize}; use std::cmp::Ordering; -use std::fmt::Debug; pub type CommonHash = H256; @@ -19,19 +17,18 @@ impl L2gethClient { Ok(Self { provider }) } - pub async fn get_block_trace_by_hash(&self, hash: &CommonHash) -> anyhow::Result - where - T: Serialize + DeserializeOwned + Debug + Send, - { + pub async fn get_block_trace_by_hash(&self, hash: &CommonHash) -> anyhow::Result { log::info!( "l2geth_client calling get_block_trace_by_hash, hash: {:#?}", hash ); - let trace = self + let trace: serde_json::Value = self .provider .request("scroll_getBlockTraceByNumberOrHash", [format!("{hash:#x}")]) .await?; + + let trace = serde_json::to_string(&trace)?; Ok(trace) } @@ -45,7 +42,7 @@ impl L2gethClient { pub async fn get_sorted_traces_by_hashes( &self, block_hashes: &[CommonHash], - ) -> anyhow::Result> { + ) -> anyhow::Result> { if block_hashes.is_empty() { log::error!("failed to get sorted traces: block_hashes are empty"); anyhow::bail!("block_hashes are empty") @@ -94,6 +91,7 @@ impl L2gethClient { } } -fn get_block_number_from_trace(block_trace: &BlockTrace) -> Option { +fn get_block_number_from_trace(block_trace: &String) -> Option { + let block_trace: BlockTrace = serde_json::from_str(block_trace).unwrap(); block_trace.header.number.map(|n| n.as_u64()) } diff --git a/src/utils.rs b/src/utils.rs index 760e018..7248f45 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,9 +1,22 @@ use tracing_subscriber::filter::{EnvFilter, LevelFilter}; -const SDK_VERSION: &str = env!("CARGO_PKG_VERSION"); +use std::cell::OnceCell; -pub fn get_version(circuit_version: &str) -> String { - format!("sdk-v{}-{}", SDK_VERSION, circuit_version) +static DEFAULT_COMMIT: &str = "unknown"; +static mut VERSION: OnceCell = OnceCell::new(); + +pub const TAG: &str = "v0.0.0"; +pub const DEFAULT_ZK_VERSION: &str = "000000-000000"; + +fn init_version() -> String { + let commit = option_env!("GIT_REV").unwrap_or(DEFAULT_COMMIT); + let tag = option_env!("GO_TAG").unwrap_or(TAG); + let zk_version = option_env!("ZK_VERSION").unwrap_or(DEFAULT_ZK_VERSION); + format!("{tag}-{commit}-{zk_version}") +} + +pub fn get_version() -> String { + unsafe { VERSION.get_or_init(init_version).clone() } } pub fn init_tracing() { @@ -19,3 +32,8 @@ pub fn init_tracing() { .try_init() .expect("Failed to initialize tracing subscriber"); } + +pub fn format_cloud_prover_name(provider_name: String, index: usize) -> String { + // note the name of cloud prover is in fact in the format of "cloud_prover_{provider-name}_index", + format!("cloud_prover_{}_{}", provider_name, index) +}