Skip to content

Commit

Permalink
fix llm_runner for tinyllama (#407)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 authored Feb 8, 2024
1 parent d14a9f6 commit a1b22ea
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions models/turbine_models/custom_models/llm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@
parser.add_argument(
"--prompt",
type=str,
default="""<s>[INST] <<SYS>>
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>> hi what are you? [/INST]
""",
default="hi what are you?",
help="prompt for llm model",
)
parser.add_argument(
Expand Down Expand Up @@ -183,14 +181,14 @@ def run_llm(
streaming_llm=streaming_llm,
)
if not chat_mode:
prompt = append_user_prompt(chat_sys_prompt, prompt)
initial_input = tokenizer(prompt, return_tensors="pt")
example_input_id = initial_input.input_ids
turbine_results = llm.generate(example_input_id)
return tokenizer.decode(turbine_results)
prompt = chat_sys_prompt
while True:
user_prompt = input("User prompt: ")
prompt = append_user_prompt(prompt, user_prompt)
prompt = append_user_prompt(chat_sys_prompt, user_prompt)
initial_input = tokenizer(prompt, return_tensors="pt")
example_input_id = initial_input.input_ids
result = llm.generate(example_input_id)
Expand All @@ -199,7 +197,13 @@ def run_llm(
prompt = append_bot_prompt(prompt, bot_response)


def run_torch_llm(hf_model_name, hf_auth_token, prompt, streaming_llm=False):
def run_torch_llm(
hf_model_name,
hf_auth_token,
prompt,
streaming_llm=False,
chat_sys_prompt=DEFAULT_CHAT_SYS_PROMPT,
):
from turbine_models.model_builder import HFTransformerBuilder
from transformers import AutoModelForCausalLM

Expand All @@ -210,13 +214,13 @@ def run_torch_llm(hf_model_name, hf_auth_token, prompt, streaming_llm=False):
hf_auth_token=hf_auth_token,
auto_tokenizer=AutoTokenizer,
)
model_builder.build_model()
if streaming_llm is True:
enable_llama_pos_shift_attention(model_builder.model)

def get_token_from_logits(logits):
return torch.argmax(logits[:, -1, :], dim=1)

prompt = append_user_prompt(chat_sys_prompt, prompt)
initial_input = model_builder.tokenizer(prompt, return_tensors="pt")
example_input_id = initial_input.input_ids

Expand Down Expand Up @@ -256,6 +260,10 @@ def get_token_from_logits(logits):
if args.compare_vs_torch:
print("generating torch output: ")
torch_output = run_torch_llm(
args.hf_model_name, args.hf_auth_token, args.prompt
args.hf_model_name,
args.hf_auth_token,
args.prompt,
args.streaming_llm,
args.chat_sys_prompt,
)
print(torch_output)

0 comments on commit a1b22ea

Please sign in to comment.