Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] use llama cpp (gguf file format) chat template #8

Merged
merged 3 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions gigax/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
Location,
ProtagonistCharacter,
)
from typing import Literal
from gigax.parse import CharacterAction
from jinja2 import Template


@outlines.prompt
Expand Down Expand Up @@ -45,3 +47,12 @@ def NPCPrompt(

{{ protagonist.name }}:
"""


def llama_chat_template(
message: list[dict[Literal["role", "content"], str]],
bos_token: str,
chat_template: str,
):
tpl = Template(chat_template)
return tpl.render(messages=message, bos_token=bos_token)
39 changes: 23 additions & 16 deletions gigax/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import traceback

from openai import AsyncOpenAI
from gigax.prompt import NPCPrompt
from gigax.prompt import NPCPrompt, llama_chat_template
from gigax.scene import (
Character,
Item,
Expand Down Expand Up @@ -33,6 +33,11 @@ def __init__(
if isinstance(model, str) and not self.api_key:
raise ValueError("You must provide an API key to use our API.")

if not isinstance(model, (models.LlamaCpp, models.Transformers)):
raise NotImplementedError(
"Only LlamaCpp and Transformers models are supported in local mode for now."
)

async def generate_api(
self,
model: str,
Expand Down Expand Up @@ -69,27 +74,29 @@ async def generate_api(
async def generate_local(
self,
prompt: str,
model: models.LogitsGenerator,
llm: models.LogitsGenerator,
guided_regex: str,
) -> str:
if not isinstance(model, (models.LlamaCpp, models.Transformers)): # type: ignore
raise NotImplementedError(
"Only LlamaCpp and Transformers models are supported in local mode for now."
)

# Time the query
start = time.time()

generator = regex(model, guided_regex)
if isinstance(model, models.LlamaCpp): # type: ignore
generator = regex(llm, guided_regex)
messages = [
{"role": "user", "content": f"{prompt}"},
]
if isinstance(llm, models.LlamaCpp): # type: ignore

# Llama-cpp-python has a convenient create_chat_completion() method that guesses the chat prompt
# But outlines does not support it for generation, so we do this ugly hack instead
chat_prompt = f"<|user|>\n{prompt}<|end|>\n<|assistant|>"
elif isinstance(model, models.Transformers): # type: ignore
messages = [
{"role": "user", "content": f"{prompt}"},
]
chat_prompt = model.tokenizer.tokenizer.apply_chat_template(
bos_token = llm.model._model.token_get_text(
int(llm.model.metadata["tokenizer.ggml.bos_token_id"])
)
chat_prompt = llama_chat_template(
messages, bos_token, llm.model.metadata["tokenizer.chat_template"]
)

elif isinstance(llm, models.Transformers): # type: ignore
chat_prompt = llm.tokenizer.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
Expand Down Expand Up @@ -136,7 +143,7 @@ async def get_action(
guided_regex = get_guided_regex(protagonist.skills, NPCs, locations, items)

# Generate the response
if isinstance(self.model, models.LogitsGenerator): # type: ignore
if isinstance(self.model, models.LogitsGenerator):
res = await self.generate_local(
prompt,
self.model,
Expand Down
Loading