Skip to content

Commit

Permalink
[Core] use llama cpp (gguf file format) chat template (#8)
Browse files Browse the repository at this point in the history
* feat: use chat template for gguf model (llama.cpp model)

* change: variable name ('model' to 'llm') in NPCStepper.generate_local()

---------

Co-authored-by: A-Mahla <>
  • Loading branch information
A-Mahla authored Jun 17, 2024
1 parent d004de3 commit 6fa19e3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
11 changes: 11 additions & 0 deletions gigax/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
Location,
ProtagonistCharacter,
)
from typing import Literal
from gigax.parse import CharacterAction
from jinja2 import Template


@outlines.prompt
Expand Down Expand Up @@ -45,3 +47,12 @@ def NPCPrompt(
{{ protagonist.name }}:
"""


def llama_chat_template(
message: list[dict[Literal["role", "content"], str]],
bos_token: str,
chat_template: str,
):
tpl = Template(chat_template)
return tpl.render(messages=message, bos_token=bos_token)
39 changes: 23 additions & 16 deletions gigax/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import traceback

from openai import AsyncOpenAI
from gigax.prompt import NPCPrompt
from gigax.prompt import NPCPrompt, llama_chat_template
from gigax.scene import (
Character,
Item,
Expand Down Expand Up @@ -33,6 +33,11 @@ def __init__(
if isinstance(model, str) and not self.api_key:
raise ValueError("You must provide an API key to use our API.")

if not isinstance(model, (models.LlamaCpp, models.Transformers)):
raise NotImplementedError(
"Only LlamaCpp and Transformers models are supported in local mode for now."
)

async def generate_api(
self,
model: str,
Expand Down Expand Up @@ -69,27 +74,29 @@ async def generate_api(
async def generate_local(
self,
prompt: str,
model: models.LogitsGenerator,
llm: models.LogitsGenerator,
guided_regex: str,
) -> str:
if not isinstance(model, (models.LlamaCpp, models.Transformers)): # type: ignore
raise NotImplementedError(
"Only LlamaCpp and Transformers models are supported in local mode for now."
)

# Time the query
start = time.time()

generator = regex(model, guided_regex)
if isinstance(model, models.LlamaCpp): # type: ignore
generator = regex(llm, guided_regex)
messages = [
{"role": "user", "content": f"{prompt}"},
]
if isinstance(llm, models.LlamaCpp): # type: ignore

# Llama-cpp-python has a convenient create_chat_completion() method that guesses the chat prompt
# But outlines does not support it for generation, so we do this ugly hack instead
chat_prompt = f"<|user|>\n{prompt}<|end|>\n<|assistant|>"
elif isinstance(model, models.Transformers): # type: ignore
messages = [
{"role": "user", "content": f"{prompt}"},
]
chat_prompt = model.tokenizer.tokenizer.apply_chat_template(
bos_token = llm.model._model.token_get_text(
int(llm.model.metadata["tokenizer.ggml.bos_token_id"])
)
chat_prompt = llama_chat_template(
messages, bos_token, llm.model.metadata["tokenizer.chat_template"]
)

elif isinstance(llm, models.Transformers): # type: ignore
chat_prompt = llm.tokenizer.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
Expand Down Expand Up @@ -136,7 +143,7 @@ async def get_action(
guided_regex = get_guided_regex(protagonist.skills, NPCs, locations, items)

# Generate the response
if isinstance(self.model, models.LogitsGenerator): # type: ignore
if isinstance(self.model, models.LogitsGenerator):
res = await self.generate_local(
prompt,
self.model,
Expand Down

0 comments on commit 6fa19e3

Please sign in to comment.