Skip to content

Commit

Permalink
feat: make threading work and hands dirty
Browse files Browse the repository at this point in the history
  • Loading branch information
meloalright committed Oct 6, 2024
1 parent d14d6bd commit 3f28886
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ members = ["interpreter"]
[dependencies]
rustyline = { version = "12.0.0", optional = true }
rustyline-derive = { version = "0.4.0", optional = true }
three_body_interpreter = { version = "0.6.1", path = "./interpreter", features = ["sophon"] }
three_body_interpreter = { version = "0.6.1", path = "./interpreter", features = ["sophon", "threading"] }

[[bin]]
name = "3body"
Expand Down
4 changes: 3 additions & 1 deletion interpreter/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ rand = { version = "0.8.5" }
llm = { version = "0.1.1", optional = true }
llm-base = { version = "0.1.1", optional = true }
spinoff = { version = "0.7.0", default-features = false, features = ["dots", "arc", "line"], optional = true }
tokio = { version = "1.40.0", features = ["sync", "time", "macros", "rt-multi-thread"], optional = true }

[features]
default = []
sophon = ["llm", "llm-base", "spinoff"]
sophon = ["llm", "llm-base", "spinoff"]
threading = ["tokio"]
91 changes: 86 additions & 5 deletions interpreter/src/evaluator/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ use llm::{load_progress_callback_stdout as load_callback, InferenceParameters, M
use llm_base::InferenceRequest;
#[cfg(feature="sophon")]
use std::{convert::Infallible, io::Write, path::Path};
use std::time::Duration;
#[cfg(feature="sophon")]
use spinoff;
use crate::evaluator::object::Object::Native;

pub fn new_builtins() -> HashMap<String, Object> {
let mut builtins = HashMap::new();
Expand Down Expand Up @@ -44,6 +46,11 @@ pub fn new_builtins() -> HashMap<String, Object> {
String::from("智子工程"),
Object::Builtin(1, three_body_sophon_engineering),
);
#[cfg(feature="threading")] // threading
builtins.insert(
String::from("饱和式救援"),
Object::Builtin(1, three_body_threading),
);
builtins
}

Expand Down Expand Up @@ -174,7 +181,7 @@ fn three_body_sophon_engineering(args: Vec<Object>) -> Object {
};

let model_type = model_type.as_str();


let model_path = {
match model_path {
Expand Down Expand Up @@ -208,7 +215,7 @@ fn three_body_sophon_engineering(args: Vec<Object>) -> Object {
.unwrap_or_else(|err| {
panic!("Failed to load {model_type} model from {model_path:?}: {err}")
});

let model = Box::leak(model);

println!(
Expand Down Expand Up @@ -289,11 +296,11 @@ fn three_body_sophon_engineering(args: Vec<Object>) -> Object {
|t| {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(())
},
);

match res {
Err(err) => println!("\n{err}"),
_ => ()
Expand All @@ -315,7 +322,7 @@ fn three_body_sophon_engineering(args: Vec<Object>) -> Object {
NativeObject::LLMModel(model_ptr) => {
model_ptr.clone()
},
_ => panic!()
_ => panic!(),
}
},
_ => panic!()
Expand All @@ -337,6 +344,75 @@ fn three_body_sophon_engineering(args: Vec<Object>) -> Object {
}
}

use std::cell::RefCell;
use std::rc::Rc;
use crate::evaluator;
use crate::parser;
use crate::ast;
use crate::ast::{BlockStmt, Stmt};
use crate::lexer;

fn eval(input: &str) -> Option<Object> {
evaluator::Evaluator {
env: Rc::new(RefCell::new(evaluator::env::Env::from(new_builtins()))),
}
.eval(&parser::Parser::new(lexer::Lexer::new(input)).parse())
}

#[cfg(feature="threading")]
fn three_body_threading(args: Vec<Object>) -> Object {
match &args[0] {
Object::Int(o) => {
async fn local_task(id: i64) {
println!("Local task {} is running!", id);
tokio::time::sleep(Duration::from_secs(1)).await;
println!("Local task {} completed!", id);
}

let o = (*o).clone();

let mut handle = std::thread::spawn(move || {
let local_set = tokio::task::LocalSet::new();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();

// 在 LocalSet 中安排任务
local_set.spawn_local(local_task(o));

// 运行 LocalSet 直到其中的任务完成
rt.block_on(local_set);
});
Object::Null
},
Object::String(input) => {
// async fn local_task(stmt: &BlockStmt) {
// println!("Local task {} is running!", id);
// tokio::time::sleep(Duration::from_secs(1)).await;
// // println!("Local task {} completed!", id);
// }
let input = (*input).clone();

let mut handle = std::thread::spawn(move || {
let local_set = tokio::task::LocalSet::new();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();

// 在 LocalSet 中安排任务
local_set.spawn_local(async move { eval(&input) });

// 运行 LocalSet 直到其中的任务完成
rt.block_on(local_set);
});
Object::Null
},
_ => Object::Null,
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -578,4 +654,9 @@ mod tests {
assert_eq!(got, expected);
}
}

#[test]
#[cfg(feature="threading")]
fn test_three_body_threading() {
}
}

0 comments on commit 3f28886

Please sign in to comment.