Skip to content

Commit

Permalink
feat(katana-runner): allow configuring chain id (#2554)
Browse files Browse the repository at this point in the history
  • Loading branch information
kariy authored Oct 18, 2024
1 parent bdb8fb5 commit 3508fb9
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 18 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions crates/katana/node-bindings/src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,12 @@ pub struct AccountInfo {
pub struct RpcAddr {
pub addr: SocketAddr,
}

/// {
/// "message": "Starting node.",
/// "chain": "SN_SEPOLIA"
/// }
#[derive(Deserialize, Debug)]
pub struct ChainId {
pub chain: String,
}
66 changes: 58 additions & 8 deletions crates/katana/node-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use std::process::{Child, Command};
use std::str::FromStr;
use std::time::{Duration, Instant};

use json::RpcAddr;
use starknet::core::types::{Felt, FromStrError};
use starknet::core::utils::cairo_short_string_to_felt;
use starknet::macros::short_string;
use starknet::signers::SigningKey;
use thiserror::Error;
Expand Down Expand Up @@ -130,6 +130,10 @@ pub enum Error {
#[error("missing rpc server address")]
MissingSocketAddr,

/// A line indicating the instance address was found but the actual value was not.
#[error("missing chain id")]
MissingChainId,

#[error("encountered unexpected format: {0}")]
UnexpectedFormat(String),

Expand All @@ -142,6 +146,8 @@ pub enum Error {

/// The string indicator from which the RPC server address can be extracted from.
const RPC_ADDR_LOG_SUBSTR: &str = "RPC server started.";
/// The string indicator from which the chain id can be extracted from.
const CHAIN_ID_LOG_SUBSTR: &str = "Starting node.";

/// Builder for launching `katana`.
///
Expand Down Expand Up @@ -494,10 +500,9 @@ impl Katana {
let mut accounts = Vec::new();
// var to store the current account being processed
let mut current_account: Option<Account> = None;

// TODO: the chain id should be fetched from stdout as well but Katana doesn't display the
// chain id atm
let chain_id = self.chain_id.unwrap_or(short_string!("KATANA"));
// var to store the chain id parsed from the logs. default to KATANA (default katana chain
// id) if not specified
let mut chain_id: Felt = self.chain_id.unwrap_or(short_string!("KATANA"));

loop {
if start + Duration::from_millis(self.timeout.unwrap_or(KATANA_STARTUP_TIMEOUT_MILLIS))
Expand All @@ -514,13 +519,24 @@ impl Katana {
// Because we using a concrete type for rpc addr log, we need to parse this first.
// Otherwise if we were to inverse the if statements, the else block
// would never be executed as all logs can be parsed as `JsonLog`.
if let Ok(log) = serde_json::from_str::<JsonLog<RpcAddr>>(&line) {
if let Ok(log) = serde_json::from_str::<JsonLog<json::RpcAddr>>(&line) {
debug_assert!(log.fields.message.contains(RPC_ADDR_LOG_SUBSTR));
port = log.fields.other.addr.port();
// We can safely break here as we don't need any information after the rpc
// address
break;
}
// Try parsing as chain id log
else if let Ok(log) = serde_json::from_str::<JsonLog<json::ChainId>>(&line) {
debug_assert!(log.fields.message.contains(CHAIN_ID_LOG_SUBSTR));
let chain_raw = log.fields.other.chain;
chain_id = if chain_raw.starts_with("0x") {
Felt::from_str(&chain_raw)?
} else {
cairo_short_string_to_felt(&chain_raw)
.map_err(|_| Error::UnexpectedFormat("invalid chain id".to_string()))?
};
}
// Parse all logs as generic logs
else if let Ok(info) = serde_json::from_str::<JsonLog>(&line) {
// Check if this log is a katana startup info log
Expand All @@ -543,6 +559,10 @@ impl Katana {
break;
}

if line.contains(CHAIN_ID_LOG_SUBSTR) {
chain_id = parse_chain_id_log(&line)?;
}

const ACC_ADDRESS_PREFIX: &str = "| Account address |";
if line.starts_with(ACC_ADDRESS_PREFIX) {
// If there is currently an account being handled, but we've reached the next
Expand Down Expand Up @@ -609,6 +629,23 @@ fn parse_rpc_addr_log(log: &str) -> Result<SocketAddr, Error> {
Ok(SocketAddr::from_str(addr)?)
}

// Example chain ID log format (ansi color codes removed):
// 2024-10-18T01:30:14.023880Z INFO katana_node: Starting node. chain=0x4b4154414e41
fn parse_chain_id_log(log: &str) -> Result<Felt, Error> {
// remove any ANSI escape codes from the log.
let cleaned = clean_ansi_escape_codes(log)?;

// This will separate the log into two parts as separated by `chain=` str and we take
// only the second part which is the chain ID.
let chain_part = cleaned.split("chain=").nth(1).ok_or(Error::MissingChainId)?.trim();
if chain_part.starts_with("0x") {
Ok(Felt::from_str(chain_part)?)
} else {
Ok(cairo_short_string_to_felt(chain_part)
.map_err(|_| Error::UnexpectedFormat("invalid chain id".to_string()))?)
}
}

#[cfg(test)]
mod tests {
use starknet::providers::jsonrpc::HttpTransport;
Expand All @@ -634,10 +671,10 @@ mod tests {

#[test]
fn can_launch_katana_with_json_log() {
let katana = Katana::new().json_log(true).spawn();
let katana = Katana::new().json_log(true).chain_id(short_string!("SN_SEPOLIA")).spawn();
// Assert default values when using JSON logging
assert_eq!(katana.accounts().len(), 10);
assert_eq!(katana.chain_id(), short_string!("KATANA"));
assert_eq!(katana.chain_id(), short_string!("SN_SEPOLIA"));
// assert that all accounts have private key
assert!(katana.accounts().iter().all(|a| a.private_key.is_some()));
}
Expand Down Expand Up @@ -705,4 +742,17 @@ mod tests {
assert_eq!(addr.ip().to_string(), "127.0.0.1");
assert_eq!(addr.port(), 60817);
}

#[tokio::test]
async fn can_launch_katana_with_custom_chain_id() {
let custom_chain_id = Felt::from_str("0x1234").unwrap();
let katana = Katana::new().chain_id(custom_chain_id).spawn();

assert_eq!(katana.chain_id(), custom_chain_id);

let provider = JsonRpcClient::new(HttpTransport::new(katana.endpoint_url()));
let actual_chain_id = provider.chain_id().await.unwrap();

assert_eq!(custom_chain_id, actual_chain_id);
}
}
3 changes: 3 additions & 0 deletions crates/katana/node/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ impl Node {
///
/// This method will start all the node process, running them until the node is stopped.
pub async fn launch(self) -> Result<LaunchedNode> {
let chain = self.backend.chain_spec.id;
info!(%chain, "Starting node.");

if let Some(ref cfg) = self.metrics_config {
let addr = cfg.addr;
let mut reports = Vec::new();
Expand Down
1 change: 0 additions & 1 deletion crates/katana/primitives/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ serde_json.workspace = true
serde_with.workspace = true
starknet.workspace = true
starknet-crypto.workspace = true
strum_macros.workspace = true
thiserror.workspace = true

alloy-primitives.workspace = true
Expand Down
14 changes: 10 additions & 4 deletions crates/katana/primitives/src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use starknet::macros::short_string;
use crate::{Felt, FromStrError};

/// Known chain ids that has been assigned a name.
#[derive(Debug, Clone, Copy, PartialEq, Eq, strum_macros::Display)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum NamedChainId {
Mainnet,
Expand Down Expand Up @@ -44,6 +44,12 @@ impl NamedChainId {
}
}

impl std::fmt::Display for NamedChainId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.name())
}
}

/// This `struct` is created by the [`NamedChainId::try_from<u128>`] method.
#[derive(Debug, thiserror::Error)]
#[error("Unknown named chain id {0:#x}")]
Expand Down Expand Up @@ -185,9 +191,9 @@ mod tests {
assert_eq!(ChainId::from(sepolia_id), ChainId::SEPOLIA);
assert_eq!(ChainId::from(felt!("0x1337")), ChainId::Id(felt!("0x1337")));

assert_eq!(ChainId::MAINNET.to_string(), "Mainnet");
assert_eq!(ChainId::GOERLI.to_string(), "Goerli");
assert_eq!(ChainId::SEPOLIA.to_string(), "Sepolia");
assert_eq!(ChainId::MAINNET.to_string(), "SN_MAIN");
assert_eq!(ChainId::GOERLI.to_string(), "SN_GOERLI");
assert_eq!(ChainId::SEPOLIA.to_string(), "SN_SEPOLIA");
assert_eq!(ChainId::Id(felt!("0x1337")).to_string(), "0x1337");
}

Expand Down
23 changes: 21 additions & 2 deletions crates/katana/runner/macro/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct Configuration {
pub db_dir: Option<syn::Expr>,
pub block_time: Option<syn::Expr>,
pub log_path: Option<syn::Expr>,
pub chain_id: Option<syn::Expr>,
}

impl Configuration {
Expand All @@ -32,6 +33,7 @@ impl Configuration {
validation: None,
block_time: None,
crate_name: None,
chain_id: None,
}
}

Expand Down Expand Up @@ -105,6 +107,19 @@ impl Configuration {
self.accounts = Some(accounts);
Ok(())
}

fn set_chain_id(
&mut self,
chain_id: syn::Expr,
span: proc_macro2::Span,
) -> Result<(), syn::Error> {
if self.chain_id.is_some() {
return Err(syn::Error::new(span, "`chain_id` set multiple times."));
}

self.chain_id = Some(chain_id);
Ok(())
}
}

enum RunnerArg {
Expand All @@ -113,6 +128,7 @@ enum RunnerArg {
Validation,
Accounts,
DbDir,
ChainId,
}

impl std::str::FromStr for RunnerArg {
Expand All @@ -125,9 +141,10 @@ impl std::str::FromStr for RunnerArg {
"validation" => Ok(RunnerArg::Validation),
"accounts" => Ok(RunnerArg::Accounts),
"db_dir" => Ok(RunnerArg::DbDir),
"chain_id" => Ok(RunnerArg::ChainId),
_ => Err(format!(
"Unknown attribute {s} is specified; expected one of: `fee`, `validation`, \
`accounts`, `db_dir`, `block_time`",
`accounts`, `db_dir`, `block_time`, `chain_id`",
)),
}
}
Expand Down Expand Up @@ -172,7 +189,9 @@ pub fn build_config(
RunnerArg::DbDir => {
config.set_db_dir(expr.clone(), Spanned::span(&namevalue))?
}

RunnerArg::ChainId => {
config.set_chain_id(expr.clone(), Spanned::span(&namevalue))?
}
RunnerArg::Fee => config.set_fee(expr.clone(), Spanned::span(&namevalue))?,
}
}
Expand Down
6 changes: 5 additions & 1 deletion crates/katana/runner/macro/src/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ pub fn parse_knobs(input: ItemFn, is_test: bool, config: Configuration) -> Token
}

if let Some(value) = config.log_path {
cfg = quote_spanned! (last_stmt_start_span=> #cfg, log_path: Some(#value), );
cfg = quote_spanned! (last_stmt_start_span=> #cfg log_path: Some(#value), );
}

if let Some(value) = config.chain_id {
cfg = quote_spanned! (last_stmt_start_span=> #cfg chain_id: Some(#value), );
}

if config.dev {
Expand Down
7 changes: 7 additions & 0 deletions crates/katana/runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ pub struct KatanaRunnerConfig {
pub db_dir: Option<PathBuf>,
/// Whether to run the katana runner with the `dev` rpc endpoints.
pub dev: bool,
/// The chain id to use.
pub chain_id: Option<Felt>,
}

impl Default for KatanaRunnerConfig {
Expand All @@ -82,6 +84,7 @@ impl Default for KatanaRunnerConfig {
messaging: None,
db_dir: None,
dev: false,
chain_id: None,
}
}
}
Expand Down Expand Up @@ -123,6 +126,10 @@ impl KatanaRunner {
.dev(config.dev)
.fee(!config.disable_fee);

if let Some(id) = config.chain_id {
builder = builder.chain_id(id);
}

if let Some(block_time_ms) = config.block_time {
builder = builder.block_time(block_time_ms);
}
Expand Down
12 changes: 11 additions & 1 deletion crates/katana/runner/tests/runner.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
use katana_runner::RunnerCtx;
use starknet::macros::short_string;
use starknet::providers::Provider;

#[katana_runner::test(fee = false, accounts = 7)]
fn simple(runner: &RunnerCtx) {
assert_eq!(runner.accounts().len(), 7);
}

#[tokio::test(flavor = "multi_thread")]
#[katana_runner::test(chain_id = short_string!("SN_SEPOLIA"))]
async fn custom_chain_id(runner: &RunnerCtx) {
let provider = runner.provider();
let id = provider.chain_id().await.unwrap();
assert_eq!(id, short_string!("SN_SEPOLIA"));
}

#[katana_runner::test]
fn with_return(_: &RunnerCtx) -> Result<(), Box<dyn std::error::Error>> {
Ok(())
Expand All @@ -15,6 +24,7 @@ fn with_return(_: &RunnerCtx) -> Result<(), Box<dyn std::error::Error>> {
#[katana_runner::test]
async fn with_async(ctx: &RunnerCtx) -> Result<(), Box<dyn std::error::Error>> {
let provider = ctx.provider();
let _ = provider.chain_id().await?;
let id = provider.chain_id().await?;
assert_eq!(id, short_string!("KATANA"));
Ok(())
}

0 comments on commit 3508fb9

Please sign in to comment.