Skip to content

Commit

Permalink
AACSAPIView: validate the payload in the class
Browse files Browse the repository at this point in the history
Validate the view parameter from within AACSAPIView.
  • Loading branch information
goneri committed Jan 9, 2025
1 parent 5f271e6 commit 2e1acca
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 101 deletions.
5 changes: 0 additions & 5 deletions ansible_ai_connect/ai/api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,6 @@ class ChatbotNotEnabledException(BaseWisdomAPIException):
default_detail = "Chatbot is not enabled"


class ChatbotInvalidRequestException(WisdomBadRequest):
default_code = "error__chatbot_invalid_request"
default_detail = "Invalid request"


class ChatbotInvalidResponseException(BaseWisdomAPIException):
status_code = 500
default_code = "error__chatbot_invalid_response"
Expand Down
8 changes: 7 additions & 1 deletion ansible_ai_connect/ai/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ class ExplanationRequestSerializer(Metadata):
required=False,
label="Custom prompt",
help_text="Custom prompt passed to the LLM when explaining a playbook.",
default="",
)
explanationId = serializers.UUIDField(
format="hex_verbose",
Expand All @@ -432,8 +433,9 @@ class ExplanationRequestSerializer(Metadata):
help_text=(
"A UUID that identifies the particular explanation data is being requested for."
),
default="",
)
model = serializers.CharField(required=False, allow_blank=True)
model = serializers.CharField(required=False, allow_blank=True, default="")
metadata = Metadata(required=False)

def validate(self, data):
Expand Down Expand Up @@ -483,12 +485,14 @@ class Meta:
required=False,
label="Custom prompt",
help_text="Custom prompt passed to the LLM when generating the text of a playbook.",
default="",
)
generationId = serializers.UUIDField(
format="hex_verbose",
required=False,
label="generation ID",
help_text=("A UUID that identifies the particular generation data is being requested for."),
default="",
)
createOutline = serializers.BooleanField(
required=False,
Expand All @@ -503,12 +507,14 @@ class Meta:
required=False,
label="outline",
help_text="A long step by step outline of the expected Ansible Playbook.",
default="",
)
wizardId = serializers.UUIDField(
format="hex_verbose",
required=False,
label="wizard ID",
help_text=("A UUID to track the succession of interaction from the user."),
default="",
)
model = serializers.CharField(required=False, allow_blank=True)

Expand Down
9 changes: 0 additions & 9 deletions ansible_ai_connect/ai/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from ansible_ai_connect.ai.api.exceptions import (
ChatbotForbiddenException,
ChatbotInternalServerException,
ChatbotInvalidRequestException,
ChatbotInvalidResponseException,
ChatbotNotEnabledException,
ChatbotPromptTooLongException,
Expand Down Expand Up @@ -4106,14 +4105,6 @@ def test_chat_not_enabled_exception(self):
TestChatView.VALID_PAYLOAD, 503, ChatbotNotEnabledException, "Chatbot is not enabled"
)

def test_chat_invalid_request_exception(self):
self.assert_test(
TestChatView.INVALID_PAYLOAD,
400,
ChatbotInvalidRequestException,
"ChatbotInvalidRequestException",
)

def test_chat_invalid_response_exception(self):
self.assert_test(
TestChatView.PAYLOAD_INVALID_RESPONSE,
Expand Down
124 changes: 38 additions & 86 deletions ansible_ai_connect/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from ansible_ai_connect.ai.api.aws.exceptions import WcaSecretManagerError
from ansible_ai_connect.ai.api.exceptions import (
BaseWisdomAPIException,
ChatbotInvalidRequestException,
ChatbotInvalidResponseException,
ChatbotNotEnabledException,
FeedbackInternalServerException,
Expand Down Expand Up @@ -189,15 +188,19 @@ def initialize_request(self, request, *args, **kwargs):
self.event: schema1.Schema1Event = self.schema1_event()
self.event.set_request(initialised_request)

# TODO: when we will move the request_serializer handling in this
# class, we will change this line below. The model_id attribute
# should exposed as self.validated_data.model, or using a special
# method.
# See: https://github.com/ansible/ansible-ai-connect-service/pull/1147/files#diff-ecfb6919dfd8379aafba96af7457b253e4dce528897dfe6bfc207ca2b3b2ada9R143-R151 # noqa: E501
self.req_model_id: str = ""
if hasattr(request, "data"):
self.load_parameters(request)

return initialised_request

def load_parameters(self, request) -> Response:
if hasattr(self, "request_serializer_class"):
request_serializer = self.request_serializer_class(data=request.data)
request_serializer.is_valid(raise_exception=True)
self.validated_data = request_serializer.validated_data
if req_model_id := self.validated_data.get("mode_id"):
self.req_model_id = req_model_id

def handle_exception(self, exc):
self.exception = exc

Expand Down Expand Up @@ -747,7 +750,7 @@ class Explanation(AACSAPIView):
permission_classes = PERMISSIONS_MAP.get(settings.DEPLOYMENT_MODE)
required_scopes = ["read", "write"]
schema1_event = schema1.ExplainPlaybookEvent

request_serializer_class = ExplanationRequestSerializer
throttle_cache_key_suffix = "_explanation"

@extend_schema(
Expand All @@ -763,32 +766,20 @@ class Explanation(AACSAPIView):
summary="Inline code suggestions",
)
def post(self, request) -> Response:
explanation_id: str = None
playbook: str = ""
answer = {}

# TODO: This request_serializer block will move in AACSAPIView later
request_serializer = ExplanationRequestSerializer(data=request.data)
request_serializer.is_valid(raise_exception=True)
explanation_id = str(request_serializer.validated_data.get("explanationId", ""))
playbook = request_serializer.validated_data.get("content")
custom_prompt = str(request_serializer.validated_data.get("customPrompt", ""))
self.req_model_id = str(request_serializer.validated_data.get("model", ""))

llm: ModelPipelinePlaybookExplanation = apps.get_app_config("ai").get_model_pipeline(
ModelPipelinePlaybookExplanation
)
explanation = llm.invoke(
PlaybookExplanationParameters.init(
request=request,
content=playbook,
custom_prompt=custom_prompt,
explanation_id=explanation_id,
model_id=self.model_id,
content=self.validated_data["content"],
custom_prompt=self.validated_data["customPrompt"],
explanation_id=self.validated_data["explanationId"],
model_id=self.validated_data["model"],
)
)
self.event.playbook_length = len(playbook)
self.event.explanationId = explanation_id
self.event.playbook_length = self.validated_data["content"]
self.event.explanationId = self.validated_data["explanationId"]

# Anonymize response
# Anonymized in the View to be consistent with where Completions are anonymized
Expand All @@ -799,7 +790,7 @@ def post(self, request) -> Response:
answer = {
"content": anonymized_explanation,
"format": "markdown",
"explanationId": explanation_id,
"explanationId": self.validated_data["explanationId"],
}

return Response(
Expand All @@ -816,7 +807,7 @@ class GenerationPlaybook(AACSAPIView):
permission_classes = PERMISSIONS_MAP.get(settings.DEPLOYMENT_MODE)
required_scopes = ["read", "write"]
schema1_event = schema1.GenerationPlaybookEvent

request_serializer_class = GenerationPlaybookRequestSerializer
throttle_cache_key_suffix = "_generation_playbook"

@extend_schema(
Expand All @@ -832,43 +823,25 @@ class GenerationPlaybook(AACSAPIView):
summary="Inline code suggestions",
)
def post(self, request) -> Response:
generation_id = None
wizard_id = None
create_outline = None
anonymized_playbook = ""
playbook = ""
answer = {}
model_id = ""

request_serializer = GenerationPlaybookRequestSerializer(data=request.data)
request_serializer.is_valid(raise_exception=True)
generation_id = str(request_serializer.validated_data.get("generationId", ""))
create_outline = request_serializer.validated_data["createOutline"]
outline = str(request_serializer.validated_data.get("outline", ""))
text = request_serializer.validated_data["text"]
custom_prompt = str(request_serializer.validated_data.get("customPrompt", ""))
wizard_id = str(request_serializer.validated_data.get("wizardId", ""))
self.req_model_id = str(request_serializer.validated_data.get("model", ""))

self.event.generationId = generation_id
self.event.wizardId = wizard_id
self.event.modelName = model_id

self.event.create_outline = create_outline
self.event.create_outline = self.validated_data["createOutline"]

llm: ModelPipelinePlaybookGeneration = apps.get_app_config("ai").get_model_pipeline(
ModelPipelinePlaybookGeneration
)
playbook, outline, warnings = llm.invoke(
PlaybookGenerationParameters.init(
request=request,
text=text,
custom_prompt=custom_prompt,
create_outline=create_outline,
outline=outline,
generation_id=generation_id,
model_id=model_id,
text=self.validated_data["text"],
custom_prompt=self.validated_data["customPrompt"],
create_outline=self.validated_data["createOutline"],
outline=self.validated_data["outline"],
generation_id=self.validated_data["generationId"],
model_id=self.model_id,
)
)
self.event.generationId = self.validated_data["generationId"]
self.event.wizard_id = self.validated_data["wizardId"]

# Anonymize responses
# Anonymized in the View to be consistent with where Completions are anonymized
Expand All @@ -885,7 +858,7 @@ def post(self, request) -> Response:
"outline": anonymized_outline,
"warnings": warnings,
"format": "plaintext",
"generationId": generation_id,
"generationId": self.validated_data["generationId"],
}

return Response(
Expand Down Expand Up @@ -1007,8 +980,9 @@ class ChatEndpointThrottle(EndpointRateThrottle):
IsRHEmployee | IsTestUser,
]
required_scopes = ["read", "write"]
throttle_classes = [ChatEndpointThrottle]
schema1_event = schema1.ChatBotOperationalEvent
request_serializer_class = ExplanationRequestSerializer
throttle_classes = [ChatEndpointThrottle]

def __init__(self):
self.chatbot_enabled = (
Expand Down Expand Up @@ -1036,43 +1010,21 @@ def __init__(self):
)
def post(self, request) -> Response:
headers = {"Content-Type": "application/json"}
request_serializer = ChatRequestSerializer(data=request.data)

if not self.chatbot_enabled:
raise ChatbotNotEnabledException()

if not request_serializer.is_valid():
raise ChatbotInvalidRequestException()

req_query = request_serializer.validated_data["query"]
req_system_prompt = (
request_serializer.validated_data["system_prompt"]
if "system_prompt" in request_serializer.validated_data
else None
)
req_model_id = (
request_serializer.validated_data["model"]
if "model" in request_serializer.validated_data
else settings.CHATBOT_DEFAULT_MODEL
)
req_provider = (
request_serializer.validated_data["provider"]
if "provider" in request_serializer.validated_data
else settings.CHATBOT_DEFAULT_PROVIDER
)
conversation_id = (
request_serializer.validated_data["conversation_id"]
if "conversation_id" in request_serializer.validated_data
else None
)
req_query = self.validated_data["query"]
req_system_prompt = self.validated_data.get("system_prompt")
req_provider = self.validated_data("provider", settings.CHATBOT_DEFAULT_PROVIDER)
conversation_id = self.validated_data.get("conversation_id")

# Initialise Segment Event early, in case of exceptions
self.model_id = req_model_id
self.event.chat_prompt = anonymize_struct(req_query)
self.event.chat_system_prompt = req_system_prompt
self.event.provider_id = req_provider
self.event.conversation_id = conversation_id
self.event.modelName = req_model_id
self.event.modelName = self.req_model_id or settings.CHATBOT_DEFAULT_MODEL

llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline(
ModelPipelineChatBot
Expand All @@ -1083,7 +1035,7 @@ def post(self, request) -> Response:
request=request,
query=req_query,
system_prompt=req_system_prompt,
model_id=req_model_id,
model_id=self.req_model_id,
provider=req_provider,
conversation_id=conversation_id,
)
Expand Down
7 changes: 7 additions & 0 deletions tools/openapi-schema/ansible-ai-connect-service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -907,16 +907,19 @@ components:
description: The playbook that needs to be explained.
customPrompt:
type: string
default: ''
title: Custom prompt
description: Custom prompt passed to the LLM when explaining a playbook.
explanationId:
type: string
format: uuid
default: ''
title: Explanation ID
description: A UUID that identifies the particular explanation data is being
requested for.
model:
type: string
default: ''
metadata:
$ref: '#/components/schemas/Metadata'
required:
Expand Down Expand Up @@ -974,12 +977,14 @@ components:
description: The description that needs to be converted to a playbook.
customPrompt:
type: string
default: ''
title: Custom prompt
description: Custom prompt passed to the LLM when generating the text of
a playbook.
generationId:
type: string
format: uuid
default: ''
title: generation ID
description: A UUID that identifies the particular generation data is being
requested for.
Expand All @@ -991,10 +996,12 @@ components:
of the Ansible Playbook.
outline:
type: string
default: ''
description: A long step by step outline of the expected Ansible Playbook.
wizardId:
type: string
format: uuid
default: ''
title: wizard ID
description: A UUID to track the succession of interaction from the user.
model:
Expand Down

0 comments on commit 2e1acca

Please sign in to comment.