Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AAP-33915: Use the new exception handler for the Chat view #1478

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions ansible_ai_connect/ai/api/telemetry/schema1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
version_info = VersionInfo()


def _anonymize_struct(value):
def anonymize_struct(value):
return anonymizer.anonymize_struct(value, value_template=Template("{{ _${variable_name}_ }}"))


Expand Down Expand Up @@ -167,7 +167,7 @@ class ChatBotResponseDocsReferences:


@define
class ChatBotBaseEvent:
class ChatBotBaseEvent(Schema1Event):
manstis marked this conversation as resolved.
Show resolved Hide resolved
chat_prompt: str = field(validator=validators.instance_of(str), converter=str, default="")
chat_system_prompt: str = field(
validator=validators.instance_of(str), converter=str, default=""
Expand All @@ -183,29 +183,20 @@ class ChatBotBaseEvent:
converter=str,
default=settings.CHATBOT_DEFAULT_PROVIDER,
)
modelName: str = field(
validator=validators.instance_of(str), converter=str, default=settings.CHATBOT_DEFAULT_MODEL
)
rh_user_org_id: int = field(validator=validators.instance_of(int), converter=int, default=-1)
timestamp: str = field(
default=Factory(lambda self: timezone.now().isoformat(), takes_self=True)
)

def __attrs_post_init__(self):
manstis marked this conversation as resolved.
Show resolved Hide resolved
self.chat_prompt = _anonymize_struct(self.chat_prompt)
self.chat_response = _anonymize_struct(self.chat_response)
self.chat_prompt = anonymize_struct(self.chat_prompt)
self.chat_response = anonymize_struct(self.chat_response)


@define
class ChatBotFeedbackEvent(ChatBotBaseEvent):
event_name: str = "chatFeedbackEvent"
sentiment: int = field(
validator=[validators.instance_of(int), validators.in_([0, 1])], converter=int, default=0
)


@define
class ChatBotOperationalEvent(ChatBotBaseEvent):
req_duration: float = field(
validator=[validators.instance_of(float)], converter=float, default=0
)
exception: str = field(validator=validators.instance_of(str), converter=str, default="")
event_name: str = "chatOperationalEvent"
6 changes: 2 additions & 4 deletions ansible_ai_connect/ai/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4072,7 +4072,6 @@ def assert_test(
):
if user is None:
user = self.user
self.client.force_authenticate(user=user)
with (
patch.object(
apps.get_app_config("ai"),
Expand Down Expand Up @@ -4224,9 +4223,8 @@ def test_operational_telemetry_error(self):
123,
)
self.assertEqual(
segment_events[0]["properties"]["exception"],
"An exception <class 'ansible_ai_connect.ai.api.exceptions."
"ChatbotInvalidResponseException'> occurred during a chat generation",
segment_events[0]["properties"]["problem"],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change.

AACSAPIView uses Schema1Event.exception: bool to indicate whether an exception occurred. Schema1Event.problem contains the exception message. Is this going to cause Marty etc issues?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, well at least if he's using this field as for reports in Amplitude. Although this looks not a blocker, we just need to let him know in order to "refactor", if necessary, any query.

Also @manstis this is a good candidate to make the integration tests to fail :) I think we're doing assertions over this field, it worth checking too, IMO :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: I knew the way we use this field is not always consistent depending on the view and I would like to converge to the new approach.

Feel free to object/cry/shout :-).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@romartin I could only find these tests. This seems to only be a basic sanity check. Do you know of any others?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh @manstis , correct! Sorry for the confusion.

"Invalid response",
)

@override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE")
Expand Down
4 changes: 2 additions & 2 deletions ansible_ai_connect/ai/api/utils/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ def send_chatbot_event(event: ChatBotBaseEvent, event_name: str, user: User) ->
if not settings.SEGMENT_WRITE_KEY:
logger.info("segment write key not set, skipping event")
return
if _is_segment_message_exceeds_limit(asdict(event)):
if is_segment_message_exceeds_limit(asdict(event)):
# Prioritize the prompt and referenced documents.
event.chat_response = ""
base_send_segment_event(asdict(event), event_name, user, analytics)


def _is_segment_message_exceeds_limit(msg_dict):
def is_segment_message_exceeds_limit(msg_dict):
msg_dict = clean(msg_dict)
msg_size = len(json.dumps(msg_dict, cls=DatetimeSerializer).encode())
if msg_size > MAX_MSG_SIZE:
Expand Down
184 changes: 78 additions & 106 deletions ansible_ai_connect/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

import logging
import time
import timeit
import traceback
from string import Template

from ansible_anonymizer import anonymizer
from attr import asdict
from django.apps import apps
from django.conf import settings
from django_prometheus.conf import NAMESPACE
Expand Down Expand Up @@ -134,10 +134,10 @@
)
from .telemetry.schema1 import (
ChatBotFeedbackEvent,
ChatBotOperationalEvent,
ChatBotResponseDocsReferences,
anonymize_struct,
)
from .utils.segment import send_segment_event
from .utils.segment import is_segment_message_exceeds_limit, send_segment_event

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -176,9 +176,18 @@


class AACSAPIView(APIView):
def __init__(self):
manstis marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.event = None
self.model_id = None
self.exception = None

def initialize_request(self, request, *args, **kwargs):
# Call super first to ensure request object is correctly
manstis marked this conversation as resolved.
Show resolved Hide resolved
# initialised before instantiating event
initialised_request = super().initialize_request(request, *args, **kwargs)
self.event: schema1.Schema1Event = self.schema1_event()
self.event.set_request(request)
self.event.set_request(initialised_request)

# TODO: when we will move the request_serializer handling in this
# class, we will change this line below. The model_id attribute
Expand All @@ -187,7 +196,7 @@ def initialize_request(self, request, *args, **kwargs):
# See: https://github.com/ansible/ansible-ai-connect-service/pull/1147/files#diff-ecfb6919dfd8379aafba96af7457b253e4dce528897dfe6bfc207ca2b3b2ada9R143-R151 # noqa: E501
self.model_id: str = ""

return super().initialize_request(request, *args, **kwargs)
return initialised_request

def handle_exception(self, exc):
self.exception = exc
Expand Down Expand Up @@ -1109,7 +1118,7 @@ def post(self, request) -> Response:
)


class Chat(APIView):
class Chat(AACSAPIView):
"""
Send a message to the backend chatbot service and get a reply.
"""
Expand All @@ -1124,6 +1133,7 @@ class ChatEndpointThrottle(EndpointRateThrottle):
]
required_scopes = ["read", "write"]
throttle_classes = [ChatEndpointThrottle]
schema1_event = schema1.ChatBotOperationalEvent

def __init__(self):
self.chatbot_enabled = (
Expand Down Expand Up @@ -1152,114 +1162,76 @@ def __init__(self):
def post(self, request) -> Response:
headers = {"Content-Type": "application/json"}
request_serializer = ChatRequestSerializer(data=request.data)
rh_user_org_id = getattr(request.user, "org_id", None)

data = {}
req_query = "<undefined>"
req_system_prompt = "<undefined>"
req_model_id = "<undefined>"
req_provider = "<undefined>"
duration = 0
operational_event = {}

try:
if not self.chatbot_enabled:
raise ChatbotNotEnabledException()

if not request_serializer.is_valid():
raise ChatbotInvalidRequestException()

req_query = request_serializer.validated_data["query"]
req_system_prompt = (
request_serializer.validated_data["system_prompt"]
if "system_prompt" in request_serializer.validated_data
else None
)
req_model_id = (
request_serializer.validated_data["model"]
if "model" in request_serializer.validated_data
else settings.CHATBOT_DEFAULT_MODEL
)
req_provider = (
request_serializer.validated_data["provider"]
if "provider" in request_serializer.validated_data
else settings.CHATBOT_DEFAULT_PROVIDER
)
conversation_id = (
request_serializer.validated_data["conversation_id"]
if "conversation_id" in request_serializer.validated_data
else None
)
if not self.chatbot_enabled:
raise ChatbotNotEnabledException()

start = timeit.default_timer()
if not request_serializer.is_valid():
manstis marked this conversation as resolved.
Show resolved Hide resolved
raise ChatbotInvalidRequestException()

llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline(
ModelPipelineChatBot
)

data = llm.invoke(
ChatBotParameters.init(
request=request,
query=req_query,
system_prompt=req_system_prompt,
model_id=req_model_id,
provider=req_provider,
conversation_id=conversation_id,
)
)

duration = timeit.default_timer() - start
req_query = request_serializer.validated_data["query"]
req_system_prompt = (
request_serializer.validated_data["system_prompt"]
if "system_prompt" in request_serializer.validated_data
else None
)
req_model_id = (
request_serializer.validated_data["model"]
if "model" in request_serializer.validated_data
else settings.CHATBOT_DEFAULT_MODEL
)
req_provider = (
request_serializer.validated_data["provider"]
if "provider" in request_serializer.validated_data
else settings.CHATBOT_DEFAULT_PROVIDER
)
conversation_id = (
request_serializer.validated_data["conversation_id"]
if "conversation_id" in request_serializer.validated_data
else None
)

response_serializer = ChatResponseSerializer(data=data)
# Initialise Segment Event early, in case of exceptions
self.model_id = req_model_id
self.event.chat_prompt = anonymize_struct(req_query)
self.event.chat_system_prompt = req_system_prompt
self.event.provider_id = req_provider
self.event.conversation_id = conversation_id
self.event.modelName = req_model_id

if not response_serializer.is_valid():
raise ChatbotInvalidResponseException()
llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline(
ModelPipelineChatBot
)

operational_event = ChatBotOperationalEvent(
chat_prompt=req_query,
chat_system_prompt=req_system_prompt,
chat_response=data["response"],
chat_truncated=bool(data["truncated"]),
chat_referenced_documents=[
ChatBotResponseDocsReferences(docs_url=rd["docs_url"], title=rd["title"])
for rd in data["referenced_documents"]
],
conversation_id=data["conversation_id"],
provider_id=req_provider,
modelName=req_model_id,
rh_user_org_id=rh_user_org_id,
req_duration=duration,
data = llm.invoke(
ChatBotParameters.init(
request=request,
query=req_query,
system_prompt=req_system_prompt,
model_id=req_model_id,
provider=req_provider,
conversation_id=conversation_id,
)
)

return Response(
data,
status=rest_framework_status.HTTP_200_OK,
headers=headers,
)
response_serializer = ChatResponseSerializer(data=data)

except Exception as exc:
if "detail" in data:
detail = data.get("detail", "")
operational_event = ChatBotOperationalEvent(
chat_prompt=req_query,
chat_system_prompt=req_system_prompt,
provider_id=req_provider,
modelName=req_model_id,
rh_user_org_id=rh_user_org_id,
req_duration=duration,
exception=detail,
)
else:
exception_message = (
f"An exception {exc.__class__} occurred during a chat generation"
)
logger.exception(exception_message)
operational_event = ChatBotOperationalEvent(
exception=exception_message,
rh_user_org_id=rh_user_org_id,
)
if not response_serializer.is_valid():
raise ChatbotInvalidResponseException()

raise exc
# Finalise Segment Event with response details
self.event.chat_truncated = bool(data["truncated"])
self.event.chat_referenced_documents = [
ChatBotResponseDocsReferences(docs_url=rd["docs_url"], title=rd["title"])
for rd in data["referenced_documents"]
]
self.event.chat_response = anonymize_struct(data["response"])
self.event.chat_response = (
"" if is_segment_message_exceeds_limit(asdict(self.event)) else self.event.chat_response
)

finally:
send_chatbot_event(operational_event, "chatOperationalEvent", request.user)
return Response(
data,
status=rest_framework_status.HTTP_200_OK,
headers=headers,
)
Loading