-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(rag-llm): chat and streams (#89)
* add chat stream settings, template and prompts * fix RAG stream * add collection name parameter to endpoint `/llm-rag-question` * update QAF for new RAG method * removed unused variables and lambdas * return RAG response as a response string - Converted response from chat deltas to a string. * remove unused imports and fix stream naming * refactor stream lambdas to function calls
- Loading branch information
Showing
8 changed files
with
364 additions
and
45 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from typing import Literal, Optional | ||
|
||
from langstream.contrib import OpenAIChatMessage | ||
|
||
from fai_backend.chat.settings import OpenAIStreamSettings, default_settings | ||
|
||
|
||
class MessageChatPrompt: | ||
""" | ||
Represents a chat prompt message. | ||
Args: | ||
template (str): The content of the message template. | ||
role (Optional[Literal['system', 'assistant', 'user', 'function']]): The role of the message. Defaults to None. | ||
name (Optional[str]): The name of the message. Defaults to None. | ||
""" | ||
|
||
def __init__(self, template: str, role: Optional[Literal['system', 'assistant', 'user', 'function']] = None, | ||
name: Optional[str] = None, input_map_fn=lambda input: {}): | ||
self.content = template | ||
self.template = template | ||
self.role = role | ||
self.name = name | ||
|
||
|
||
class UserChatPrompt(MessageChatPrompt): | ||
""" | ||
Represents a chat prompt message from the user. | ||
Args: | ||
template (str): The content of the message template. | ||
""" | ||
|
||
def __init__(self, template: str): | ||
super().__init__(template, "user") | ||
|
||
|
||
class SystemChatPrompt(MessageChatPrompt): | ||
""" | ||
Represents a chat prompt message from the system. | ||
Args: | ||
template (str): The content of the message template. | ||
""" | ||
|
||
def __init__(self, template: str): | ||
super().__init__(template, "system") | ||
|
||
|
||
class FunctionChatPrompt(MessageChatPrompt): | ||
""" | ||
Represents a chat prompt message from a function. | ||
Args: | ||
template (str): The content of the message template. | ||
name (str): The name of the function. | ||
""" | ||
|
||
def __init__(self, template: str, name: str): | ||
super().__init__(template, "function", name) | ||
|
||
|
||
class AssistantChatPrompt(MessageChatPrompt): | ||
""" | ||
Represents a chat prompt message from the assistant. | ||
Args: | ||
template (str): The content of the message template. | ||
""" | ||
|
||
def __init__(self, template: str): | ||
super().__init__(template, "assistant") | ||
|
||
|
||
class ChatPrompt: | ||
""" | ||
Represents a chat prompt. | ||
Args: | ||
name (str): The name of the chat prompt. | ||
messages (list[MessageChatPrompt], optional): The list of chat prompt messages. Defaults to []. | ||
settings (OpenAIStreamSettings, optional): The settings for the chat prompt. Defaults to default_settings. | ||
""" | ||
|
||
def __init__(self, name: str, messages: list[MessageChatPrompt] = [], | ||
settings: OpenAIStreamSettings = default_settings): | ||
self.templates = messages | ||
self.input_vars = {} | ||
self.settings = settings | ||
self.name = name | ||
|
||
def format_prompt(self, input_vars: dict[str, str]): | ||
""" | ||
Formats the chat prompt with the given input variables. | ||
Args: | ||
input_vars (dict[str, str]): The input variables to format the chat prompt with. | ||
""" | ||
self.input_vars = input_vars | ||
print(input_vars) | ||
for template in self.templates: | ||
template.content = template.template.format(**{**input_vars}) | ||
return self | ||
|
||
def to_messages(self) -> list[OpenAIChatMessage]: | ||
""" | ||
Converts the chat prompt to a list of OpenAIChatMessage objects. | ||
Returns: | ||
list[OpenAIChatMessage]: The list of OpenAIChatMessage objects. | ||
""" | ||
return [ | ||
OpenAIChatMessage(content=prompt.content, role=prompt.role, name=prompt.name) # type: ignore | ||
for prompt in self.templates | ||
] | ||
|
||
# def to_prompt(self) -> Prompt: | ||
# """ | ||
# Converts the chat prompt to a Prompt object. | ||
# | ||
# Returns: | ||
# Prompt: The Prompt object. | ||
# """ | ||
# return Prompt( | ||
# provider=ChatOpenAI.id, | ||
# inputs=self.input_vars, | ||
# settings={**(self.settings if self.settings else {})}, | ||
# messages=[ | ||
# PromptMessage(template=prompt.template, formatted=prompt.content, role=prompt.role, name=prompt.name) | ||
# # type: ignore | ||
# for prompt in self.templates | ||
# ], | ||
# ) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from typing import TypedDict | ||
|
||
OpenAIStreamSettings = TypedDict('OpenAIStreamSettings', { | ||
'model': str, | ||
'temperature': float | ||
}) | ||
|
||
default_settings = OpenAIStreamSettings(model='gpt-3.5-turbo', temperature=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from typing import Tuple, TypeVar, Iterable, Callable | ||
|
||
from langstream.contrib import OpenAIChatStream, OpenAIChatMessage | ||
|
||
from fai_backend.chat.prompt import ChatPrompt | ||
from fai_backend.chat.settings import OpenAIStreamSettings, default_settings | ||
|
||
T = TypeVar('T') | ||
U = TypeVar('U') | ||
|
||
|
||
def create_chat_prompt(prompt_args: dict) -> ChatPrompt: | ||
""" | ||
Create a ChatPrompt object. | ||
Args: | ||
prompt_args (dict): Arguments for creating a ChatPrompt object. | ||
Returns: | ||
ChatPrompt: The ChatPrompt object. | ||
""" | ||
return ChatPrompt(**prompt_args) | ||
|
||
|
||
def create_chat_stream( | ||
name: str, | ||
messages_fn: Callable[[T], Iterable[OpenAIChatMessage]], | ||
settings: OpenAIStreamSettings = default_settings, | ||
) -> OpenAIChatStream[T, U]: | ||
""" | ||
Create a chat stream. | ||
Args: | ||
name (str): The name of the chat stream. | ||
settings (OpenAIStreamSettings): Settings for the chat stream. | ||
input_map_fn (Callable[[T], U], optional): Function that maps input of type T to output of type U. Defaults to identity function. | ||
Returns: | ||
OpenAIChatStream[T, U]: The chat stream. | ||
""" | ||
return OpenAIChatStream[T, U]( | ||
name, | ||
lambda delta: [*messages_fn(delta)], | ||
**settings | ||
) | ||
|
||
|
||
def create_chat_stream_from_prompt( | ||
prompt_args: dict, | ||
) -> Tuple[OpenAIChatStream[T, U], ChatPrompt]: | ||
""" | ||
Create a general chat stream with a prompt. | ||
Args: | ||
prompt_args (dict): Arguments for creating a ChatPrompt object. | ||
settings (OpenAIStreamSettings): Settings for the chat stream. | ||
history (list[OpenAIChatMessage], optional): Chat history. Defaults to an empty list. | ||
input_map_fn (Callable[[T], U], optional): Function that maps input of type T to output of type U. Defaults to identity function. | ||
Returns: | ||
tuple[OpenAIChatStream[T, U], ChatPrompt]: A tuple containing the chat stream and prompt objects. | ||
""" | ||
|
||
input_map_fn = prompt_args["input_map_fn"] if "input_map_fn" in prompt_args else lambda x: x | ||
|
||
prompt = create_chat_prompt({ | ||
"name": prompt_args["name"], | ||
"messages": prompt_args["messages"], | ||
"settings": prompt_args["settings"] if "settings" in prompt_args else default_settings, | ||
}) | ||
|
||
def messages(p: T) -> Iterable[OpenAIChatMessage]: | ||
prompt.format_prompt(input_map_fn(p)) | ||
|
||
return prompt.to_messages() | ||
|
||
chat_stream = create_chat_stream( | ||
prompt.name, | ||
messages, | ||
prompt.settings | ||
) | ||
|
||
return chat_stream, prompt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import os | ||
from fai_backend.chat.prompt import UserChatPrompt, SystemChatPrompt | ||
|
||
CHAT_PROMPT_TEMPLATE_ARGS = { | ||
"name": "ChatStream", | ||
"messages": [ | ||
SystemChatPrompt( | ||
"You are a helpful AI assistant that helps people with answering questions about planning " | ||
"permission.<br> If you can't find the answer in the search result below, just say (in Swedish) " | ||
"\"Tyvärr kan jag inte svara på det.\" Don't try to make up an answer.<br> If the " | ||
"question is not related to the context, politely respond that you are tuned to only " | ||
"answer questions that are related to the context.<br> The questions are going to be " | ||
"asked in Swedish. Your response must always be in Swedish." | ||
), | ||
UserChatPrompt("{query}"), | ||
UserChatPrompt("Here are the results of the search:\n\n {results}"), | ||
], | ||
"input_map_fn": lambda input: { | ||
"query": list(input)[0]['query'], | ||
"results": ' | '.join([doc for doc, _ in list(input)[0]['results']]) | ||
}, | ||
"settings": { | ||
"model": os.environ.get("GPT_4_MODEL_NAME", "gpt-4"), | ||
"temperature": 0 | ||
}, | ||
} | ||
|
||
SCORING_PROMPT_TEMPLATE_ARGS = { | ||
"name": "ScoringStream", | ||
"messages": [ | ||
SystemChatPrompt("You are a scoring systems that classifies documents from 0-100 based on how well they answer a query."), | ||
UserChatPrompt("Query: {query}\n\nDocument: {document}"), | ||
], | ||
"input_map_fn": lambda input: {**(input)}, | ||
"settings": { | ||
"model": "gpt-3.5-turbo", | ||
"temperature": 0, | ||
"functions": [ | ||
{ | ||
"name": "score_document", | ||
"description": "Scores the previous document according to the user query\n\n Parameters\n ----------\n score\n A number from 0-100 scoring how well does the document matches the query. The higher the score, the better match for the query\n ", | ||
"parameters": { | ||
"type": "object", | ||
"properties": { | ||
"score": { | ||
"type": "number", | ||
} | ||
}, | ||
"required": ["score"], | ||
} | ||
} | ||
], | ||
"function_call": {"name": "score_document"}, | ||
}, | ||
} |
Oops, something went wrong.