Skip to content

Commit

Permalink
Merge branch 'main' into dev-1939-use-self-hosted-runner
Browse files Browse the repository at this point in the history
  • Loading branch information
christian-calabrese authored Oct 3, 2024
2 parents cd8b8e6 + 320e268 commit f8006f2
Show file tree
Hide file tree
Showing 14 changed files with 604 additions and 825 deletions.
5 changes: 5 additions & 0 deletions .changeset/tame-maps-judge.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"chatbot": major
---

Using Gemini LLM instead one of those provided by AWS
1 change: 1 addition & 0 deletions apps/chatbot/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Create a `.env` file inside the folder and write to the file the following envir
CHB_AWS_GUARDRAIL_ID=...
CHB_AWS_GUARDRAIL_VERSION=...
CHB_REDIS_URL=...
CHB_REDIS_INDEX_NAME=...
CHB_WEBSITE_URL=...
CHB_GOOGLE_API_KEY=...
CHB_PROVIDER=...
Expand Down
10 changes: 6 additions & 4 deletions apps/chatbot/config/prompts.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
qa_prompt_str: |
You are an Italian customer services chatbot.
Given the context information:
You are an customer services chatbot.
Your name is Discovery and your duty is to assist the user with the PagoPA DevPortal documentation!
--------------------
Context information:
{context_str}
--------------------
Chatbot Policy:
- the answer must be generated using only the context information and not prior knowledge.
- the answer must be brief, concise, and composed of 80 words maximum.
- the answer must be brief, concise, and composed of 100 words maximum.
- the answer must be respectful and polite.
- the answer must be clear, non-redundant, and have not repeated sentences.
- the answer must not include the query.
- the answer must be in Italian.
- If your answer is based on this retrieved context, include a "Rif" section at the end of the response, listing the titles and filenames from the source nodes used. If no context is used, do not include a reference.
- the answer must be with the same language of the query.
--------------------
Output Examples:
Query: Cos'è il nodo dei pagamenti?
Expand Down
995 changes: 403 additions & 592 deletions apps/chatbot/poetry.lock

Large diffs are not rendered by default.

7 changes: 0 additions & 7 deletions apps/chatbot/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@ llama-index-readers-web = "^0.2.2"
llama-index-vector-stores-redis = "^0.3.2"
llama-index-storage-docstore-redis = "^0.2.0"
llama-index-storage-index-store-redis = "^0.3.0"
langchain = "^0.2.16"
langchain-community = "^0.2.16"
langchain-aws = "^0.1.18"
langchainhub = "^0.1.21"
langchain-core = "^0.2.39"
llama-index-embeddings-langchain = "^0.2.1"
llama-index-llms-langchain = "^0.4.1"
llama-index-llms-gemini = "^0.3.4"
google-generativeai = "^0.5.2"
llama-index-embeddings-gemini = "^0.2.0"
Expand Down
10 changes: 6 additions & 4 deletions apps/chatbot/scripts/redis-tunnel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ SERVICE_NAME="cms-ecs"
LOAD_BALANCER_NAME="chatbot-load-balancer"
REMOTE_PORT="6379"
LOCAL_PORT="16379"
AWS_REGION="eu-south-1"


# Get the ECS Task ARN
TASK_ARN=$(aws ecs list-tasks --cluster $CLUSTER_NAME --service-name $SERVICE_NAME --query 'taskArns[0]' --output text)
TASK_ARN=$(aws ecs list-tasks --cluster $CLUSTER_NAME --service-name $SERVICE_NAME --query 'taskArns[0]' --output text --region $AWS_REGION)

if [ -z "$TASK_ARN" ]; then
echo "No ECS task found for the service $SERVICE_NAME in cluster $CLUSTER_NAME."
Expand All @@ -25,15 +26,15 @@ if [ -z "$TASK_ID" ]; then
fi

# Get the container name (task definition details)
CONTAINER_ID=$(aws ecs describe-tasks --cluster $CLUSTER_NAME --tasks $TASK_ARN --query 'tasks[0].containers[0].runtimeId' --output text)
CONTAINER_ID=$(aws ecs describe-tasks --cluster $CLUSTER_NAME --tasks $TASK_ARN --query 'tasks[0].containers[0].runtimeId' --output text --region $AWS_REGION)

if [ -z "$CONTAINER_ID" ]; then
echo "No container found for task $TASK_ARN."
exit 1
fi

# Get the Network Load Balancer DNS name
TARGET_HOST=$(aws elbv2 describe-load-balancers --names $LOAD_BALANCER_NAME --query 'LoadBalancers[0].DNSName' --output text)
TARGET_HOST=$(aws elbv2 describe-load-balancers --names $LOAD_BALANCER_NAME --query 'LoadBalancers[0].DNSName' --output text --region $AWS_REGION)

if [ -z "$TARGET_HOST" ]; then
echo "No Load Balancer found with the name $LOAD_BALANCER_NAME."
Expand All @@ -47,4 +48,5 @@ TARGET="ecs:${CLUSTER_NAME}_${TASK_ID}_${CONTAINER_ID}"
aws ssm start-session \
--target $TARGET \
--document-name AWS-StartPortForwardingSessionToRemoteHost \
--parameters "{\"host\":[\"$TARGET_HOST\"],\"portNumber\":[\"$REMOTE_PORT\"],\"localPortNumber\":[\"$LOCAL_PORT\"]}"
--parameters "{\"host\":[\"$TARGET_HOST\"],\"portNumber\":[\"$REMOTE_PORT\"],\"localPortNumber\":[\"$LOCAL_PORT\"]}" \
--region $AWS_REGION
215 changes: 55 additions & 160 deletions apps/chatbot/src/modules/chatbot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import re
import logging
from collections import Counter
from typing import List, Union, Tuple
from typing import Union, Tuple

from langdetect import detect_langs

from llama_index.core import PromptTemplate
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.base.response.schema import (
Response, StreamingResponse, AsyncStreamingResponse, PydanticResponse
)
Expand All @@ -23,8 +22,6 @@
ITALIAN_THRESHOLD = 0.85
NUM_MIN_WORDS_QUERY = 3
NUM_MIN_REFERENCES = 1
GUARDRAIL_ANSWER = """Mi dispiace, non mi è consentito elaborare contenuti inappropriati.
Riformula la domanda in modo che non violi queste linee guida."""
RESPONSE_TYPE = Union[
Response, StreamingResponse, AsyncStreamingResponse, PydanticResponse
]
Expand Down Expand Up @@ -79,7 +76,14 @@ def __init__(
chunk_overlap=params["vector_index"]["chunk_overlap"],
)

self.messages = []
self.history = [
ChatMessage(
role=MessageRole.ASSISTANT,
content="""You are an Italian customer services chatbot.
Your name is Discovery and it is your duty to assist the user answering his questions about the PagoPA DevPortal documentation!
"""
)
]
self.qa_prompt_tmpl, self.ref_prompt_tmpl = self._get_prompt_templates()
self.engine = get_automerging_query_engine(
self.index,
Expand All @@ -89,6 +93,7 @@ def __init__(
verbose=self.params["engine"]["verbose"]
)


def _get_prompt_templates(self) -> Tuple[PromptTemplate, PromptTemplate]:

# create templates
Expand All @@ -112,47 +117,16 @@ def _get_prompt_templates(self) -> Tuple[PromptTemplate, PromptTemplate]:
return qa_prompt_tmpl, ref_prompt_tmpl


def _update_messages(self, role: str, message: str):
self.messages.append({"role": role, "text": message})
def _update_history(self, role: MessageRole, message: str):
self.history.append(ChatMessage(role=role, content=message))


def reset_chat_history(self):
self.messages = []
return self.messages


def _messages_to_str(self, **kwargs) -> str:
text = ""
if len(self.messages) > 0:
for speach in self.messages:
text += f"{speach['role']}: {speach['text']}\n\n"

return text


def get_chat_history(self) -> List[dict]:
return self.messages


def _check_language(self, message_str):

try:
langs = detect_langs(message_str)
except Exception as e:
logging.warning(f"LangDetectException: {e}. Now list of detected languages is empty.")
langs = []

it_score = 0.0
for lang in langs:
if lang.lang == "it":
it_score = lang.prob

logging.info(f"Detected Italian with score {it_score:.4f} at the last user's message.")

return it_score
self.history = []
return self.history


def _get_response_str(self, query: str, engine_response: RESPONSE_TYPE) -> str:
def _get_response_str(self, engine_response: RESPONSE_TYPE) -> str:

if isinstance(engine_response, StreamingResponse):
typed_response = engine_response.get_response()
Expand All @@ -162,124 +136,38 @@ def _get_response_str(self, query: str, engine_response: RESPONSE_TYPE) -> str:
response_str = typed_response.response.strip()
nodes = typed_response.source_nodes

response_str = self._remove_redundancy(query, response_str)

if response_str is None or response_str == "Empty Response" or response_str == "" or len(nodes) == 0:
response_str = """Mi dispiace, posso rispondere solo a domande riguardo la documentazione del [PagoPA DevPortal | Home](https://developer.pagopa.it/).
Prova a riformulare la domanda.
"""
elif response_str == GUARDRAIL_ANSWER:
logging.warning(f"AWS guardrail: Detected harmful category.")
else:
response_str = self._add_reference(response_str, nodes)
response_str = self._unmask_reference(response_str, nodes)

return response_str


def _remove_redundancy(self, query: str, response: str) -> str:

sentences = re.split(r"(?<=[\n])", response)
sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
unique_sentences = list(Counter(sentences).keys())
indexes_to_remove = []
for i, unique_sentence in enumerate(unique_sentences):
for j, us in enumerate(unique_sentences):
if i != j and unique_sentence in us:
indexes_to_remove.append(i)

for idx in indexes_to_remove[::-1]:
try:
unique_sentences.pop(idx)
except Exception as e:
unique_sentences = []
print(response)
logging.info(f"{e}: the generated response has too many redundacy problems. The output is now empty string.")
break

respose_str = "\n".join(unique_sentences)
respose_str = respose_str.replace(query, "")

return respose_str.strip()


def _add_reference(self, response_str: str, nodes) -> str:
def _unmask_reference(self, response_str: str, nodes) -> str:

pattern = r'[a-fA-F0-9]{64}'

# Find all matches in the text
hashed_urls = re.findall(pattern, response_str)

if not hashed_urls:
num_refs = 0
freq = {}
for node in nodes:
title = node.metadata["title"]
hashed_url = node.metadata["filename"]
if (title, hashed_url) not in freq.keys():
freq[(title, hashed_url)] = 1
else:
freq[(title, hashed_url)] += 1

freq_sorted = dict(sorted(freq.items(), key=lambda item: item[1], reverse=True))

response_str += "\n\nRif:"

for i, (k, v) in enumerate(freq_sorted.items()):

if i == 0 and v == NUM_MIN_REFERENCES:
num_refs += 1
title = nodes[0].metadata["title"]
hashed_url = nodes[0].metadata["filename"]
if self.use_redis:
url = REDIS_KVSTORE.get(
collection="hash_table",
key=hashed_url
)
if url is None:
url = "{URL}"
else:
if hashed_url in self.hash_table.keys():
url = self.hash_table[hashed_url]
else:
url = "{URL}"
response_str += f"\n[{title}]({url})"
break
else:
if v > NUM_MIN_REFERENCES:
num_refs += 1
title, hashed_url = k
if self.use_redis:
url = REDIS_KVSTORE.get(
collection="hash_table",
key=hashed_url
)
if url is None:
url = "{URL}"
else:
if hashed_url in self.hash_table.keys():
url = self.hash_table[hashed_url]
else:
url = "{URL}"
response_str += f"\n[{title}]({url})"

logging.info(f"Generated answer had no references. Added {num_refs} references taken from {len(nodes)} nodes. First node has score: {nodes[0].score:.4f}.")

else:
logging.info(f"Generated answer has {len(hashed_urls)} references taken from {len(nodes)} nodes. First node has score: {nodes[0].score:.4f}.")
for hashed_url in hashed_urls:
if self.use_redis:
url = REDIS_KVSTORE.get(
collection="hash_table",
key=hashed_url
)
if url is None:
url = "{URL}"
logging.info(f"Generated answer has {len(hashed_urls)} references taken from {len(nodes)} nodes. First node has score: {nodes[0].score:.4f}.")
for hashed_url in hashed_urls:
if self.use_redis:
url = REDIS_KVSTORE.get(
collection="hash_table",
key=hashed_url
)
if url is None:
url = "{URL}"
else:
if hashed_url in self.hash_table.keys():
url = self.hash_table[hashed_url]
else:
if hashed_url in self.hash_table.keys():
url = self.hash_table[hashed_url]
else:
url = "{URL}"
response_str = response_str.replace(hashed_url, url)
url = "{URL}"
response_str = response_str.replace(hashed_url, url)

# remove sentences with generated masked url: {URL}
parts = re.split(r"(?<=[\.\?\!\n])", response_str)
Expand All @@ -290,23 +178,30 @@ def _add_reference(self, response_str: str, nodes) -> str:


def generate(self, query_str: str) -> str:

num_words = len(query_str.split(" "))
it_score = self._check_language(query_str)
if num_words < NUM_MIN_WORDS_QUERY:
response_str = """Mi dispiace, la domanda fornita è insufficiente.
Per piacere, riformula la tua domanda.
"""
elif num_words >= NUM_MIN_WORDS_QUERY and it_score < ITALIAN_THRESHOLD:
response_str = """Mi dispiace, la domanda fornita non è propriamente formulata in italiano.
Per piacere, riformula la tua domanda.
"""
else:

try:
engine_response = self.engine.query(query_str)
response_str = self._get_response_str(query_str, engine_response)
response_str = self._get_response_str(engine_response)

self._update_history(MessageRole.USER, query_str)
self._update_history(MessageRole.ASSISTANT, response_str)

# update messages
self._update_messages("user", query_str)
self._update_messages("assistant", response_str)
except Exception as e:
exception_str = str(e)
if "SAFETY" in exception_str:
if "HARM_CATEGORY_HARASSMENT" in exception_str:
response_str = "Mi dispiace, ma non posso rispondere a domande offensive o minacciose."
logging.info("Gemini Safety: blocked query because retrieved HARASSMENT content in it.")
if "HARM_CATEGORY_SEXUALLY_EXPLICIT" in exception_str:
response_str = "Mi dispiace, ma non posso rispondere a domande di natura sessualmente esplicita."
logging.info("Gemini Safety: blocked query because retrieved SEXUALLY_EXPLICIT content in it.")
if "HARM_CATEGORY_HATE_SPEECH" in exception_str:
response_str = "Mi dispiace, ma non posso accettare discorsi di odio. Per favore, evita di usare linguaggio."
logging.info("Gemini Safety: blocked query because retrieved HATE_SPEECH content in it.")
if "HARM_CATEGORY_DANGEROUS_CONTENT" in exception_str:
response_str = "Mi dispiace, ma non posso fornire informazioni che potrebbero essere pericolose o dannose."
logging.info("Gemini Safety: blocked query because retrieved DANGEROUS_CONTENT in it.")
else:
logging.info(exception_str)

return response_str
4 changes: 2 additions & 2 deletions apps/chatbot/src/modules/create_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
load_dotenv()
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))

CHB_AWS_DEFAULT_REGION = os.getenv('CHB_AWS_DEFAULT_REGION', os.getenv('AWS_DEFAULT_REGION'))
CHB_AWS_DEFAULT_REGION = os.getenv("CHB_AWS_DEFAULT_REGION", os.getenv("AWS_DEFAULT_REGION"))
AWS_S3_BUCKET = os.getenv("CHB_AWS_S3_BUCKET", os.getenv("AWS_S3_BUCKET"))


Expand All @@ -30,7 +30,7 @@
if params["vector_index"]["use_redis"] and params["vector_index"]["use_s3"]:
raise Exception("Vector Store Error: use s3 or Redis or none of them.")

model = get_llm(params)
model = get_llm()
embed_model = get_embed_model()

# create vector index
Expand Down
Loading

0 comments on commit f8006f2

Please sign in to comment.