diff --git a/Cargo.lock b/Cargo.lock index 14e5553b0a..2c5222a96f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8218,7 +8218,6 @@ dependencies = [ "similar-asserts", "starknet 0.12.0", "starknet-crypto 0.7.2", - "strum_macros 0.25.3", "thiserror", ] diff --git a/crates/katana/node-bindings/src/json.rs b/crates/katana/node-bindings/src/json.rs index 2140b7ee79..0c8db425ae 100644 --- a/crates/katana/node-bindings/src/json.rs +++ b/crates/katana/node-bindings/src/json.rs @@ -69,3 +69,12 @@ pub struct AccountInfo { pub struct RpcAddr { pub addr: SocketAddr, } + +/// { +/// "message": "Starting node.", +/// "chain": "SN_SEPOLIA" +/// } +#[derive(Deserialize, Debug)] +pub struct ChainId { + pub chain: String, +} diff --git a/crates/katana/node-bindings/src/lib.rs b/crates/katana/node-bindings/src/lib.rs index bcf3071902..f270e0d384 100644 --- a/crates/katana/node-bindings/src/lib.rs +++ b/crates/katana/node-bindings/src/lib.rs @@ -14,8 +14,8 @@ use std::process::{Child, Command}; use std::str::FromStr; use std::time::{Duration, Instant}; -use json::RpcAddr; use starknet::core::types::{Felt, FromStrError}; +use starknet::core::utils::cairo_short_string_to_felt; use starknet::macros::short_string; use starknet::signers::SigningKey; use thiserror::Error; @@ -130,6 +130,10 @@ pub enum Error { #[error("missing rpc server address")] MissingSocketAddr, + /// A line indicating the instance address was found but the actual value was not. + #[error("missing chain id")] + MissingChainId, + #[error("encountered unexpected format: {0}")] UnexpectedFormat(String), @@ -142,6 +146,8 @@ pub enum Error { /// The string indicator from which the RPC server address can be extracted from. const RPC_ADDR_LOG_SUBSTR: &str = "RPC server started."; +/// The string indicator from which the chain id can be extracted from. +const CHAIN_ID_LOG_SUBSTR: &str = "Starting node."; /// Builder for launching `katana`. /// @@ -494,10 +500,9 @@ impl Katana { let mut accounts = Vec::new(); // var to store the current account being processed let mut current_account: Option = None; - - // TODO: the chain id should be fetched from stdout as well but Katana doesn't display the - // chain id atm - let chain_id = self.chain_id.unwrap_or(short_string!("KATANA")); + // var to store the chain id parsed from the logs. default to KATANA (default katana chain + // id) if not specified + let mut chain_id: Felt = self.chain_id.unwrap_or(short_string!("KATANA")); loop { if start + Duration::from_millis(self.timeout.unwrap_or(KATANA_STARTUP_TIMEOUT_MILLIS)) @@ -514,13 +519,24 @@ impl Katana { // Because we using a concrete type for rpc addr log, we need to parse this first. // Otherwise if we were to inverse the if statements, the else block // would never be executed as all logs can be parsed as `JsonLog`. - if let Ok(log) = serde_json::from_str::>(&line) { + if let Ok(log) = serde_json::from_str::>(&line) { debug_assert!(log.fields.message.contains(RPC_ADDR_LOG_SUBSTR)); port = log.fields.other.addr.port(); // We can safely break here as we don't need any information after the rpc // address break; } + // Try parsing as chain id log + else if let Ok(log) = serde_json::from_str::>(&line) { + debug_assert!(log.fields.message.contains(CHAIN_ID_LOG_SUBSTR)); + let chain_raw = log.fields.other.chain; + chain_id = if chain_raw.starts_with("0x") { + Felt::from_str(&chain_raw)? + } else { + cairo_short_string_to_felt(&chain_raw) + .map_err(|_| Error::UnexpectedFormat("invalid chain id".to_string()))? + }; + } // Parse all logs as generic logs else if let Ok(info) = serde_json::from_str::(&line) { // Check if this log is a katana startup info log @@ -543,6 +559,10 @@ impl Katana { break; } + if line.contains(CHAIN_ID_LOG_SUBSTR) { + chain_id = parse_chain_id_log(&line)?; + } + const ACC_ADDRESS_PREFIX: &str = "| Account address |"; if line.starts_with(ACC_ADDRESS_PREFIX) { // If there is currently an account being handled, but we've reached the next @@ -609,6 +629,23 @@ fn parse_rpc_addr_log(log: &str) -> Result { Ok(SocketAddr::from_str(addr)?) } +// Example chain ID log format (ansi color codes removed): +// 2024-10-18T01:30:14.023880Z INFO katana_node: Starting node. chain=0x4b4154414e41 +fn parse_chain_id_log(log: &str) -> Result { + // remove any ANSI escape codes from the log. + let cleaned = clean_ansi_escape_codes(log)?; + + // This will separate the log into two parts as separated by `chain=` str and we take + // only the second part which is the chain ID. + let chain_part = cleaned.split("chain=").nth(1).ok_or(Error::MissingChainId)?.trim(); + if chain_part.starts_with("0x") { + Ok(Felt::from_str(chain_part)?) + } else { + Ok(cairo_short_string_to_felt(chain_part) + .map_err(|_| Error::UnexpectedFormat("invalid chain id".to_string()))?) + } +} + #[cfg(test)] mod tests { use starknet::providers::jsonrpc::HttpTransport; @@ -634,10 +671,10 @@ mod tests { #[test] fn can_launch_katana_with_json_log() { - let katana = Katana::new().json_log(true).spawn(); + let katana = Katana::new().json_log(true).chain_id(short_string!("SN_SEPOLIA")).spawn(); // Assert default values when using JSON logging assert_eq!(katana.accounts().len(), 10); - assert_eq!(katana.chain_id(), short_string!("KATANA")); + assert_eq!(katana.chain_id(), short_string!("SN_SEPOLIA")); // assert that all accounts have private key assert!(katana.accounts().iter().all(|a| a.private_key.is_some())); } @@ -705,4 +742,17 @@ mod tests { assert_eq!(addr.ip().to_string(), "127.0.0.1"); assert_eq!(addr.port(), 60817); } + + #[tokio::test] + async fn can_launch_katana_with_custom_chain_id() { + let custom_chain_id = Felt::from_str("0x1234").unwrap(); + let katana = Katana::new().chain_id(custom_chain_id).spawn(); + + assert_eq!(katana.chain_id(), custom_chain_id); + + let provider = JsonRpcClient::new(HttpTransport::new(katana.endpoint_url())); + let actual_chain_id = provider.chain_id().await.unwrap(); + + assert_eq!(custom_chain_id, actual_chain_id); + } } diff --git a/crates/katana/node/src/lib.rs b/crates/katana/node/src/lib.rs index 53f103b9c7..09ccd54d1e 100644 --- a/crates/katana/node/src/lib.rs +++ b/crates/katana/node/src/lib.rs @@ -96,6 +96,9 @@ impl Node { /// /// This method will start all the node process, running them until the node is stopped. pub async fn launch(self) -> Result { + let chain = self.backend.chain_spec.id; + info!(%chain, "Starting node."); + if let Some(ref cfg) = self.metrics_config { let addr = cfg.addr; let mut reports = Vec::new(); diff --git a/crates/katana/primitives/Cargo.toml b/crates/katana/primitives/Cargo.toml index 4c66693811..4f6de47088 100644 --- a/crates/katana/primitives/Cargo.toml +++ b/crates/katana/primitives/Cargo.toml @@ -18,7 +18,6 @@ serde_json.workspace = true serde_with.workspace = true starknet.workspace = true starknet-crypto.workspace = true -strum_macros.workspace = true thiserror.workspace = true alloy-primitives.workspace = true diff --git a/crates/katana/primitives/src/chain.rs b/crates/katana/primitives/src/chain.rs index d8c359892f..a1f8408d8e 100644 --- a/crates/katana/primitives/src/chain.rs +++ b/crates/katana/primitives/src/chain.rs @@ -4,7 +4,7 @@ use starknet::macros::short_string; use crate::{Felt, FromStrError}; /// Known chain ids that has been assigned a name. -#[derive(Debug, Clone, Copy, PartialEq, Eq, strum_macros::Display)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum NamedChainId { Mainnet, @@ -44,6 +44,12 @@ impl NamedChainId { } } +impl std::fmt::Display for NamedChainId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.name()) + } +} + /// This `struct` is created by the [`NamedChainId::try_from`] method. #[derive(Debug, thiserror::Error)] #[error("Unknown named chain id {0:#x}")] @@ -185,9 +191,9 @@ mod tests { assert_eq!(ChainId::from(sepolia_id), ChainId::SEPOLIA); assert_eq!(ChainId::from(felt!("0x1337")), ChainId::Id(felt!("0x1337"))); - assert_eq!(ChainId::MAINNET.to_string(), "Mainnet"); - assert_eq!(ChainId::GOERLI.to_string(), "Goerli"); - assert_eq!(ChainId::SEPOLIA.to_string(), "Sepolia"); + assert_eq!(ChainId::MAINNET.to_string(), "SN_MAIN"); + assert_eq!(ChainId::GOERLI.to_string(), "SN_GOERLI"); + assert_eq!(ChainId::SEPOLIA.to_string(), "SN_SEPOLIA"); assert_eq!(ChainId::Id(felt!("0x1337")).to_string(), "0x1337"); } diff --git a/crates/katana/runner/macro/src/config.rs b/crates/katana/runner/macro/src/config.rs index cb6c6b8651..32378cc4ef 100644 --- a/crates/katana/runner/macro/src/config.rs +++ b/crates/katana/runner/macro/src/config.rs @@ -18,6 +18,7 @@ pub struct Configuration { pub db_dir: Option, pub block_time: Option, pub log_path: Option, + pub chain_id: Option, } impl Configuration { @@ -32,6 +33,7 @@ impl Configuration { validation: None, block_time: None, crate_name: None, + chain_id: None, } } @@ -105,6 +107,19 @@ impl Configuration { self.accounts = Some(accounts); Ok(()) } + + fn set_chain_id( + &mut self, + chain_id: syn::Expr, + span: proc_macro2::Span, + ) -> Result<(), syn::Error> { + if self.chain_id.is_some() { + return Err(syn::Error::new(span, "`chain_id` set multiple times.")); + } + + self.chain_id = Some(chain_id); + Ok(()) + } } enum RunnerArg { @@ -113,6 +128,7 @@ enum RunnerArg { Validation, Accounts, DbDir, + ChainId, } impl std::str::FromStr for RunnerArg { @@ -125,9 +141,10 @@ impl std::str::FromStr for RunnerArg { "validation" => Ok(RunnerArg::Validation), "accounts" => Ok(RunnerArg::Accounts), "db_dir" => Ok(RunnerArg::DbDir), + "chain_id" => Ok(RunnerArg::ChainId), _ => Err(format!( "Unknown attribute {s} is specified; expected one of: `fee`, `validation`, \ - `accounts`, `db_dir`, `block_time`", + `accounts`, `db_dir`, `block_time`, `chain_id`", )), } } @@ -172,7 +189,9 @@ pub fn build_config( RunnerArg::DbDir => { config.set_db_dir(expr.clone(), Spanned::span(&namevalue))? } - + RunnerArg::ChainId => { + config.set_chain_id(expr.clone(), Spanned::span(&namevalue))? + } RunnerArg::Fee => config.set_fee(expr.clone(), Spanned::span(&namevalue))?, } } diff --git a/crates/katana/runner/macro/src/entry.rs b/crates/katana/runner/macro/src/entry.rs index 6311fbdf52..53f2d390ce 100644 --- a/crates/katana/runner/macro/src/entry.rs +++ b/crates/katana/runner/macro/src/entry.rs @@ -74,7 +74,11 @@ pub fn parse_knobs(input: ItemFn, is_test: bool, config: Configuration) -> Token } if let Some(value) = config.log_path { - cfg = quote_spanned! (last_stmt_start_span=> #cfg, log_path: Some(#value), ); + cfg = quote_spanned! (last_stmt_start_span=> #cfg log_path: Some(#value), ); + } + + if let Some(value) = config.chain_id { + cfg = quote_spanned! (last_stmt_start_span=> #cfg chain_id: Some(#value), ); } if config.dev { diff --git a/crates/katana/runner/src/lib.rs b/crates/katana/runner/src/lib.rs index fc063e01cc..c3633b0e5d 100644 --- a/crates/katana/runner/src/lib.rs +++ b/crates/katana/runner/src/lib.rs @@ -67,6 +67,8 @@ pub struct KatanaRunnerConfig { pub db_dir: Option, /// Whether to run the katana runner with the `dev` rpc endpoints. pub dev: bool, + /// The chain id to use. + pub chain_id: Option, } impl Default for KatanaRunnerConfig { @@ -82,6 +84,7 @@ impl Default for KatanaRunnerConfig { messaging: None, db_dir: None, dev: false, + chain_id: None, } } } @@ -123,6 +126,10 @@ impl KatanaRunner { .dev(config.dev) .fee(!config.disable_fee); + if let Some(id) = config.chain_id { + builder = builder.chain_id(id); + } + if let Some(block_time_ms) = config.block_time { builder = builder.block_time(block_time_ms); } diff --git a/crates/katana/runner/tests/runner.rs b/crates/katana/runner/tests/runner.rs index dcec61acd1..ba068658ff 100644 --- a/crates/katana/runner/tests/runner.rs +++ b/crates/katana/runner/tests/runner.rs @@ -1,4 +1,5 @@ use katana_runner::RunnerCtx; +use starknet::macros::short_string; use starknet::providers::Provider; #[katana_runner::test(fee = false, accounts = 7)] @@ -6,6 +7,14 @@ fn simple(runner: &RunnerCtx) { assert_eq!(runner.accounts().len(), 7); } +#[tokio::test(flavor = "multi_thread")] +#[katana_runner::test(chain_id = short_string!("SN_SEPOLIA"))] +async fn custom_chain_id(runner: &RunnerCtx) { + let provider = runner.provider(); + let id = provider.chain_id().await.unwrap(); + assert_eq!(id, short_string!("SN_SEPOLIA")); +} + #[katana_runner::test] fn with_return(_: &RunnerCtx) -> Result<(), Box> { Ok(()) @@ -15,6 +24,7 @@ fn with_return(_: &RunnerCtx) -> Result<(), Box> { #[katana_runner::test] async fn with_async(ctx: &RunnerCtx) -> Result<(), Box> { let provider = ctx.provider(); - let _ = provider.chain_id().await?; + let id = provider.chain_id().await?; + assert_eq!(id, short_string!("KATANA")); Ok(()) }