Skip to content

Commit

Permalink
feat: add ChatLlamaCpp LLM using local llama-cpp-python inference engine
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjoyo committed Apr 22, 2024
1 parent d998b35 commit 66dffbe
Show file tree
Hide file tree
Showing 9 changed files with 405 additions and 97 deletions.
Empty file.
8 changes: 8 additions & 0 deletions bpm_ai_inference/llm/llama_cpp/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

DEFAULT_MODEL = "QuantFactory/dolphin-2.9-llama3-8b-GGUF"
DEFAULT_QUANT_LARGE = "*Q8_0.gguf"
DEFAULT_QUANT_BALANCED = "*Q4_K_M.gguf"
DEFAULT_QUANT_SMALL = "*Q2_K.gguf"
DEFAULT_TEMPERATURE = 0.0
DEFAULT_MAX_RETRIES = 8
a
203 changes: 203 additions & 0 deletions bpm_ai_inference/llm/llama_cpp/llama_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import json
import logging
import re
from typing import Dict, Any, Optional, List

from bpm_ai_core.llm.common.llm import LLM
from bpm_ai_core.llm.common.message import ChatMessage, AssistantMessage, SystemMessage, ToolCallMessage
from bpm_ai_core.llm.common.tool import Tool
from bpm_ai_core.llm.openai_chat.util import messages_to_openai_dicts
from bpm_ai_core.prompt.prompt import Prompt
from bpm_ai_core.tracing.tracing import Tracing
from bpm_ai_core.util.json_schema import expand_simplified_json_schema
from llama_cpp.llama_grammar import json_schema_to_gbnf, LlamaGrammar

from bpm_ai_inference.llm.llama_cpp._constants import DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_MAX_RETRIES, \
DEFAULT_QUANT_BALANCED
from bpm_ai_inference.llm.llama_cpp.util import messages_to_llama_dicts

logger = logging.getLogger(__name__)

try:
from llama_cpp import Llama, CreateChatCompletionResponse, llama_grammar

has_llama_cpp_python = True
except ImportError:
has_llama_cpp_python = False


class ChatLlamaCpp(LLM):
"""
Local open-weight chat large language models based on `llama-cpp-python` running on CPU.
To use, you should have the ``llama-cpp-python`` python package installed (and enough available RAM).
"""

def __init__(
self,
model: str = DEFAULT_MODEL,
filename: str = DEFAULT_QUANT_BALANCED,
temperature: float = DEFAULT_TEMPERATURE,
max_retries: int = DEFAULT_MAX_RETRIES,
):
if not has_llama_cpp_python:
raise ImportError('llama-cpp-python is not installed')
super().__init__(
model=model,
temperature=temperature,
max_retries=max_retries,
retryable_exceptions=[]
)
self.llm = Llama.from_pretrained(
repo_id=model,
filename=filename,
n_ctx=4096,
verbose=False
)

async def _generate_message(
self,
messages: List[ChatMessage],
output_schema: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
stop: list[str] = None,
current_try: int = None
) -> AssistantMessage:
completion = await self._run_completion(messages, output_schema, tools, stop, current_try)
message = completion["choices"][0]["message"]
if output_schema:
return AssistantMessage(content=self._parse_json(message["content"]))
elif tools:
tool_call = self._parse_json(message["content"])
return AssistantMessage(
tool_calls=[ToolCallMessage(
id=tool_call["name"],
name=tool_call["name"],
payload=tool_call["arguments"]
)]
)
else:
return AssistantMessage(content=message["content"])

async def _run_completion(
self,
messages: List[ChatMessage],
output_schema: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
stop: list[str] = None,
current_try: int = None
) -> CreateChatCompletionResponse:
messages = await messages_to_llama_dicts(messages)

grammar = None
prefix = None

if output_schema:
output_schema = expand_simplified_json_schema(output_schema)
grammar = self._bnf_grammar_for_json_schema(output_schema)
output_prompt = Prompt.from_file("output_schema", output_schema=json.dumps(output_schema, indent=2))
output_prompt = output_prompt.format()[0].content
if messages[0]["role"] == "system":
messages[0]["content"] += f"\n\n{output_prompt}"
else:
messages.insert(0, {"role": "system", "content": output_prompt})

elif tools:
grammar = self._bnf_grammar_for_json_schema(self._tool_call_json_schema(tools))
grammar = self._extend_root_rule(grammar)
tool_use_prompt = Prompt.from_file("tool_use", tool_schemas=json.dumps([self._get_function_schema(t) for t in tools], indent=2))
tool_use_prompt = tool_use_prompt.format()[0].content
if messages[0]["role"] == "system":
messages[0]["content"] += f"\n\n{tool_use_prompt}"
else:
messages.insert(0, {"role": "system", "content": tool_use_prompt})
if messages[-1]["role"] == "assistant":
logger.warning("Ignoring trailing assistant message.")
messages.pop()
prefix = "<tool_call>"
stop = ["</tool_call>"]

Tracing.tracers().start_llm_trace(self, messages, current_try, tools or ({"output_schema": output_schema} if output_schema else None))
completion: CreateChatCompletionResponse = self.llm.create_chat_completion(
messages=messages,
stop=stop or [],
grammar=LlamaGrammar.from_string(grammar, verbose=False) if grammar else None,
temperature=self.temperature,
)
completion["choices"][0]["message"]["content"] = completion["choices"][0]["message"]["content"].removeprefix(prefix or "").strip()
Tracing.tracers().end_llm_trace(completion["choices"][0]["message"])
return completion

@staticmethod
def _extend_root_rule(gbnf_string: str):
root_rule_pattern = r'(root\s*::=\s*)("\{"[^}]*"\}")'
def replace_root_rule(match):
prefix = match.group(1)
json_content = match.group(2)
extended_rule = f'{prefix}"<tool_call>" space {json_content} space "</tool_call>"'
return extended_rule
extended_gbnf = re.sub(root_rule_pattern, replace_root_rule, gbnf_string)
return extended_gbnf

@staticmethod
def _get_function_schema(tool: Tool) -> dict:
schema = tool.args_schema
return {
'type': 'function',
'function': {
'name': tool.name,
'description': tool.description,
**schema
}
}

@staticmethod
def _tool_call_json_schema(tools: list[Tool]) -> dict:
return {
"type": "object",
"properties": {
"name": {
"type": "string",
"enum": [t.name for t in tools]
},
"arguments": {
"oneOf": [t.args_schema for t in tools]
}
},
"required": ["name", "arguments"]
}

@staticmethod
def _bnf_grammar_for_json_schema(
json_schema: dict,
fallback_to_generic_json: bool = True
) -> str:
try:
schema_str = json.dumps(json_schema)
return json_schema_to_gbnf(schema_str)
except Exception as e:
if fallback_to_generic_json:
logger.warning("Exception while converting json schema to gbnf, falling back to generic json grammar.")
return llama_grammar.JSON_GBNF
else:
raise e

@staticmethod
def _parse_json(content: str) -> dict | None:
try:
json_object = json.loads(content)
except ValueError:
json_object = None
return json_object

def supports_images(self) -> bool:
return False

def supports_video(self) -> bool:
return False

def supports_audio(self) -> bool:
return False

def name(self) -> str:
return "llama"
7 changes: 7 additions & 0 deletions bpm_ai_inference/llm/llama_cpp/output_schema.prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
<output_instructions>
Output your result as a valid JSON object precisely following the JSON schema given in <schema></schema> tags below.
</output_instructions>

<schema>
{{output_schema}}
</schema>
15 changes: 15 additions & 0 deletions bpm_ai_inference/llm/llama_cpp/tool_use.prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
You are provided with tool signatures within <tools></tools> XML tags.
Please call a tool and wait for tool results to be provided to you in the next iteration.
Don't make assumptions about what values to plug into tool arguments.
Once you have called a tool, results will be fed back to you within <tool_response></tool_response> XML tags.
Don't make assumptions about tool results if <tool_response> XML tags are not present since the tool hasn't been executed yet.
Analyze the data once you get the results and call another tool.

Here are the available tools:
<tools>
{{tool_schemas}}
</tools>

Follow this json schema for each tool call you will make: {"type": "object", "properties": {"name": {"type": "string"}, "arguments": {"type": "object"}}, "required": ["name", "arguments"]}
For each tool call return a json object with tool name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call> {"name": <tool-name>, "arguments": <args-dict>} </tool_call>
28 changes: 28 additions & 0 deletions bpm_ai_inference/llm/llama_cpp/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import json
import logging

from bpm_ai_core.llm.common.message import ChatMessage, AssistantMessage, ToolResultMessage

logger = logging.getLogger(__name__)


async def messages_to_llama_dicts(messages: list[ChatMessage]):
return [await message_to_llama_dict(m) for m in messages]


async def message_to_llama_dict(message: ChatMessage) -> dict:
if isinstance(message, AssistantMessage) and message.has_tool_calls():
tool_call = message.tool_calls[0]
tool_content = json.dumps(tool_call.payload_dict())
content = '<tool_call>\n{"name": "' + tool_call.name + '", "arguments": ' + tool_content + '}\n</tool_call>'
elif isinstance(message, ToolResultMessage):
tool_response_content = f"{message.content}"
content = '<tool_response>\n{"name": "' + message.name + '", "content": ' + tool_response_content + '}\n</tool_response>'
else:
content = message.content

return {
"role": message.role,
**({"content": content} if content else {}),
**({"name": message.name} if message.name else {})
}
3 changes: 2 additions & 1 deletion bpm_ai_inference/util/optimum.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def _optimize(repository_id: str, model_dir, task, push_to_hub=False):
if push_to_hub:
config = AutoOptimizationConfig.O2()
else:
config = OptimizationConfig(optimization_level=99) # enable all optimizations
config = AutoOptimizationConfig.O2()
#config = OptimizationConfig(optimization_level=99) # enable all optimizations
optimizer.optimize(
optimization_config=config,
save_dir=model_dir
Expand Down
Loading

0 comments on commit 66dffbe

Please sign in to comment.