Skip to content

Commit

Permalink
✨ Feat(chat): integrate with rag capabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterKenth committed May 20, 2024
1 parent 7e3ca39 commit d89e562
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 98 deletions.
78 changes: 74 additions & 4 deletions fai-rag-app/fai-backend/fai_backend/chat/template.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,74 @@
import dataclasses
import os
from fai_backend.chat.prompt import UserChatPrompt, SystemChatPrompt
from typing import List, Any, Callable

from fai_backend.chat.prompt import UserChatPrompt, SystemChatPrompt, MessageChatPrompt


@dataclasses.dataclass
class PromptTemplate:
name: str
messages: List[MessageChatPrompt]

# Takes the input T from a Stream[T,U] and returns a dict of replacement variables
input_map_fn: Callable[[Any], dict[str, str]]

settings: dict[str, Any]


chatPromptTemplate = PromptTemplate(
name="ChatStream",
messages=[
SystemChatPrompt(
"You are a helpful AI assistant that helps people with answering questions about planning "
"permission.<br> If you can't find the answer in the search result below, just say (in Swedish) "
"\"Tyvärr kan jag inte svara på det.\" Don't try to make up an answer.<br> If the "
"question is not related to the context, politely respond that you are tuned to only "
"answer questions that are related to the context.<br> The questions are going to be "
"asked in Swedish. Your response must always be in Swedish."
),
UserChatPrompt("{query}"),
UserChatPrompt("Here are the results of the search:\n\n {results}"),
],
input_map_fn=lambda input: {
"query": list(input)[0]['query'],
"results": ' | '.join([doc for doc, _ in list(input)[0]['results']])
},
settings={
"model": os.environ.get("GPT_4_MODEL_NAME", "gpt-4"),
"temperature": 0
}
)

scoringPromptTemplate = PromptTemplate(
name="ScoringStream",
messages=[
SystemChatPrompt(
"You are a scoring systems that classifies documents from 0-100 based on how well they answer a query."),
UserChatPrompt("Query: {query}\n\nDocument: {document}"),
],
input_map_fn=lambda input: {**(input)},
settings={
"model": "gpt-3.5-turbo",
"temperature": 0,
"functions": [
{
"name": "score_document",
"description": "Scores the previous document according to the user query\n\n Parameters\n ----------\n score\n A number from 0-100 scoring how well does the document matches the query. The higher the score, the better match for the query\n ",
"parameters": {
"type": "object",
"properties": {
"score": {
"type": "number",
}
},
"required": ["score"],
}
}
],
"function_call": {"name": "score_document"},
}
)

CHAT_PROMPT_TEMPLATE_ARGS = {
"name": "ChatStream",
Expand All @@ -20,15 +89,16 @@
"results": ' | '.join([doc for doc, _ in list(input)[0]['results']])
},
"settings": {
"model": os.environ.get("GPT_4_MODEL_NAME", "gpt-4"),
"model": "gpt-4o",
"temperature": 0
},
}

SCORING_PROMPT_TEMPLATE_ARGS = {
"name": "ScoringStream",
"messages": [
SystemChatPrompt("You are a scoring systems that classifies documents from 0-100 based on how well they answer a query."),
SystemChatPrompt(
"You are a scoring systems that classifies documents from 0-100 based on how well they answer a query."),
UserChatPrompt("Query: {query}\n\nDocument: {document}"),
],
"input_map_fn": lambda input: {**(input)},
Expand All @@ -52,4 +122,4 @@
],
"function_call": {"name": "score_document"},
},
}
}
17 changes: 7 additions & 10 deletions fai-rag-app/fai-backend/fai_backend/llm/impl/openai.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from typing import AsyncGenerator

from langstream import StreamOutput, Stream
from langstream import Stream
from langstream.contrib import OpenAIChatStream, OpenAIChatDelta, OpenAIChatMessage

from fai_backend.llm.protocol import ILLMProtocol
from fai_backend.llm.protocol import ILLMStreamProtocol
from fai_backend.llm.models import LLMDataPacket


class OpenAILLM(ILLMProtocol):
class OpenAILLM(ILLMStreamProtocol):

def __init__(self, template: str):
self.template = template

def run(self, input_message: str) -> AsyncGenerator[StreamOutput[str], str]:
llm_stream: Stream[str, str] = OpenAIChatStream[str, OpenAIChatDelta](
async def create(self) -> Stream[str, LLMDataPacket]:
return OpenAIChatStream[str, OpenAIChatDelta](
"RecipeStream",
lambda user_question: [
OpenAIChatMessage(
Expand All @@ -26,6 +25,4 @@ def run(self, input_message: str) -> AsyncGenerator[StreamOutput[str], str]:
],
model="gpt-4",
temperature=0,
).map(lambda delta: delta.content)

return llm_stream(input_message)
).map(lambda delta: LLMDataPacket(content=delta.content, user_friendly=True))
15 changes: 7 additions & 8 deletions fai-rag-app/fai-backend/fai_backend/llm/impl/parrot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import asyncio
from random import uniform
from typing import AsyncGenerator

from langstream import StreamOutput, Stream
from langstream import Stream

from fai_backend.llm.protocol import ILLMProtocol
from fai_backend.llm.protocol import ILLMStreamProtocol
from fai_backend.llm.models import LLMDataPacket


class ParrotLLM(ILLMProtocol):
class ParrotLLM(ILLMStreamProtocol):
"""
Parrot (mock) LLM protocol reference implementation.
Expand All @@ -25,9 +25,8 @@ async def to_generator(self, input_message: str):
yield part
await asyncio.sleep(uniform(self.min_delay, self.max_delay))

def run(self, input_message: str) -> AsyncGenerator[StreamOutput[str], str]:
stream = Stream[str, str](
async def create(self) -> Stream[str, LLMDataPacket]:
return Stream[str, str](
"ParrotStream",
self.to_generator
)
return stream(input_message)
).map(lambda delta: LLMDataPacket(content=delta, user_friendly=True))
26 changes: 26 additions & 0 deletions fai-rag-app/fai-backend/fai_backend/llm/impl/rag_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from langstream import Stream

from fai_backend.llm.protocol import ILLMStreamProtocol
from fai_backend.llm.models import LLMDataPacket
from fai_backend.llm.service import create_rag_stream


class RAGWrapper(ILLMStreamProtocol):
"""
Wraps an underlying Stream with RAG capabilities.
The underlying stream will be supplied with document extracts in plaintext
from the given collection along with the original question.
"""

def __init__(self, input_query: str, base_llm: ILLMStreamProtocol, rag_collection_name: str):
self.input_query = input_query
self.rag_collection_name = rag_collection_name
self.base_llm = base_llm

async def create(self) -> Stream[str, LLMDataPacket]:
rag_stream = await create_rag_stream(self.input_query, self.rag_collection_name)
base_stream = await self.base_llm.create()

return (rag_stream
.and_then(base_stream))
7 changes: 7 additions & 0 deletions fai-rag-app/fai-backend/fai_backend/llm/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
from datetime import datetime

from pydantic import BaseModel
Expand All @@ -8,3 +9,9 @@ class LLMMessage(BaseModel):
date: datetime
source: str | None = None
content: str | None = None


@dataclasses.dataclass
class LLMDataPacket:
content: str
user_friendly: bool = False
20 changes: 8 additions & 12 deletions fai-rag-app/fai-backend/fai_backend/llm/protocol.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
from typing import Protocol, AsyncGenerator
from typing import Protocol

from langstream import Stream, StreamOutput
from langstream import Stream

from fai_backend.llm.models import LLMDataPacket

class ILLMProtocol(Protocol):
def run(self, input_message: str) -> AsyncGenerator[StreamOutput[str], str]:
"""
Takes an input message and returns an async stream resolving to the response from the LLM.
:param input_message: The input message, as a string.
:type input_message: str
:return: The resulting `Stream[str, str]` object.
:rtype: Stream[str, str]

class ILLMStreamProtocol(Protocol):
async def create(self) -> Stream[str, LLMDataPacket]:
"""
Create a Stream that takes a str (generally a question) and returns
a stream of tokens (strings) of the response given by the LLM.
"""
...
90 changes: 44 additions & 46 deletions fai-rag-app/fai-backend/fai_backend/llm/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fai_backend.config import settings
from fai_backend.llm.impl.openai import OpenAILLM
from fai_backend.llm.impl.parrot import ParrotLLM
from fai_backend.llm.protocol import ILLMProtocol
from fai_backend.llm.protocol import ILLMStreamProtocol
from fai_backend.vector.service import VectorService
from fai_backend.vector.factory import vector_db

Expand All @@ -22,7 +22,7 @@

class LLMFactory:
@staticmethod
def get(backend: str = settings.LLM_BACKEND) -> ILLMProtocol:
def get(backend: str = settings.LLM_BACKEND) -> ILLMStreamProtocol:
return {
"parrot": lambda: ParrotLLM(),
"openai": lambda: OpenAILLM(template=SYSTEM_TEMPLATE),
Expand Down Expand Up @@ -64,31 +64,14 @@ async def ask_llm_question(question: str):
return await join_final_output(llm_stream(question))


async def ask_llm_raq_question(question: str, collection_name: str):
add_document, list_documents = (lambda documents: (
lambda document: (documents.append(document), document)[1],
lambda: [*documents]
))([])

chat_stream, _ = create_chat_stream_from_prompt(CHAT_PROMPT_TEMPLATE_ARGS)
scoring_stream, _ = create_chat_stream_from_prompt(SCORING_PROMPT_TEMPLATE_ARGS)
async def create_rag_stream(query: str, vector_collection_name: str) -> Stream[str, str]:
vector_service = VectorService(vector_db=vector_db)

def append_score_to_documents(scores):
return zip(list_documents(), [s[0] for s in scores])

def sort_and_slice_documents(scored_documents, slice_size: int):
first_element = list(scored_documents)[0]
sorted_scores = sorted(first_element, key=lambda x: x[1], reverse=True)
return sorted_scores[:slice_size]

def create_query_document_pair(query, document):
return {"query": query, "document": document}
scoring_stream, _ = create_chat_stream_from_prompt(SCORING_PROMPT_TEMPLATE_ARGS)

vector_db_query_result = await query_vector(
vector_service=vector_service,
collection_name=collection_name,
query=question,
collection_name=vector_collection_name,
query=query,
)

scoring_stream = scoring_stream.map(
Expand All @@ -97,26 +80,41 @@ def create_query_document_pair(query, document):
else 0
)

def stream(query):
try:
return (
vector_db_query_result
.map(add_document)
.map(lambda document: create_query_document_pair(query, document))
.map(scoring_stream)
.gather()
.and_then(append_score_to_documents)
.and_then(lambda scored_documents: sort_and_slice_documents(scored_documents, 6))
.and_then(lambda results: {"query": query, "results": results[0]})
.and_then(chat_stream)
.map(lambda delta: delta.content)
)(query)
except Exception as e:
print(f"Error processing query: {e}", {str(e)})
raise e

try:
return await join_final_output(stream(question))
except Exception as e:
print(f"Error joining final output '{question}': {str(e)}")
raise e
documents: [str] = []

def store_and_return_document(document: str):
documents.append(document)
return document

def pair_query_document(document: str):
return {"query": query, "document": document}

def append_score_to_documents(scores):
return zip(documents, [s[0] for s in scores])

def sort_and_slice_documents(scored_documents, slice_size: int):
first_element = list(scored_documents)[0]
sorted_scores = sorted(first_element, key=lambda x: x[1], reverse=True)
return sorted_scores[:slice_size]

return (
vector_db_query_result
.map(store_and_return_document)
.map(pair_query_document)
.map(scoring_stream)
.gather()
.and_then(append_score_to_documents)
.and_then(lambda scored_documents: sort_and_slice_documents(scored_documents, 6))
.and_then(lambda results: {"query": query, "results": results[0]})
)


async def ask_llm_raq_question(question: str, collection_name: str):
rag_stream = await create_rag_stream(question, collection_name)
chat_stream, _ = create_chat_stream_from_prompt(CHAT_PROMPT_TEMPLATE_ARGS)

final_generator = (rag_stream
.and_then(chat_stream)
.map(lambda delta: delta.content)(question))

return await join_final_output(final_generator)
Loading

0 comments on commit d89e562

Please sign in to comment.