-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
C ready for initial publishing
- Loading branch information
Showing
9 changed files
with
190 additions
and
185 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,91 +1,42 @@ | ||
# !/usr/bin/python | ||
import os | ||
|
||
import typer | ||
from typer import Typer | ||
|
||
from typing import List, Union | ||
from .query import Query | ||
from .conversation import Conversation | ||
from .migrations import Migration | ||
|
||
from rich.console import Console | ||
|
||
from langchain.prompts.chat import ( | ||
ChatPromptTemplate, | ||
HumanMessagePromptTemplate, | ||
SystemMessagePromptTemplate, | ||
app: Typer = Typer( | ||
help="Chino is a chatbot based on OpenAI. It can also provide responses about queries on user-provided data." | ||
) | ||
from langchain.schema import HumanMessage, SystemMessage, BaseMessage | ||
from langchain_openai import ChatOpenAI | ||
|
||
from .store.query import query_data | ||
from .store.migrations import generate_data_store | ||
|
||
|
||
app = typer.Typer() | ||
console = Console() | ||
|
||
model: ChatOpenAI = ChatOpenAI() | ||
|
||
messages: List[Union[SystemMessage, HumanMessage]] = [ | ||
SystemMessage( | ||
content="You are Chino, a chatbot based on ChatGPT. 'Chino' means 'intelligence' in Japanese." | ||
), | ||
] | ||
|
||
|
||
def get_response(prompt: str) -> None: | ||
global model, messages | ||
|
||
with console.status("[i]thinking...[/i]"): | ||
messages.append(HumanMessage(content=prompt)) | ||
response: BaseMessage = model.invoke(messages) | ||
messages.append(SystemMessage(content=response.content)) | ||
console.print(f"[b blue]Chino:[/b blue] {response.content}") | ||
console.rule() | ||
|
||
|
||
def run_query(prompt: str) -> None: | ||
global model, messages | ||
|
||
with console.status("[i]thinking...[/i]"): | ||
query_text, query_sources = query_data(prompt) | ||
messages.append(HumanMessage(content=query_text)) | ||
response: BaseMessage = model.invoke(messages) | ||
messages.append(SystemMessage(content=response.content)) | ||
console.print( | ||
f"[b blue]Chino:[/b blue] {response.content}\n\n[i violet]Sources:[/i violet]{query_sources}" | ||
) | ||
console.rule() | ||
|
||
|
||
def run_conversation(prompt: str, query: bool) -> None: | ||
global model, messages | ||
|
||
if prompt: | ||
if query or prompt.lower().startswith("\\q:"): | ||
run_query(prompt) | ||
return | ||
get_response(prompt) | ||
return | ||
|
||
while True: | ||
prompt: str = console.input("[b green]You: [/b green]") | ||
if prompt == "quit": | ||
break | ||
elif query or prompt.lower().startswith("\\query:"): | ||
run_query(prompt) | ||
continue | ||
get_response(prompt) | ||
|
||
|
||
def main( | ||
prompt: str = typer.Option(None, "-p", "--prompt", help="Prompt for ChatGPT"), | ||
query: bool = typer.Option(False, "-q", "--query", help="Query for your data"), | ||
process: bool = typer.Option(False, "--process", help="Process your data"), | ||
@app.command() | ||
def migrate( | ||
chroma_path: str = os.path.expanduser("~/.local/share/chino/chroma/"), | ||
data_path: str = os.path.expanduser("~/.local/share/chino/data/"), | ||
) -> None: | ||
if process: | ||
console.status("Processing your data...") | ||
generate_data_store() | ||
return | ||
run_conversation(prompt, query) | ||
|
||
|
||
if __name__ == "__main__": | ||
typer.run(main) | ||
"""Migrate the data to the chroma using vector embeddings.""" | ||
|
||
migration = Migration(chroma_path, data_path) | ||
migration.generate_data_store() | ||
|
||
|
||
@app.command("start") | ||
def main(): | ||
"""Start the main event loop function. A chat interface will be opened.""" | ||
|
||
conversation: Conversation = Conversation() | ||
try: | ||
while True: | ||
prompt = conversation.console.input("[bold green]You: [/bold green]") | ||
if prompt == "quit": | ||
conversation.console.print("[bold red]Quiting...[/bold red]") | ||
break | ||
elif prompt.lower().startswith("\\q:"): | ||
conversation.run_query(prompt) | ||
continue | ||
conversation.get_response(prompt) | ||
except KeyboardInterrupt: | ||
conversation.console.print("\n[bold red]Quiting...[/bold red]") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
|
||
from typing import List, Union | ||
|
||
from rich.console import Console | ||
|
||
from langchain.schema import HumanMessage, SystemMessage, BaseMessage | ||
from langchain_openai import ChatOpenAI | ||
|
||
from .query import Query | ||
|
||
|
||
class Conversation: | ||
def __init__(self) -> None: | ||
self.model = ChatOpenAI() | ||
self.messages: List[Union[SystemMessage, HumanMessage]] = [ | ||
SystemMessage( | ||
content="You are Chino, a chatbot based on ChatGPT. 'Chino' means 'intelligence' in Japanese." | ||
), | ||
] | ||
self.console = Console() | ||
|
||
def get_response(self, prompt: str) -> None: | ||
with self.console.status("[i]thinking...[/i]"): | ||
self.messages.append(HumanMessage(content=prompt)) | ||
response: BaseMessage = self.model.invoke(self.messages) | ||
self.messages.append(SystemMessage(content=response.content)) | ||
self.console.print(f"[b blue]Chino:[/b blue] {response.content}") | ||
self.console.rule() | ||
|
||
def run_query(self, prompt: str) -> None: | ||
with self.console.status("[i]thinking...[/i]"): | ||
query_text, query_sources = Query( | ||
prompt, os.path.expanduser("~/.local/share/chino/chroma/") | ||
).query_data() | ||
self.messages.append(HumanMessage(content=query_text)) | ||
response: BaseMessage = self.model.invoke(self.messages) | ||
self.messages.append(SystemMessage(content=response.content)) | ||
self.console.print( | ||
f"[b blue]Chino:[/b blue] {response.content}\n\n[i violet]Sources:[/i violet]{query_sources}" | ||
) | ||
self.console.rule() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import os | ||
import shutil | ||
|
||
from typing import Any, List | ||
|
||
from langchain_community.document_loaders import DirectoryLoader | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain.schema import Document | ||
from langchain_openai import OpenAIEmbeddings | ||
from langchain.vectorstores.chroma import Chroma | ||
|
||
|
||
class Migration: | ||
def __init__(self, chroma_path: str = None, data_path: str = None) -> None: | ||
self.CHROMA_PATH: str = chroma_path | ||
self.DATA_PATH: str = data_path | ||
|
||
def generate_data_store(self) -> None: | ||
documents: Any = Migration._load_documents(self.DATA_PATH) | ||
chunks: List[Document] = Migration._split_text(documents) | ||
Migration._save_to_chroma(chunks, self.CHROMA_PATH) | ||
|
||
@classmethod | ||
def _load_documents(cls, data_path): | ||
loader: DirectoryLoader = DirectoryLoader(data_path) | ||
documents: Any = loader.load() | ||
return documents | ||
|
||
@classmethod | ||
def _split_text(cls, documents: List[Document]) -> List[Document]: | ||
text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter( | ||
chunk_size=300, | ||
chunk_overlap=100, | ||
length_function=len, | ||
add_start_index=True, | ||
) | ||
chunks: List[Document] = text_splitter.split_documents(documents) | ||
print(f"Split {len(documents)} documents into {len(chunks)} chunks.") | ||
|
||
return chunks | ||
|
||
@classmethod | ||
def _save_to_chroma(cls, chunks: List[Document], chroma_path: str): | ||
# Clear out the database first. | ||
if os.path.exists(chroma_path): | ||
shutil.rmtree(chroma_path) | ||
|
||
# Create a new DB from the documents. | ||
db: Chroma = Chroma.from_documents( | ||
chunks, OpenAIEmbeddings(), persist_directory=chroma_path | ||
) | ||
db.persist() | ||
print(f"Saved {len(chunks)} chunks to {chroma_path}.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import sys | ||
|
||
from typing import List, Tuple | ||
|
||
from langchain.vectorstores.chroma import Chroma | ||
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | ||
from langchain.prompts import ChatPromptTemplate | ||
|
||
|
||
class Query: | ||
def __init__(self, prompt: str = None, chroma_path: str = None) -> None: | ||
self.query_text = prompt | ||
self.CHROMA_PATH = chroma_path | ||
|
||
def _prepare_db(self) -> Chroma: | ||
return Chroma( | ||
persist_directory=self.CHROMA_PATH, embedding_function=OpenAIEmbeddings() | ||
) | ||
|
||
def query_data(self) -> Tuple[str, List[str]]: | ||
# query text will be the prompt - provided by the user - to be processed using the embeddings and the model | ||
PROMPT_TEMPLATE: str = """ | ||
Details | ||
{context} | ||
--- | ||
Answer this question based on the above details and previous messages: {question} | ||
""" | ||
db: Chroma = self._prepare_db() | ||
|
||
# Search the DB. | ||
results: List = db.similarity_search_with_relevance_scores(self.query_text, k=3) | ||
if len(results) == 0 or results[0][1] < 0.7: | ||
print(f"Unable to find matching results.") | ||
sys.exit() | ||
|
||
context_text: str = "\n\n---\n\n".join( | ||
[doc.page_content for doc, _score in results] | ||
) | ||
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE) | ||
query_prompt = "QUERY:" + prompt_template.format( | ||
context=context_text, question=self.query_text | ||
) | ||
# This prompt will serve as a context for the model to generate a response. | ||
|
||
query_sources = [doc.metadata.get("source", None) for doc, _score in results] | ||
return query_prompt, query_sources |
Empty file.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.