Skip to content
This repository has been archived by the owner on Jan 5, 2025. It is now read-only.

Ollama LLM and conversational retrieval... #239

Merged
merged 2 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ dj_backend_server/nginx/nginx.conf
dj_backend_server.code-workspace
.aider*
.aiderignore
dj_backend_server/.vscode/settings.json

3 changes: 2 additions & 1 deletion dj_backend_server/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ pip-delete-this-directory.txt
website_data_sources/*
venv
open-llama-7B-open-instruct.ggmlv3.q4_K_M.bin
llama-2-7b-chat.ggmlv3.q4_K_M.bin
llama-2-7b-chat.ggmlv3.q4_K_M.bin
.vscode/
4 changes: 4 additions & 0 deletions dj_backend_server/CHANGELOG.MD
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
2.18.2024
- The conversational retrieval functionality is now operating as expected. It successfully sends the conversation history to the language model, allowing the context from previous interactions to be utilized effectively.
- Added support for Ollama as the Language Model (LLM). Ensure Ollama is specified in the .env configuration and the model is preloaded on the server.

2.17.2024
- Incorporate 'Ollama' into your example.env configuration and make sure to reflect these changes in your .env file for compatibility.
- We've expanded the logging capabilities within settings.py by deploying logging.debug for more detailed insights, although it remains inactive when the DEBUG mode is off.
Expand Down
18 changes: 10 additions & 8 deletions dj_backend_server/api/utils/get_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@ def get_azure_embedding():
deployment = os.environ.get("AZURE_OPENAI_EMBEDDING_MODEL_NAME")
openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
client = os.environ.get("AZURE_OPENAI_API_TYPE")
openai_api_base = os.environ['AZURE_OPENAI_API_BASE']
openai_api_version = os.environ['AZURE_OPENAI_API_VERSION']
openai_api_base = os.environ["AZURE_OPENAI_API_BASE"]
openai_api_version = os.environ["AZURE_OPENAI_API_VERSION"]

return OpenAIEmbeddings(
openai_api_key=openai_api_key,
deployment=deployment,
client=client,
chunk_size=8,
openai_api_base=openai_api_base,
openai_api_version=openai_api_version
openai_api_version=openai_api_version,
)


def get_openai_embedding():
"""Gets embeddings using the OpenAI embedding provider."""
openai_api_key = os.environ.get("OPENAI_API_KEY")
return OpenAIEmbeddings(openai_api_key=openai_api_key, chunk_size=1)
return OpenAIEmbeddings(openai_api_key=openai_api_key, chunk_size=1)


def get_llama2_embedding():
Expand All @@ -48,15 +48,17 @@ def choose_embedding_provider():

if embedding_provider == EmbeddingProvider.azure.value:
return get_azure_embedding()

elif embedding_provider == EmbeddingProvider.OPENAI.value:
return get_openai_embedding()

elif embedding_provider == EmbeddingProvider.llama2.value:
return get_llama2_embedding()

else:
available_providers = ", ".join([service.value for service in EmbeddingProvider])
available_providers = ", ".join(
[service.value for service in EmbeddingProvider]
)
raise ValueError(
f"Embedding service '{embedding_provider}' is not currently available. "
f"Available services: {available_providers}"
Expand All @@ -66,4 +68,4 @@ def choose_embedding_provider():
# Main function to get embeddings
def get_embeddings() -> Embeddings:
"""Gets embeddings using the chosen embedding provider."""
return choose_embedding_provider()
return choose_embedding_provider()
62 changes: 23 additions & 39 deletions dj_backend_server/api/utils/get_openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@
from django.utils.timezone import make_aware
from datetime import datetime, timezone
from uuid import uuid4
from ollama import Client
from openai import OpenAI
from django.conf import settings
from langchain_openai.chat_models import ChatOpenAI
from langchain_community.llms import Ollama
from langchain_community.chat_models import ChatOllama
from langchain_community.llms import AzureOpenAI
from langchain_community.llms import LlamaCpp
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from web.models.failed_jobs import FailedJob
Expand Down Expand Up @@ -62,12 +58,7 @@ def get_llama_llm():
def get_azure_openai_llm():
"""Returns AzureOpenAI instance configured from environment variables"""
try:
if settings.DEBUG:
openai_api_type = "openai" # JUST FOR DEVELOPMENT
logging.debug(f"DEVELOPMENT Using API Type: {openai_api_type}")
else:
openai_api_type = os.environ["AZURE_OPENAI_API_TYPE"]

openai_api_type = os.environ["AZURE_OPENAI_API_TYPE"]
openai_api_key = os.environ["AZURE_OPENAI_API_KEY"]
openai_deployment_name = os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"]
openai_model_name = os.environ["AZURE_OPENAI_COMPLETION_MODEL"]
Expand Down Expand Up @@ -134,30 +125,26 @@ def get_openai_llm():
traceback.print_exc()


def get_ollama_llm(sanitized_question):
"""Returns an Ollama Server instance configured from environment variables"""
llm = Client(host=os.environ.get("OLLAMA_URL"))
# Use the client to make a request
def get_ollama_llm():
"""Returns an Ollama instance configured from environment variables"""
try:
if sanitized_question:
response = llm.chat(
model=os.environ.get("OLLAMA_MODEL_NAME"),
messages=[{"role": "user", "content": sanitized_question}],
)
else:
raise ValueError("Question cannot be None.")
if response:
return response
else:
raise ValueError("Invalid response from Ollama.")
base_url = os.environ.get("OLLAMA_URL")
model = os.environ.get("OLLAMA_MODEL_NAME", "llama2")

llm = ChatOllama(
base_url=base_url,
model=model,
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
)
return llm

except Exception as e:
logger.debug(f"Exception in get_ollama_llm: {e}")
failed_job = FailedJob(
uuid=str(uuid4()),
connection="default",
queue="default",
payload="get_openai_llm",
payload="get_ollama_llm",
exception=str(e),
failed_at=make_aware(datetime.now(), timezone.utc),
)
Expand All @@ -176,29 +163,26 @@ def get_llm():
"ollama": lambda: get_ollama_llm(),
}

# DEVENV
# if settings.DEBUG:
# api_type = "ollama"
api_type = os.environ.get("OPENAI_API_TYPE", "openai")

if api_type not in clients:
raise ValueError(f"Invalid OPENAI_API_TYPE: {api_type}")

logging.debug(f"Using LLM: {api_type}")

if api_type in clients:
if api_type == "ollama":
return clients[api_type]()
elif api_type != "ollama":
return clients[api_type]()
llm_instance = clients[api_type]()
if llm_instance is None:
logger.error(f"LLM instance for {api_type} could not be created.")
return None
return llm_instance
else:
raise ValueError(f"Invalid OPENAI_API_TYPE: {api_type}")

except Exception as e:
failed_job = FailedJob(
uuid=str(uuid4()),
connection="default",
queue="default",
payload="get_llm",
exception=str(e),
failed_at=datetime.now(),
)
failed_job = FailedJob(
uuid=str(uuid4()),
connection="default",
Expand Down
1 change: 1 addition & 0 deletions dj_backend_server/api/utils/make_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def getConversationRetrievalChain(
retriever=vector_store.as_retriever(),
verbose=True,
combine_docs_chain_kwargs={"prompt": prompt},
return_source_documents=True,
)
logger.debug(f"ConversationalRetrievalChain {llm}, created: {chain}")
return chain
Expand Down
14 changes: 4 additions & 10 deletions dj_backend_server/api/views/views_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,15 @@ def get_completion_response(
elif chain_type == "conversation_retrieval":
chain = getConversationRetrievalChain(vector_store, mode, initial_prompt)
logger.debug("getConversationRetrievalChain")
chat_history_json = json.dumps(
get_chat_history_for_retrieval_chain(
session_id, limit=20, initial_prompt=initial_prompt
),
ensure_ascii=False,
chat_history = get_chat_history_for_retrieval_chain(
session_id, limit=20, initial_prompt=initial_prompt
)
chat_history_json = ""
logger.debug(f"Formatted Chat_history {chat_history_json}")
logger.debug(f"Formatted Chat_history {chat_history}")

response = chain.invoke(
{"question": sanitized_question, "chat_history": chat_history_json}
{"question": sanitized_question, "chat_history": chat_history},
)
logger.debug(f"response from chain.invoke: {response}")
response_text = response.get("answer")
logger.debug(f"response_text : {response_text}")
try:
# Attempt to parse the response_text as JSON
response_text = json.loads(response_text)
Expand Down
32 changes: 13 additions & 19 deletions dj_backend_server/api/views/views_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,26 +170,23 @@ def send_chat(request):
"""
try:

if settings.DEBUG:
logger.debug("Entering send_chat function")
logger.debug("Entering send_chat function")
# You can add additional validation for 'history' and 'content_type' if needed.

bot_token = request.headers.get("X-Bot-Token")
bot = get_object_or_404(Chatbot, token=bot_token)

data = json.loads(request.body)
if settings.DEBUG:
logger.debug(
f"Request data: {data}"
) # {'from': 'user', 'type': 'text', 'content': 'input text from chat'}
logger.debug(
f"Request data: {data}"
) # {'from': 'user', 'type': 'text', 'content': 'input text from chat'}
# Validate the request data
content = data.get("content")
history = data.get("history")
if settings.DEBUG:
logger.debug(f"Content: {content}")
logger.debug(
f"History: {history}"
) # history is a list of chat history - None????
logger.debug(f"Content: {content}")
logger.debug(
f"History: {history}"
) # history is a list of chat history - None????
content_type = data.get("type")

session_id = get_session_id(request=request, bot_id=bot.id)
Expand All @@ -198,10 +195,9 @@ def send_chat(request):
{"message": entry.message, "from_user": entry.from_user}
for entry in history
]
if settings.DEBUG:
logger.debug(
f"History entries in JSON: {history_entries} - and history in text from DB: {history}"
)
logger.debug(
f"History entries in JSON: {history_entries} - and history in text from DB: {history}"
)

# Implement the equivalent logic for validation
if not content:
Expand All @@ -211,8 +207,7 @@ def send_chat(request):
)

# Implement the equivalent logic to send the HTTP request to the external API
if settings.DEBUG:
logger.debug(f"External API response START")
logger.debug(f"External API response START")
response = requests.post(
os.getenv("APP_URL") + "/api/chat/",
json={
Expand All @@ -226,8 +221,7 @@ def send_chat(request):
},
timeout=200,
)
if settings.DEBUG:
logger.debug(f"External API response: {response.text} and {response}")
logger.debug(f"External API response: {response.text} and {response}")

"""
This block will first check if the response content is not empty. If it is empty,
Expand Down
11 changes: 4 additions & 7 deletions dj_backend_server/example.env
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ OPENAI_API_TYPE=openai
OPENAI_API_MODEL=gpt-4-1106-preview
OPENAI_API_TEMPERATURE=1

# azure | openai | llama2 | ollama
# azure | openai | llama2 - change only if you know what you do
EMBEDDING_PROVIDER=openai

# If using azure
Expand All @@ -30,22 +30,20 @@ EMBEDDING_PROVIDER=openai
# AZURE_OPENAI_DEPLOYMENT_NAME=
# AZURE_OPENAI_COMPLETION_MODEL=gpt-35-turbo


# OLLAMA_URL="" #no trailing slash at the end or will not work.
# OLLAMA_MODEL_NAME="" # ex openchat, llama2 - Be sure you have this on server downloaded "ollama pull openchat"

# Vector Store, PINECONE|QDRANT
STORE=QDRANT


# if using pinecone
# PINECONE_API_KEY=
# PINECONE_ENV=
# VECTOR_STORE_INDEX_NAME=


# if using qdrant
QDRANT_URL=http://qdrant:6333


# optional, defaults to 15
MAX_PAGES_CRAWL=150

Expand Down Expand Up @@ -73,5 +71,4 @@ OCR_LLM = '1'

# retrieval_qa | conversation_retrieval, retrieval_qa works better with azure openai
# if you want to use the conversation_retrieval | retrieval_qa chain
CHAIN_TYPE=conversation_retrieval

CHAIN_TYPE=conversation_retrieval
7 changes: 6 additions & 1 deletion dj_backend_server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ drf-spectacular==0.27.1
drf_spectacular.extensions==0.0.2
exceptiongroup==1.1.2
frozenlist==1.4.0
filelock==3.13.1
fsspec==2024.2.0
huggingface-hub==0.20.3
grpcio==1.56.2
grpcio-tools==1.56.2
h11==0.14.0
Expand Down Expand Up @@ -71,6 +74,7 @@ qdrant-client==1.7.0
redis==4.6.0
regex==2023.6.3
requests==2.31.0
safetensors==0.4.2
six==1.16.0
sniffio==1.3.0
soupsieve==2.4.1
Expand All @@ -79,6 +83,8 @@ sqlparse==0.4.4
tenacity==8.2.2
tiktoken==0.6.0
tqdm==4.65.0
tokenizers==0.15.2
transformers==4.37.2
typing-inspect==0.9.0
typing_extensions==4.7.1
tzdata==2023.3
Expand All @@ -88,4 +94,3 @@ wcwidth==0.2.6
yarl==1.9.2
django-cors-headers==4.3.1


Loading
Loading