Skip to content

Commit

Permalink
AAP-33915: Use the new exception handler for the Chat view (#1478)
Browse files Browse the repository at this point in the history
  • Loading branch information
manstis authored Jan 8, 2025
1 parent 48acf18 commit 05fcb28
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 127 deletions.
21 changes: 6 additions & 15 deletions ansible_ai_connect/ai/api/telemetry/schema1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
version_info = VersionInfo()


def _anonymize_struct(value):
def anonymize_struct(value):
return anonymizer.anonymize_struct(value, value_template=Template("{{ _${variable_name}_ }}"))


Expand Down Expand Up @@ -169,7 +169,7 @@ class ChatBotResponseDocsReferences:


@define
class ChatBotBaseEvent:
class ChatBotBaseEvent(Schema1Event):
chat_prompt: str = field(validator=validators.instance_of(str), converter=str, default="")
chat_system_prompt: str = field(
validator=validators.instance_of(str), converter=str, default=""
Expand All @@ -185,29 +185,20 @@ class ChatBotBaseEvent:
converter=str,
default=settings.CHATBOT_DEFAULT_PROVIDER,
)
modelName: str = field(
validator=validators.instance_of(str), converter=str, default=settings.CHATBOT_DEFAULT_MODEL
)
rh_user_org_id: int = field(validator=validators.instance_of(int), converter=int, default=-1)
timestamp: str = field(
default=Factory(lambda self: timezone.now().isoformat(), takes_self=True)
)

def __attrs_post_init__(self):
self.chat_prompt = _anonymize_struct(self.chat_prompt)
self.chat_response = _anonymize_struct(self.chat_response)
self.chat_prompt = anonymize_struct(self.chat_prompt)
self.chat_response = anonymize_struct(self.chat_response)


@define
class ChatBotFeedbackEvent(ChatBotBaseEvent):
event_name: str = "chatFeedbackEvent"
sentiment: int = field(
validator=[validators.instance_of(int), validators.in_([0, 1])], converter=int, default=0
)


@define
class ChatBotOperationalEvent(ChatBotBaseEvent):
req_duration: float = field(
validator=[validators.instance_of(float)], converter=float, default=0
)
exception: str = field(validator=validators.instance_of(str), converter=str, default="")
event_name: str = "chatOperationalEvent"
6 changes: 2 additions & 4 deletions ansible_ai_connect/ai/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4072,7 +4072,6 @@ def assert_test(
):
if user is None:
user = self.user
self.client.force_authenticate(user=user)
with (
patch.object(
apps.get_app_config("ai"),
Expand Down Expand Up @@ -4224,9 +4223,8 @@ def test_operational_telemetry_error(self):
123,
)
self.assertEqual(
segment_events[0]["properties"]["exception"],
"An exception <class 'ansible_ai_connect.ai.api.exceptions."
"ChatbotInvalidResponseException'> occurred during a chat generation",
segment_events[0]["properties"]["problem"],
"Invalid response",
)

@override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE")
Expand Down
4 changes: 2 additions & 2 deletions ansible_ai_connect/ai/api/utils/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ def send_chatbot_event(event: ChatBotBaseEvent, event_name: str, user: User) ->
if not settings.SEGMENT_WRITE_KEY:
logger.info("segment write key not set, skipping event")
return
if _is_segment_message_exceeds_limit(asdict(event)):
if is_segment_message_exceeds_limit(asdict(event)):
# Prioritize the prompt and referenced documents.
event.chat_response = ""
base_send_segment_event(asdict(event), event_name, user, analytics)


def _is_segment_message_exceeds_limit(msg_dict):
def is_segment_message_exceeds_limit(msg_dict):
msg_dict = clean(msg_dict)
msg_size = len(json.dumps(msg_dict, cls=DatetimeSerializer).encode())
if msg_size > MAX_MSG_SIZE:
Expand Down
184 changes: 78 additions & 106 deletions ansible_ai_connect/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

import logging
import time
import timeit
import traceback
from string import Template

from ansible_anonymizer import anonymizer
from attr import asdict
from django.apps import apps
from django.conf import settings
from django_prometheus.conf import NAMESPACE
Expand Down Expand Up @@ -134,10 +134,10 @@
)
from .telemetry.schema1 import (
ChatBotFeedbackEvent,
ChatBotOperationalEvent,
ChatBotResponseDocsReferences,
anonymize_struct,
)
from .utils.segment import send_segment_event
from .utils.segment import is_segment_message_exceeds_limit, send_segment_event

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -176,9 +176,18 @@


class AACSAPIView(APIView):
def __init__(self):
super().__init__()
self.event = None
self.model_id = None
self.exception = None

def initialize_request(self, request, *args, **kwargs):
# Call super first to ensure request object is correctly
# initialised before instantiating event
initialised_request = super().initialize_request(request, *args, **kwargs)
self.event: schema1.Schema1Event = self.schema1_event()
self.event.set_request(request)
self.event.set_request(initialised_request)

# TODO: when we will move the request_serializer handling in this
# class, we will change this line below. The model_id attribute
Expand All @@ -187,7 +196,7 @@ def initialize_request(self, request, *args, **kwargs):
# See: https://github.com/ansible/ansible-ai-connect-service/pull/1147/files#diff-ecfb6919dfd8379aafba96af7457b253e4dce528897dfe6bfc207ca2b3b2ada9R143-R151 # noqa: E501
self.model_id: str = ""

return super().initialize_request(request, *args, **kwargs)
return initialised_request

def handle_exception(self, exc):
self.exception = exc
Expand Down Expand Up @@ -1109,7 +1118,7 @@ def post(self, request) -> Response:
)


class Chat(APIView):
class Chat(AACSAPIView):
"""
Send a message to the backend chatbot service and get a reply.
"""
Expand All @@ -1124,6 +1133,7 @@ class ChatEndpointThrottle(EndpointRateThrottle):
]
required_scopes = ["read", "write"]
throttle_classes = [ChatEndpointThrottle]
schema1_event = schema1.ChatBotOperationalEvent

def __init__(self):
self.chatbot_enabled = (
Expand Down Expand Up @@ -1152,114 +1162,76 @@ def __init__(self):
def post(self, request) -> Response:
headers = {"Content-Type": "application/json"}
request_serializer = ChatRequestSerializer(data=request.data)
rh_user_org_id = getattr(request.user, "org_id", None)

data = {}
req_query = "<undefined>"
req_system_prompt = "<undefined>"
req_model_id = "<undefined>"
req_provider = "<undefined>"
duration = 0
operational_event = {}

try:
if not self.chatbot_enabled:
raise ChatbotNotEnabledException()

if not request_serializer.is_valid():
raise ChatbotInvalidRequestException()

req_query = request_serializer.validated_data["query"]
req_system_prompt = (
request_serializer.validated_data["system_prompt"]
if "system_prompt" in request_serializer.validated_data
else None
)
req_model_id = (
request_serializer.validated_data["model"]
if "model" in request_serializer.validated_data
else settings.CHATBOT_DEFAULT_MODEL
)
req_provider = (
request_serializer.validated_data["provider"]
if "provider" in request_serializer.validated_data
else settings.CHATBOT_DEFAULT_PROVIDER
)
conversation_id = (
request_serializer.validated_data["conversation_id"]
if "conversation_id" in request_serializer.validated_data
else None
)
if not self.chatbot_enabled:
raise ChatbotNotEnabledException()

start = timeit.default_timer()
if not request_serializer.is_valid():
raise ChatbotInvalidRequestException()

llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline(
ModelPipelineChatBot
)

data = llm.invoke(
ChatBotParameters.init(
request=request,
query=req_query,
system_prompt=req_system_prompt,
model_id=req_model_id,
provider=req_provider,
conversation_id=conversation_id,
)
)

duration = timeit.default_timer() - start
req_query = request_serializer.validated_data["query"]
req_system_prompt = (
request_serializer.validated_data["system_prompt"]
if "system_prompt" in request_serializer.validated_data
else None
)
req_model_id = (
request_serializer.validated_data["model"]
if "model" in request_serializer.validated_data
else settings.CHATBOT_DEFAULT_MODEL
)
req_provider = (
request_serializer.validated_data["provider"]
if "provider" in request_serializer.validated_data
else settings.CHATBOT_DEFAULT_PROVIDER
)
conversation_id = (
request_serializer.validated_data["conversation_id"]
if "conversation_id" in request_serializer.validated_data
else None
)

response_serializer = ChatResponseSerializer(data=data)
# Initialise Segment Event early, in case of exceptions
self.model_id = req_model_id
self.event.chat_prompt = anonymize_struct(req_query)
self.event.chat_system_prompt = req_system_prompt
self.event.provider_id = req_provider
self.event.conversation_id = conversation_id
self.event.modelName = req_model_id

if not response_serializer.is_valid():
raise ChatbotInvalidResponseException()
llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline(
ModelPipelineChatBot
)

operational_event = ChatBotOperationalEvent(
chat_prompt=req_query,
chat_system_prompt=req_system_prompt,
chat_response=data["response"],
chat_truncated=bool(data["truncated"]),
chat_referenced_documents=[
ChatBotResponseDocsReferences(docs_url=rd["docs_url"], title=rd["title"])
for rd in data["referenced_documents"]
],
conversation_id=data["conversation_id"],
provider_id=req_provider,
modelName=req_model_id,
rh_user_org_id=rh_user_org_id,
req_duration=duration,
data = llm.invoke(
ChatBotParameters.init(
request=request,
query=req_query,
system_prompt=req_system_prompt,
model_id=req_model_id,
provider=req_provider,
conversation_id=conversation_id,
)
)

return Response(
data,
status=rest_framework_status.HTTP_200_OK,
headers=headers,
)
response_serializer = ChatResponseSerializer(data=data)

except Exception as exc:
if "detail" in data:
detail = data.get("detail", "")
operational_event = ChatBotOperationalEvent(
chat_prompt=req_query,
chat_system_prompt=req_system_prompt,
provider_id=req_provider,
modelName=req_model_id,
rh_user_org_id=rh_user_org_id,
req_duration=duration,
exception=detail,
)
else:
exception_message = (
f"An exception {exc.__class__} occurred during a chat generation"
)
logger.exception(exception_message)
operational_event = ChatBotOperationalEvent(
exception=exception_message,
rh_user_org_id=rh_user_org_id,
)
if not response_serializer.is_valid():
raise ChatbotInvalidResponseException()

raise exc
# Finalise Segment Event with response details
self.event.chat_truncated = bool(data["truncated"])
self.event.chat_referenced_documents = [
ChatBotResponseDocsReferences(docs_url=rd["docs_url"], title=rd["title"])
for rd in data["referenced_documents"]
]
self.event.chat_response = anonymize_struct(data["response"])
self.event.chat_response = (
"" if is_segment_message_exceeds_limit(asdict(self.event)) else self.event.chat_response
)

finally:
send_chatbot_event(operational_event, "chatOperationalEvent", request.user)
return Response(
data,
status=rest_framework_status.HTTP_200_OK,
headers=headers,
)

0 comments on commit 05fcb28

Please sign in to comment.