diff --git a/Cargo.lock b/Cargo.lock index f17e336..69fb50a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -466,9 +466,9 @@ checksum = "7b02b629252fe8ef6460461409564e2c21d0c8e77e0944f3d189ff06c4e932ad" [[package]] name = "cc" -version = "1.2.3" +version = "1.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27f657647bcff5394bf56c7317665bbf790a137a50eaaa5c6bfbb9e27a518f2d" +checksum = "9157bbaa6b165880c27a4293a474c91cdcf265cc68cc829bf10be0964a391caf" dependencies = [ "shlex", ] @@ -548,12 +548,12 @@ checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "colored" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbf2150cce219b664a8a70df7a1f933836724b503f8a413af9365b4dcc4d90b8" +checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" dependencies = [ "lazy_static", - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -613,9 +613,9 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" dependencies = [ "crossbeam-epoch", "crossbeam-utils", @@ -632,9 +632,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" @@ -974,7 +974,7 @@ dependencies = [ [[package]] name = "dkn-compute" -version = "0.2.30" +version = "0.2.31" dependencies = [ "async-trait", "base64 0.22.1", @@ -1007,7 +1007,7 @@ dependencies = [ [[package]] name = "dkn-monitor" -version = "0.2.30" +version = "0.2.31" dependencies = [ "async-trait", "dkn-compute", @@ -1027,7 +1027,7 @@ dependencies = [ [[package]] name = "dkn-p2p" -version = "0.2.30" +version = "0.2.31" dependencies = [ "dkn-utils", "env_logger 0.11.5", @@ -1041,11 +1041,11 @@ dependencies = [ [[package]] name = "dkn-utils" -version = "0.2.30" +version = "0.2.31" [[package]] name = "dkn-workflows" -version = "0.2.30" +version = "0.2.31" dependencies = [ "dkn-utils", "dotenvy", @@ -1964,9 +1964,9 @@ dependencies = [ [[package]] name = "hyper" -version = "0.14.31" +version = "0.14.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" dependencies = [ "bytes 1.9.0", "futures-channel", @@ -1988,9 +1988,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.5.1" +version = "1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" +checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0" dependencies = [ "bytes 1.9.0", "futures-channel", @@ -2014,7 +2014,7 @@ checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.2.0", - "hyper 1.5.1", + "hyper 1.5.2", "hyper-util", "rustls", "rustls-pki-types", @@ -2045,7 +2045,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes 1.9.0", "http-body-util", - "hyper 1.5.1", + "hyper 1.5.2", "hyper-util", "native-tls", "tokio 1.42.0", @@ -2064,7 +2064,7 @@ dependencies = [ "futures-util", "http 1.2.0", "http-body 1.0.1", - "hyper 1.5.1", + "hyper 1.5.2", "pin-project-lite 0.2.15", "socket2 0.5.8", "tokio 1.42.0", @@ -2284,7 +2284,7 @@ dependencies = [ "bytes 1.9.0", "futures", "http 0.2.12", - "hyper 0.14.31", + "hyper 0.14.32", "log", "rand 0.8.5", "tokio 1.42.0", @@ -3140,9 +3140,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.0" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" dependencies = [ "adler2", ] @@ -3465,7 +3465,7 @@ dependencies = [ [[package]] name = "ollama-workflows" version = "0.1.0" -source = "git+https://github.com/andthattoo/ollama-workflows#c5c586cafeab0a8459015c6c42b0b078e2d14128" +source = "git+https://github.com/andthattoo/ollama-workflows#46dc2c5b0355aa60b5cd786f9dbffcb1e9f215e8" dependencies = [ "async-trait", "base64 0.22.1", @@ -3989,7 +3989,7 @@ dependencies = [ "rustc-hash", "rustls", "socket2 0.5.8", - "thiserror 2.0.6", + "thiserror 2.0.8", "tokio 1.42.0", "tracing", ] @@ -4008,7 +4008,7 @@ dependencies = [ "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.6", + "thiserror 2.0.8", "tinyvec", "tracing", "web-time", @@ -4016,9 +4016,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.8" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52cd4b1eff68bf27940dd39811292c49e007f4d0b4c357358dc9b0197be6b527" +checksum = "1c40286217b4ba3a71d644d752e6a0b71f13f1b6a2c5311acfcbe0c2418ed904" dependencies = [ "cfg_aliases", "libc", @@ -4249,7 +4249,7 @@ dependencies = [ "http 1.2.0", "http-body 1.0.1", "http-body-util", - "hyper 1.5.1", + "hyper 1.5.2", "hyper-rustls", "hyper-tls 0.6.0", "hyper-util", @@ -4428,9 +4428,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" +checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" dependencies = [ "web-time", ] @@ -4553,9 +4553,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.12.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2" +checksum = "1863fd3768cd83c56a7f60faa4dc0d403f1b6df0a38c3c25f44b7894e45370d5" dependencies = [ "core-foundation-sys", "libc", @@ -5068,11 +5068,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.6" +version = "2.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fec2a1820ebd077e2b90c4df007bebf344cd394098a13c563957d0afc83ea47" +checksum = "08f5383f3e0071702bf93ab5ee99b52d26936be9dedd9413067cbdcddcb6141a" dependencies = [ - "thiserror-impl 2.0.6", + "thiserror-impl 2.0.8", ] [[package]] @@ -5088,9 +5088,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.6" +version = "2.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d65750cab40f4ff1929fb1ba509e9914eb756131cef4210da8d5d700d26f6312" +checksum = "f2f357fcec90b3caef6623a099691be676d033b40a058ac95d2a6ade6fa0c943" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index bd8571a..2d986d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ default-members = ["compute"] [workspace.package] edition = "2021" -version = "0.2.30" +version = "0.2.31" license = "Apache-2.0" readme = "README.md" diff --git a/compute/src/node.rs b/compute/src/node.rs index a357e70..872e5ad 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -471,13 +471,16 @@ impl DriaComputeNode { #[cfg(test)] mod tests { use super::*; - use std::env; #[tokio::test] #[ignore = "run this manually"] async fn test_publish_message() -> eyre::Result<()> { - env::set_var("RUST_LOG", "none,dkn_compute=debug,dkn_p2p=debug"); - let _ = env_logger::builder().is_test(true).try_init(); + let _ = env_logger::builder() + .filter_level(log::LevelFilter::Off) + .filter_module("dkn_compute", log::LevelFilter::Debug) + .filter_module("dkn_p2p", log::LevelFilter::Debug) + .is_test(true) + .try_init(); // create node let cancellation = CancellationToken::new(); diff --git a/compute/src/workers/workflow.rs b/compute/src/workers/workflow.rs index f311ab4..ad763da 100644 --- a/compute/src/workers/workflow.rs +++ b/compute/src/workers/workflow.rs @@ -106,6 +106,7 @@ impl WorkflowsWorker { // (1) there are no tasks, or, // (2) there are tasks less than the batch size and the channel is not empty while tasks.is_empty() || (tasks.len() < batch_size && !self.workflow_rx.is_empty()) { + log::info!("Waiting for more workflows to process ({})", tasks.len()); let limit = batch_size - tasks.len(); match self.workflow_rx.recv_many(&mut tasks, limit).await { // 0 tasks returned means that the channel is closed @@ -235,3 +236,109 @@ impl WorkflowsWorker { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::payloads::TaskStats; + + use dkn_workflows::{Executor, Model}; + use libsecp256k1::{PublicKey, SecretKey}; + use tokio::sync::mpsc; + + // cargo test --package dkn-compute --lib --all-features -- workers::workflow::tests::test_workflows_worker --exact --show-output --nocapture --ignored + #[tokio::test] + #[ignore = "run manually"] + async fn test_workflows_worker() { + let _ = env_logger::builder() + .filter_level(log::LevelFilter::Off) + .filter_module("dkn_compute", log::LevelFilter::Debug) + .is_test(true) + .try_init(); + + let (publish_tx, mut publish_rx) = mpsc::channel(1024); + let (mut worker, workflow_tx) = WorkflowsWorker::new(publish_tx); + + // create batch workflow worker + let worker_handle = tokio::spawn(async move { + worker.run_batch(4).await; + }); + + let num_tasks = 4; + let model = Model::O1Preview; + let workflow = serde_json::json!({ + "config": { + "max_steps": 10, + "max_time": 250, + "tools": [""] + }, + "tasks": [ + { + "id": "A", + "name": "", + "description": "", + "operator": "generation", + "messages": [{ "role": "user", "content": "Write a 4 paragraph poem about Julius Caesar." }], + "inputs": [], + "outputs": [ { "type": "write", "key": "result", "value": "__result" } ] + }, + { + "id": "__end", + "name": "end", + "description": "End of the task", + "operator": "end", + "messages": [{ "role": "user", "content": "End of the task" }], + "inputs": [], + "outputs": [] + } + ], + "steps": [ { "source": "A", "target": "__end" } ], + "return_value": { "input": { "type": "read", "key": "result" } + } + }); + + for i in 0..num_tasks { + log::info!("Sending task {}", i + 1); + + let workflow = serde_json::from_value(workflow.clone()).unwrap(); + + let executor = Executor::new(model.clone()); + let input = WorkflowsWorkerInput { + entry: None, + executor, + workflow, + public_key: PublicKey::from_secret_key(&SecretKey::default()), + task_id: "task_id".to_string(), + model_name: model.to_string(), + stats: TaskStats::default(), + batchable: true, + }; + + // send workflow to worker + workflow_tx.send(input).await.unwrap(); + } + + // now wait for all results + let mut results = Vec::new(); + for i in 0..num_tasks { + log::info!("Waiting for result {}", i + 1); + let result = publish_rx.recv().await.unwrap(); + log::info!( + "Got result {} (exeuction time: {})", + i + 1, + (result.stats.execution_time as f64) / 1_000_000_000f64 + ); + if result.result.is_err() { + println!("Error: {:?}", result.result); + } + results.push(result); + } + + log::info!("Got all results, closing channel."); + publish_rx.close(); + + // TODO: this bugs out + worker_handle.await.unwrap(); + log::info!("Done."); + } +} diff --git a/p2p/src/client.rs b/p2p/src/client.rs index c10031d..492a842 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -126,7 +126,9 @@ impl DriaP2PClient { // dial rpc nodes for rpc_addr in &nodes.rpc_nodes { log::info!("Dialing RPC node: {}", rpc_addr); - swarm.dial(rpc_addr.clone())?; + if let Err(e) = swarm.dial(rpc_addr.clone()) { + log::error!("Error dialing RPC node: {:?}", e); + }; } // create commander diff --git a/p2p/tests/listen_test.rs b/p2p/tests/listen_test.rs index bfeafd6..e4e70ef 100644 --- a/p2p/tests/listen_test.rs +++ b/p2p/tests/listen_test.rs @@ -8,7 +8,9 @@ async fn test_listen_topic_once() -> Result<()> { const TOPIC: &str = "pong"; let _ = env_logger::builder() - .parse_filters("none,listen_test=debug,dkn_p2p=debug") + .filter_level(log::LevelFilter::Off) + .filter_module("listen_test", log::LevelFilter::Debug) + .filter_module("dkn_p2p", log::LevelFilter::Debug) .is_test(true) .try_init();