From a1b22ea197212e08af8a3d52d104c4cb30390297 Mon Sep 17 00:00:00 2001 From: jinchen62 <49575973+jinchen62@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:30:52 -0800 Subject: [PATCH] fix llm_runner for tinyllama (#407) --- .../custom_models/llm_runner.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/models/turbine_models/custom_models/llm_runner.py b/models/turbine_models/custom_models/llm_runner.py index 7632d1e65..3e3b82b3b 100644 --- a/models/turbine_models/custom_models/llm_runner.py +++ b/models/turbine_models/custom_models/llm_runner.py @@ -50,9 +50,7 @@ parser.add_argument( "--prompt", type=str, - default="""[INST] <> -Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] -""", + default="hi what are you?", help="prompt for llm model", ) parser.add_argument( @@ -183,14 +181,14 @@ def run_llm( streaming_llm=streaming_llm, ) if not chat_mode: + prompt = append_user_prompt(chat_sys_prompt, prompt) initial_input = tokenizer(prompt, return_tensors="pt") example_input_id = initial_input.input_ids turbine_results = llm.generate(example_input_id) return tokenizer.decode(turbine_results) - prompt = chat_sys_prompt while True: user_prompt = input("User prompt: ") - prompt = append_user_prompt(prompt, user_prompt) + prompt = append_user_prompt(chat_sys_prompt, user_prompt) initial_input = tokenizer(prompt, return_tensors="pt") example_input_id = initial_input.input_ids result = llm.generate(example_input_id) @@ -199,7 +197,13 @@ def run_llm( prompt = append_bot_prompt(prompt, bot_response) -def run_torch_llm(hf_model_name, hf_auth_token, prompt, streaming_llm=False): +def run_torch_llm( + hf_model_name, + hf_auth_token, + prompt, + streaming_llm=False, + chat_sys_prompt=DEFAULT_CHAT_SYS_PROMPT, +): from turbine_models.model_builder import HFTransformerBuilder from transformers import AutoModelForCausalLM @@ -210,13 +214,13 @@ def run_torch_llm(hf_model_name, hf_auth_token, prompt, streaming_llm=False): hf_auth_token=hf_auth_token, auto_tokenizer=AutoTokenizer, ) - model_builder.build_model() if streaming_llm is True: enable_llama_pos_shift_attention(model_builder.model) def get_token_from_logits(logits): return torch.argmax(logits[:, -1, :], dim=1) + prompt = append_user_prompt(chat_sys_prompt, prompt) initial_input = model_builder.tokenizer(prompt, return_tensors="pt") example_input_id = initial_input.input_ids @@ -256,6 +260,10 @@ def get_token_from_logits(logits): if args.compare_vs_torch: print("generating torch output: ") torch_output = run_torch_llm( - args.hf_model_name, args.hf_auth_token, args.prompt + args.hf_model_name, + args.hf_auth_token, + args.prompt, + args.streaming_llm, + args.chat_sys_prompt, ) print(torch_output)