diff --git a/.github/workflows/Sophon.yml b/.github/workflows/Sophon.yml index da401b8..9e6cba1 100644 --- a/.github/workflows/Sophon.yml +++ b/.github/workflows/Sophon.yml @@ -28,4 +28,4 @@ jobs: run: | ./target/release/3body -V git clone https://huggingface.co/huantian2415/vicuna-13b-chinese-4bit-ggml - ./target/release/3body -c 'let 智子 = 智子工程({ "type": "llama", "path": "./vicuna-13b-chinese-4bit-ggml/Vicuna-13B-chinese.bin", "prompt": "你是三体文明的智子" }); 智子.infer(智子, "中国最佳科幻小说是哪个")' + ./target/release/3body -c 'let 智子 = fn () { let instance = 智子工程({ "type": "llama", "path": "/Users/m2/Downloads/Vicuna-13B-chinese.bin", "prompt": "你是三体文明的智子" }); return { "回答": fn (问题) { instance.infer(instance, 问题) } } }(); 智子.回答("中国最佳科幻小说是哪个")' diff --git a/interpreter/src/evaluator/builtins.rs b/interpreter/src/evaluator/builtins.rs index f3b69af..5a737a4 100644 --- a/interpreter/src/evaluator/builtins.rs +++ b/interpreter/src/evaluator/builtins.rs @@ -147,110 +147,6 @@ fn three_body_deep_equal(args: Vec) -> Object { } } -fn three_body_sophon_infer(args: Vec) -> Object { - match &args[0] { - Object::Hash(hash) => { - let model_ptr = match hash.get(&Object::String("model".to_owned())).unwrap() { - Object::Native(native_object) => { - match **native_object { - NativeObject::LLMModel(model_ptr) => { - model_ptr.clone() - }, - _ => panic!() - } - }, - _ => panic!() - }; - let character = hash.get(&Object::String("character".to_owned())).unwrap(); - let model = unsafe { & *model_ptr }; - - let mut session = model.start_session(Default::default()); - let meessage = format!("{}", &args[1]); - let prompt = &format!(" -下面是描述一项任务的说明。需要适当地完成请求的响应。 - -### 角色设定: - -{} - -### 提问: - -{} - -### 回答: - -", character, meessage); - - let sp = spinoff::Spinner::new(spinoff::spinners::Arc, "".to_string(), None); - - if let Err(llm::InferenceError::ContextFull) = session.feed_prompt::( - model, - &InferenceParameters { - ..Default::default() - }, - prompt, - &mut Default::default(), - |t| { - Ok(()) - }, - ) { - println!("Prompt exceeds context window length.") - }; - sp.clear(); - - let res = session.infer::( - model, - &mut thread_rng(), - &InferenceRequest { - prompt: "", - ..Default::default() - }, - // OutputRequest - &mut Default::default(), - |t| { - print!("{t}"); - std::io::stdout().flush().unwrap(); - - Ok(()) - }, - ); - - match res { - Err(err) => println!("\n{err}"), - _ => () - } - Object::Null - }, - _ => panic!() - } -} - - - -fn three_body_sophon_close(args: Vec) -> Object { - match &args[0] { - Object::Hash(hash) => { - let model_ptr = match hash.get(&Object::String("model".to_owned())).unwrap() { - Object::Native(native_object) => { - match **native_object { - NativeObject::LLMModel(model_ptr) => { - model_ptr.clone() - }, - _ => panic!() - } - }, - _ => panic!() - }; - // let model = unsafe { & *model_ptr }; - unsafe { Box::from_raw(model_ptr) }; - // std::mem::drop(model); - Object::Null - }, - _ => panic!() - } -} - - fn three_body_sophon_engineering(args: Vec) -> Object { match &args[0] { Object::Hash(o) => { @@ -314,13 +210,121 @@ fn three_body_sophon_engineering(args: Vec) -> Object { now.elapsed().as_millis() ); + + let model_ptr = &mut *model as *mut dyn Model; let mut session_hash = HashMap::new(); session_hash.insert(Object::String("model".to_owned()), Object::Native(Box::new(NativeObject::LLMModel(model_ptr)))); session_hash.insert(Object::String("character".to_owned()), Object::String(character.to_string())); - session_hash.insert(Object::String("infer".to_owned()), Object::Builtin(2, three_body_sophon_infer)); - session_hash.insert(Object::String("close".to_owned()), Object::Builtin(1, three_body_sophon_close)); + + { + + fn three_body_sophon_infer(args: Vec) -> Object { + match &args[0] { + Object::Hash(hash) => { + let model_ptr = match hash.get(&Object::String("model".to_owned())).unwrap() { + Object::Native(native_object) => { + match **native_object { + NativeObject::LLMModel(model_ptr) => { + model_ptr.clone() + }, + _ => panic!() + } + }, + _ => panic!() + }; + let character = hash.get(&Object::String("character".to_owned())).unwrap(); + let model = unsafe { & *model_ptr }; + + let mut session = model.start_session(Default::default()); + let meessage = format!("{}", &args[1]); + let prompt = &format!(" + 下面是描述一项任务的说明。需要适当地完成请求的响应。 + + ### 角色设定: + + {} + + ### 提问: + + {} + + ### 回答: + + ", character, meessage); + + let sp = spinoff::Spinner::new(spinoff::spinners::Arc, "".to_string(), None); + + if let Err(llm::InferenceError::ContextFull) = session.feed_prompt::( + model, + &InferenceParameters { + ..Default::default() + }, + prompt, + &mut Default::default(), + |t| { + Ok(()) + }, + ) { + println!("Prompt exceeds context window length.") + }; + sp.clear(); + + let res = session.infer::( + model, + &mut thread_rng(), + &InferenceRequest { + prompt: "", + ..Default::default() + }, + // OutputRequest + &mut Default::default(), + |t| { + print!("{t}"); + std::io::stdout().flush().unwrap(); + + Ok(()) + }, + ); + + match res { + Err(err) => println!("\n{err}"), + _ => () + } + Object::Null + }, + _ => panic!() + } + } + + + + fn three_body_sophon_close(args: Vec) -> Object { + match &args[0] { + Object::Hash(hash) => { + let model_ptr = match hash.get(&Object::String("model".to_owned())).unwrap() { + Object::Native(native_object) => { + match **native_object { + NativeObject::LLMModel(model_ptr) => { + model_ptr.clone() + }, + _ => panic!() + } + }, + _ => panic!() + }; + // let model = unsafe { & *model_ptr }; + unsafe { Box::from_raw(model_ptr) }; + // std::mem::drop(model); + Object::Null + }, + _ => panic!() + } + } + session_hash.insert(Object::String("infer".to_owned()), Object::Builtin(2, three_body_sophon_infer)); + session_hash.insert(Object::String("close".to_owned()), Object::Builtin(1, three_body_sophon_close)); + } Object::Hash(session_hash) } _ => Object::Null,