Skip to content

Commit

Permalink
refactor(katana-node): flatten rpc server building logic (dojoengine#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kariy authored Dec 10, 2024
1 parent 71db0b4 commit cc4c800
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 55 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/dojo/test-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ anyhow.workspace = true
assert_fs.workspace = true
async-trait.workspace = true
camino.workspace = true
jsonrpsee = { workspace = true, features = [ "server" ] }
katana-core = { workspace = true }
katana-executor = { workspace = true, features = [ "blockifier" ] }
katana-node.workspace = true
katana-primitives = { workspace = true }
katana-rpc.workspace = true
scarb.workspace = true
scarb-ui.workspace = true
serde.workspace = true
Expand Down
7 changes: 4 additions & 3 deletions crates/dojo/test-utils/src/sequencer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::collections::HashSet;
use std::sync::Arc;

use jsonrpsee::core::Error;
use katana_core::backend::Backend;
use katana_core::constants::DEFAULT_SEQUENCER_ADDRESS;
use katana_executor::implementation::blockifier::BlockifierFactory;
Expand All @@ -11,6 +10,7 @@ pub use katana_node::config::*;
use katana_node::LaunchedNode;
use katana_primitives::chain::ChainId;
use katana_primitives::chain_spec::ChainSpec;
use katana_rpc::Error;
use starknet::accounts::{ExecutionEncoding, SingleOwnerAccount};
use starknet::core::chain_id;
use starknet::core::types::{BlockId, BlockTag, Felt};
Expand Down Expand Up @@ -42,7 +42,8 @@ impl TestSequencer {
.await
.expect("Failed to launch node");

let url = Url::parse(&format!("http://{}", handle.rpc.addr)).expect("Failed to parse URL");
let url =
Url::parse(&format!("http://{}", handle.rpc.addr())).expect("Failed to parse URL");

let account = handle.node.backend.chain_spec.genesis.accounts().next().unwrap();
let account = TestAccount {
Expand Down Expand Up @@ -104,7 +105,7 @@ impl TestSequencer {
}

pub fn stop(self) -> Result<(), Error> {
self.handle.rpc.handle.stop()
self.handle.rpc.stop()
}

pub fn url(&self) -> Url {
Expand Down
2 changes: 1 addition & 1 deletion crates/katana/node/src/exit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl<'a> NodeStoppedFuture<'a> {
pub(crate) fn new(handle: &'a LaunchedNode) -> Self {
let fut = Box::pin(async {
handle.node.task_manager.wait_for_shutdown().await;
handle.stop().await?;
handle.rpc.clone().stopped().await;
Ok(())
});
Self { fut }
Expand Down
83 changes: 37 additions & 46 deletions crates/katana/node/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::future::IntoFuture;
use std::sync::Arc;

use anyhow::Result;
use config::rpc::{ApiKind, RpcConfig};
use config::rpc::ApiKind;
use config::Config;
use dojo_metrics::exporters::prometheus::PrometheusRecorder;
use dojo_metrics::{Report, Server as MetricsServer};
Expand All @@ -28,7 +28,7 @@ use katana_core::env::BlockContextGenerator;
use katana_core::service::block_producer::BlockProducer;
use katana_db::mdbx::DbEnv;
use katana_executor::implementation::blockifier::BlockifierFactory;
use katana_executor::{ExecutionFlags, ExecutorFactory};
use katana_executor::ExecutionFlags;
use katana_pool::ordering::FiFo;
use katana_pool::TxPool;
use katana_primitives::block::GasPrices;
Expand Down Expand Up @@ -64,7 +64,7 @@ impl LaunchedNode {
/// This will instruct the node to stop and wait until it has actually stop.
pub async fn stop(&self) -> Result<()> {
// TODO: wait for the rpc server to stop instead of just stopping it.
self.rpc.handle.stop()?;
self.rpc.stop()?;
self.node.task_manager.shutdown().await;
Ok(())
}
Expand All @@ -83,18 +83,18 @@ impl LaunchedNode {
pub struct Node {
pub pool: TxPool,
pub db: Option<DbEnv>,
pub rpc_server: RpcServer,
pub task_manager: TaskManager,
pub backend: Arc<Backend<BlockifierFactory>>,
pub block_producer: BlockProducer<BlockifierFactory>,
pub config: Arc<Config>,
forked_client: Option<ForkedClient>,
}

impl Node {
/// Start the node.
///
/// This method will start all the node process, running them until the node is stopped.
pub async fn launch(mut self) -> Result<LaunchedNode> {
pub async fn launch(self) -> Result<LaunchedNode> {
let chain = self.backend.chain_spec.id;
info!(%chain, "Starting node.");

Expand Down Expand Up @@ -135,16 +135,18 @@ impl Node {
.name("Sequencing")
.spawn(sequencing.into_future());

let node_components = (pool, backend, block_producer, self.forked_client.take());
let rpc = spawn(node_components, self.config.rpc.clone()).await?;
// --- start the rpc server

let rpc_handle = self.rpc_server.start(self.config.rpc.socket_addr()).await?;

// --- start the gas oracle worker task

if let Some(ref url) = self.config.l1_provider_url {
self.backend.gas_oracle.run_worker(self.task_manager.task_spawner());
info!(%url, "Gas Price Oracle started.");
};

Ok(LaunchedNode { node: self, rpc })
Ok(LaunchedNode { node: self, rpc: rpc_handle })
}
}

Expand Down Expand Up @@ -240,36 +242,18 @@ pub async fn build(mut config: Config) -> Result<Node> {
let validator = block_producer.validator();
let pool = TxPool::new(validator.clone(), FiFo::new());

let node = Node {
db,
pool,
backend,
forked_client,
block_producer,
config: Arc::new(config),
task_manager: TaskManager::current(),
};

Ok(node)
}

// Moved from `katana_rpc` crate
pub async fn spawn<EF: ExecutorFactory>(
node_components: (TxPool, Arc<Backend<EF>>, BlockProducer<EF>, Option<ForkedClient>),
config: RpcConfig,
) -> Result<RpcServerHandle> {
let (pool, backend, block_producer, forked_client) = node_components;
// --- build rpc server

let mut modules = RpcModule::new(());
let mut rpc_modules = RpcModule::new(());

let cors = Cors::new()
.allow_origins(config.cors_origins.clone())
// Allow `POST` when accessing the resource
.allow_methods([Method::POST, Method::GET])
.allow_headers([hyper::header::CONTENT_TYPE, "argent-client".parse().unwrap(), "argent-version".parse().unwrap()]);
.allow_origins(config.rpc.cors_origins.clone())
// Allow `POST` when accessing the resource
.allow_methods([Method::POST, Method::GET])
.allow_headers([hyper::header::CONTENT_TYPE, "argent-client".parse().unwrap(), "argent-version".parse().unwrap()]);

if config.apis.contains(&ApiKind::Starknet) {
let cfg = StarknetApiConfig { max_event_page_size: config.max_event_page_size };
if config.rpc.apis.contains(&ApiKind::Starknet) {
let cfg = StarknetApiConfig { max_event_page_size: config.rpc.max_event_page_size };

let api = if let Some(client) = forked_client {
StarknetApi::new_forked(
Expand All @@ -283,28 +267,35 @@ pub async fn spawn<EF: ExecutorFactory>(
StarknetApi::new(backend.clone(), pool.clone(), Some(block_producer.clone()), cfg)
};

modules.merge(StarknetApiServer::into_rpc(api.clone()))?;
modules.merge(StarknetWriteApiServer::into_rpc(api.clone()))?;
modules.merge(StarknetTraceApiServer::into_rpc(api))?;
rpc_modules.merge(StarknetApiServer::into_rpc(api.clone()))?;
rpc_modules.merge(StarknetWriteApiServer::into_rpc(api.clone()))?;
rpc_modules.merge(StarknetTraceApiServer::into_rpc(api))?;
}

if config.apis.contains(&ApiKind::Dev) {
if config.rpc.apis.contains(&ApiKind::Dev) {
let api = DevApi::new(backend.clone(), block_producer.clone());
modules.merge(api.into_rpc())?;
rpc_modules.merge(api.into_rpc())?;
}

if config.apis.contains(&ApiKind::Torii) {
if config.rpc.apis.contains(&ApiKind::Torii) {
let api = ToriiApi::new(backend.clone(), pool.clone(), block_producer.clone());
modules.merge(api.into_rpc())?;
rpc_modules.merge(api.into_rpc())?;
}

if config.apis.contains(&ApiKind::Saya) {
if config.rpc.apis.contains(&ApiKind::Saya) {
let api = SayaApi::new(backend.clone(), block_producer.clone());
modules.merge(api.into_rpc())?;
rpc_modules.merge(api.into_rpc())?;
}

let server = RpcServer::new().metrics().health_check().cors(cors).module(modules);
let handle = server.start(config.socket_addr()).await?;
let rpc_server = RpcServer::new().metrics().health_check().cors(cors).module(rpc_modules);

Ok(handle)
Ok(Node {
db,
pool,
backend,
rpc_server,
block_producer,
config: Arc::new(config),
task_manager: TaskManager::current(),
})
}
15 changes: 12 additions & 3 deletions crates/katana/rpc/rpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,17 @@ pub enum Error {
AlreadyStopped,
}

#[derive(Debug)]
/// The RPC server handle.
#[derive(Debug, Clone)]
pub struct RpcServerHandle {
pub addr: SocketAddr,
pub handle: ServerHandle,
/// The actual address that the server is binded to.
addr: SocketAddr,
/// The handle to the spawned [`jsonrpsee::server::Server`].
handle: ServerHandle,
}

impl RpcServerHandle {
/// Tell the server to stop without waiting for the server to stop.
pub fn stop(&self) -> Result<(), Error> {
self.handle.stop().map_err(|_| Error::AlreadyStopped)
}
Expand All @@ -48,6 +52,11 @@ impl RpcServerHandle {
pub async fn stopped(self) {
self.handle.stopped().await
}

/// Returns the socket address the server is listening on.
pub fn addr(&self) -> &SocketAddr {
&self.addr
}
}

#[derive(Debug)]
Expand Down

0 comments on commit cc4c800

Please sign in to comment.