Skip to content

Commit

Permalink
Merge pull request #159 from firstbatchxyz/erhant/better-publish-time…
Browse files Browse the repository at this point in the history
…stamp-logic

Better publish timestamps & better batch logic
  • Loading branch information
erhant authored Dec 13, 2024
2 parents daa9fba + 0ef85b5 commit 7b3fc82
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 63 deletions.
14 changes: 7 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 2 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ default-members = ["compute"]

[workspace.package]
edition = "2021"
version = "0.2.29"
version = "0.2.30"
license = "Apache-2.0"
readme = "README.md"

Expand All @@ -18,14 +18,9 @@ readme = "README.md"
inherits = "release"
debug = true


# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[workspace.dependencies]
# async stuff
tokio-util = { version = "0.7.10", features = [
"rt",
] } # tokio-util provides CancellationToken
tokio-util = { version = "0.7.10", features = ["rt"] }
tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal"] }
async-trait = "0.1.81"

Expand Down
5 changes: 3 additions & 2 deletions compute/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ async fn main() -> Result<()> {
let node_token = cancellation.clone();
task_tracker.spawn(async move {
if let Err(err) = node.run(node_token).await {
log::error!("Node launch error: {}", err);
panic!("Node failed.")
log::error!("Error within main node loop: {}", err);
log::error!("Shutting down node.");
node.shutdown().await.expect("could not shutdown node");
};
log::info!("Closing node.")
});
Expand Down
12 changes: 8 additions & 4 deletions compute/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,9 @@ impl DriaComputeNode {

loop {
tokio::select! {
// check peer count every now and then
_ = peer_refresh_interval.tick() => self.handle_diagnostic_refresh().await,
// available nodes are refreshed every now and then
_ = available_node_refresh_interval.tick() => self.handle_available_nodes_refresh().await,
// prioritize the branches in the order below
biased;

// a Workflow message to be published is received from the channel
// this is expected to be sent by the workflow worker
publish_msg_opt = self.publish_rx.recv() => {
Expand All @@ -358,6 +357,11 @@ impl DriaComputeNode {
break;
};
},

// check peer count every now and then
_ = peer_refresh_interval.tick() => self.handle_diagnostic_refresh().await,
// available nodes are refreshed every now and then
_ = available_node_refresh_interval.tick() => self.handle_available_nodes_refresh().await,
// a GossipSub message is received from the channel
// this is expected to be sent by the p2p client
gossipsub_msg_opt = self.message_rx.recv() => {
Expand Down
18 changes: 18 additions & 0 deletions compute/src/payloads/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ pub struct TaskStats {
/// Timestamp at which the task was published back to network.
pub published_at: u128,
/// Time taken to execute the task.
/// FIXME: will be removed after
pub execution_time: u128,
/// Timestamp at which the task execution had started.
pub execution_started_at: u128,
/// Timestamp at which the task execution had finished.
pub execution_ended_time: u128,
}

impl TaskStats {
Expand All @@ -33,7 +38,20 @@ impl TaskStats {
self
}

/// Records the execution start time within `execution_started_at`.
pub fn record_execution_started_at(mut self) -> Self {
self.execution_started_at = get_current_time_nanos();
self
}

/// Records the execution end time within `execution_ended_time`.
pub fn record_execution_ended_at(mut self) -> Self {
self.execution_ended_time = get_current_time_nanos();
self
}

/// Records the execution time of the task.
/// TODO: #[deprecated = "will be removed later"]
pub fn record_execution_time(mut self, started_at: Instant) -> Self {
self.execution_time = Instant::now().duration_since(started_at).as_nanos();
self
Expand Down
76 changes: 44 additions & 32 deletions compute/src/workers/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl WorkflowsWorker {

if let Some(task) = task {
log::info!("Processing single workflow for task {}", task.task_id);
WorkflowsWorker::execute((task, self.publish_tx.clone())).await
WorkflowsWorker::execute((task, &self.publish_tx)).await
} else {
return self.shutdown();
};
Expand All @@ -93,76 +93,85 @@ impl WorkflowsWorker {
///
/// Batch size must NOT be larger than `MAX_BATCH_SIZE`, otherwise will panic.
pub async fn run_batch(&mut self, batch_size: usize) {
// TODO: need some better batch_size error handling here
assert!(
batch_size <= Self::MAX_BATCH_SIZE,
"Batch size must not be larger than {}",
Self::MAX_BATCH_SIZE
);

loop {
// get tasks in batch from the channel
let mut task_buffer = Vec::new();
let num_tasks = self
.workflow_rx
.recv_many(&mut task_buffer, batch_size)
.await;

if num_tasks == 0 {
return self.shutdown();
let mut tasks = Vec::new();

// get tasks in batch from the channel, we enter the loop if:
// (1) there are no tasks, or,
// (2) there are tasks less than the batch size and the channel is not empty
while tasks.len() == 0 || (tasks.len() < batch_size && !self.workflow_rx.is_empty()) {
let limit = batch_size - tasks.len();
match self.workflow_rx.recv_many(&mut tasks, limit).await {
// 0 tasks returned means that the channel is closed
0 => return self.shutdown(),
_ => {
// wait a small amount of time to allow for more tasks to be sent into the channel
tokio::time::sleep(std::time::Duration::from_millis(256)).await;
}
}
}

// process the batch
let num_tasks = tasks.len();
debug_assert!(
num_tasks <= batch_size,
"number of tasks cant be larger than batch size"
);
debug_assert!(num_tasks != 0, "number of tasks cant be zero");
log::info!("Processing {} workflows in batch", num_tasks);
let mut batch = task_buffer
.into_iter()
.map(|b| (b, self.publish_tx.clone()));
let mut batch = tasks.into_iter().map(|b| (b, &self.publish_tx));
match num_tasks {
1 => {
let r0 = WorkflowsWorker::execute(batch.next().unwrap()).await;
vec![r0]
WorkflowsWorker::execute(batch.next().unwrap()).await;
}
2 => {
let (r0, r1) = tokio::join!(
tokio::join!(
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap())
);
vec![r0, r1]
}
3 => {
let (r0, r1, r2) = tokio::join!(
tokio::join!(
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap())
);
vec![r0, r1, r2]
}
4 => {
let (r0, r1, r2, r3) = tokio::join!(
tokio::join!(
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap())
);
vec![r0, r1, r2, r3]
}
5 => {
let (r0, r1, r2, r3, r4) = tokio::join!(
tokio::join!(
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap())
);
vec![r0, r1, r2, r3, r4]
}
6 => {
let (r0, r1, r2, r3, r4, r5) = tokio::join!(
tokio::join!(
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap())
);
vec![r0, r1, r2, r3, r4, r5]
}
7 => {
let (r0, r1, r2, r3, r4, r5, r6) = tokio::join!(
tokio::join!(
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
Expand All @@ -171,10 +180,9 @@ impl WorkflowsWorker {
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap())
);
vec![r0, r1, r2, r3, r4, r5, r6]
}
8 => {
let (r0, r1, r2, r3, r4, r5, r6, r7) = tokio::join!(
tokio::join!(
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap()),
Expand All @@ -184,7 +192,6 @@ impl WorkflowsWorker {
WorkflowsWorker::execute(batch.next().unwrap()),
WorkflowsWorker::execute(batch.next().unwrap())
);
vec![r0, r1, r2, r3, r4, r5, r6, r7]
}
_ => {
unreachable!(
Expand All @@ -199,23 +206,28 @@ impl WorkflowsWorker {

/// Executes a single task, and publishes the output.
pub async fn execute(
(input, publish_tx): (WorkflowsWorkerInput, mpsc::Sender<WorkflowsWorkerOutput>),
(input, publish_tx): (WorkflowsWorkerInput, &mpsc::Sender<WorkflowsWorkerOutput>),
) {
let mut stats = input.stats;

let mut memory = ProgramMemory::new();

// TODO: will be removed later
let started_at = std::time::Instant::now();
stats = stats.record_execution_started_at();
let result = input
.executor
.execute(input.entry.as_ref(), &input.workflow, &mut memory)
.await;
stats = stats.record_execution_ended_at();

let output = WorkflowsWorkerOutput {
result,
public_key: input.public_key,
task_id: input.task_id,
model_name: input.model_name,
batchable: input.batchable,
stats: input.stats.record_execution_time(started_at),
stats: stats.record_execution_time(started_at),
};

if let Err(e) = publish_tx.send(output).await {
Expand Down
15 changes: 4 additions & 11 deletions p2p/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,6 @@ impl DriaP2PClient {
swarm.dial(rpc_addr.clone())?;
}

// add rpcs as explicit peers
// TODO: may not be necessary
// for rpc_peer_id in &nodes.rpc_peerids {
// log::info!("Adding {} as explicit peer.", rpc_peer_id);
// swarm
// .behaviour_mut()
// .gossipsub
// .add_explicit_peer(rpc_peer_id);
// }

// create commander
let (cmd_tx, cmd_rx) = mpsc::channel(COMMAND_CHANNEL_BUFSIZE);
let commander = DriaP2PCommander::new(cmd_tx, protocol.clone());
Expand All @@ -161,7 +151,9 @@ impl DriaP2PClient {
pub async fn run(mut self) {
loop {
tokio::select! {
event = self.swarm.select_next_some() => self.handle_event(event).await,
// this is a special keyword that changes the polling order from random to linear,
// which will effectively prioritize commands over events
biased;
command = self.cmd_rx.recv() => match command {
Some(c) => self.handle_command(c).await,
// channel closed, thus shutting down the network event loop
Expand All @@ -170,6 +162,7 @@ impl DriaP2PClient {
return
},
},
event = self.swarm.select_next_some() => self.handle_event(event).await,
}
}
}
Expand Down

0 comments on commit 7b3fc82

Please sign in to comment.