Skip to content

Commit

Permalink
AACSAPIView: validate the payload in the class
Browse files Browse the repository at this point in the history
Validate the view parameter from within AACSAPIView.
  • Loading branch information
goneri committed Jan 9, 2025
1 parent 5f271e6 commit 7daa9f8
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 67 deletions.
8 changes: 7 additions & 1 deletion ansible_ai_connect/ai/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ class ExplanationRequestSerializer(Metadata):
required=False,
label="Custom prompt",
help_text="Custom prompt passed to the LLM when explaining a playbook.",
default="",
)
explanationId = serializers.UUIDField(
format="hex_verbose",
Expand All @@ -432,8 +433,9 @@ class ExplanationRequestSerializer(Metadata):
help_text=(
"A UUID that identifies the particular explanation data is being requested for."
),
default="",
)
model = serializers.CharField(required=False, allow_blank=True)
model = serializers.CharField(required=False, allow_blank=True, default="")
metadata = Metadata(required=False)

def validate(self, data):
Expand Down Expand Up @@ -483,12 +485,14 @@ class Meta:
required=False,
label="Custom prompt",
help_text="Custom prompt passed to the LLM when generating the text of a playbook.",
default="",
)
generationId = serializers.UUIDField(
format="hex_verbose",
required=False,
label="generation ID",
help_text=("A UUID that identifies the particular generation data is being requested for."),
default="",
)
createOutline = serializers.BooleanField(
required=False,
Expand All @@ -503,12 +507,14 @@ class Meta:
required=False,
label="outline",
help_text="A long step by step outline of the expected Ansible Playbook.",
default="",
)
wizardId = serializers.UUIDField(
format="hex_verbose",
required=False,
label="wizard ID",
help_text=("A UUID to track the succession of interaction from the user."),
default="",
)
model = serializers.CharField(required=False, allow_blank=True)

Expand Down
101 changes: 35 additions & 66 deletions ansible_ai_connect/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,19 @@ def initialize_request(self, request, *args, **kwargs):
self.event: schema1.Schema1Event = self.schema1_event()
self.event.set_request(initialised_request)

# TODO: when we will move the request_serializer handling in this
# class, we will change this line below. The model_id attribute
# should exposed as self.validated_data.model, or using a special
# method.
# See: https://github.com/ansible/ansible-ai-connect-service/pull/1147/files#diff-ecfb6919dfd8379aafba96af7457b253e4dce528897dfe6bfc207ca2b3b2ada9R143-R151 # noqa: E501
self.req_model_id: str = ""
if hasattr(request, "data"):
self.load_parameters(request)

return initialised_request

def load_parameters(self, request) -> Response:
if hasattr(self, "request_serializer_class"):
request_serializer = self.request_serializer_class(data=request.data)
request_serializer.is_valid(raise_exception=True)
self.validated_data = request_serializer.validated_data
if req_model_id := self.validated_data.get("mode_id"):
self.req_model_id = req_model_id

def handle_exception(self, exc):
self.exception = exc

Expand Down Expand Up @@ -747,7 +751,7 @@ class Explanation(AACSAPIView):
permission_classes = PERMISSIONS_MAP.get(settings.DEPLOYMENT_MODE)
required_scopes = ["read", "write"]
schema1_event = schema1.ExplainPlaybookEvent

request_serializer_class = ExplanationRequestSerializer
throttle_cache_key_suffix = "_explanation"

@extend_schema(
Expand All @@ -763,32 +767,20 @@ class Explanation(AACSAPIView):
summary="Inline code suggestions",
)
def post(self, request) -> Response:
explanation_id: str = None
playbook: str = ""
answer = {}

# TODO: This request_serializer block will move in AACSAPIView later
request_serializer = ExplanationRequestSerializer(data=request.data)
request_serializer.is_valid(raise_exception=True)
explanation_id = str(request_serializer.validated_data.get("explanationId", ""))
playbook = request_serializer.validated_data.get("content")
custom_prompt = str(request_serializer.validated_data.get("customPrompt", ""))
self.req_model_id = str(request_serializer.validated_data.get("model", ""))

llm: ModelPipelinePlaybookExplanation = apps.get_app_config("ai").get_model_pipeline(
ModelPipelinePlaybookExplanation
)
explanation = llm.invoke(
PlaybookExplanationParameters.init(
request=request,
content=playbook,
custom_prompt=custom_prompt,
explanation_id=explanation_id,
model_id=self.model_id,
content=self.validated_data["content"],
custom_prompt=self.validated_data["customPrompt"],
explanation_id=self.validated_data["explanationId"],
model_id=self.validated_data["model"],
)
)
self.event.playbook_length = len(playbook)
self.event.explanationId = explanation_id
self.event.playbook_length = self.validated_data["content"]
self.event.explanationId = self.validated_data["explanationId"]

# Anonymize response
# Anonymized in the View to be consistent with where Completions are anonymized
Expand All @@ -799,7 +791,7 @@ def post(self, request) -> Response:
answer = {
"content": anonymized_explanation,
"format": "markdown",
"explanationId": explanation_id,
"explanationId": self.validated_data["explanationId"],
}

return Response(
Expand All @@ -816,7 +808,7 @@ class GenerationPlaybook(AACSAPIView):
permission_classes = PERMISSIONS_MAP.get(settings.DEPLOYMENT_MODE)
required_scopes = ["read", "write"]
schema1_event = schema1.GenerationPlaybookEvent

request_serializer_class = GenerationPlaybookRequestSerializer
throttle_cache_key_suffix = "_generation_playbook"

@extend_schema(
Expand All @@ -832,43 +824,25 @@ class GenerationPlaybook(AACSAPIView):
summary="Inline code suggestions",
)
def post(self, request) -> Response:
generation_id = None
wizard_id = None
create_outline = None
anonymized_playbook = ""
playbook = ""
answer = {}
model_id = ""

request_serializer = GenerationPlaybookRequestSerializer(data=request.data)
request_serializer.is_valid(raise_exception=True)
generation_id = str(request_serializer.validated_data.get("generationId", ""))
create_outline = request_serializer.validated_data["createOutline"]
outline = str(request_serializer.validated_data.get("outline", ""))
text = request_serializer.validated_data["text"]
custom_prompt = str(request_serializer.validated_data.get("customPrompt", ""))
wizard_id = str(request_serializer.validated_data.get("wizardId", ""))
self.req_model_id = str(request_serializer.validated_data.get("model", ""))

self.event.generationId = generation_id
self.event.wizardId = wizard_id
self.event.modelName = model_id

self.event.create_outline = create_outline
self.event.create_outline = self.validated_data["createOutline"]

llm: ModelPipelinePlaybookGeneration = apps.get_app_config("ai").get_model_pipeline(
ModelPipelinePlaybookGeneration
)
playbook, outline, warnings = llm.invoke(
PlaybookGenerationParameters.init(
request=request,
text=text,
custom_prompt=custom_prompt,
create_outline=create_outline,
outline=outline,
generation_id=generation_id,
model_id=model_id,
text=self.validated_data["text"],
custom_prompt=self.validated_data["customPrompt"],
create_outline=self.validated_data["createOutline"],
outline=self.validated_data["outline"],
generation_id=self.validated_data["generationId"],
model_id=self.model_id,
)
)
self.event.generationId = self.validated_data["generationId"]
self.event.wizard_id = self.validated_data["wizardId"]

# Anonymize responses
# Anonymized in the View to be consistent with where Completions are anonymized
Expand All @@ -885,7 +859,7 @@ def post(self, request) -> Response:
"outline": anonymized_outline,
"warnings": warnings,
"format": "plaintext",
"generationId": generation_id,
"generationId": self.validated_data["generationId"],
}

return Response(
Expand Down Expand Up @@ -1007,8 +981,9 @@ class ChatEndpointThrottle(EndpointRateThrottle):
IsRHEmployee | IsTestUser,
]
required_scopes = ["read", "write"]
throttle_classes = [ChatEndpointThrottle]
schema1_event = schema1.ChatBotOperationalEvent
request_serializer_class = ExplanationRequestSerializer
throttle_classes = [ChatEndpointThrottle]

def __init__(self):
self.chatbot_enabled = (
Expand Down Expand Up @@ -1036,7 +1011,6 @@ def __init__(self):
)
def post(self, request) -> Response:
headers = {"Content-Type": "application/json"}
request_serializer = ChatRequestSerializer(data=request.data)

if not self.chatbot_enabled:
raise ChatbotNotEnabledException()
Expand All @@ -1046,17 +1020,12 @@ def post(self, request) -> Response:

req_query = request_serializer.validated_data["query"]
req_system_prompt = (
request_serializer.validated_data["system_prompt"]
self.validated_data["system_prompt"]
if "system_prompt" in request_serializer.validated_data
else None
)
req_model_id = (
request_serializer.validated_data["model"]
if "model" in request_serializer.validated_data
else settings.CHATBOT_DEFAULT_MODEL
)
req_provider = (
request_serializer.validated_data["provider"]
self.validated_data["provider"]
if "provider" in request_serializer.validated_data
else settings.CHATBOT_DEFAULT_PROVIDER
)
Expand All @@ -1072,7 +1041,7 @@ def post(self, request) -> Response:
self.event.chat_system_prompt = req_system_prompt
self.event.provider_id = req_provider
self.event.conversation_id = conversation_id
self.event.modelName = req_model_id
self.event.modelName = self.req_model_id or settings.CHATBOT_DEFAULT_MODEL

llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline(
ModelPipelineChatBot
Expand Down
7 changes: 7 additions & 0 deletions tools/openapi-schema/ansible-ai-connect-service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -907,16 +907,19 @@ components:
description: The playbook that needs to be explained.
customPrompt:
type: string
default: ''
title: Custom prompt
description: Custom prompt passed to the LLM when explaining a playbook.
explanationId:
type: string
format: uuid
default: ''
title: Explanation ID
description: A UUID that identifies the particular explanation data is being
requested for.
model:
type: string
default: ''
metadata:
$ref: '#/components/schemas/Metadata'
required:
Expand Down Expand Up @@ -974,12 +977,14 @@ components:
description: The description that needs to be converted to a playbook.
customPrompt:
type: string
default: ''
title: Custom prompt
description: Custom prompt passed to the LLM when generating the text of
a playbook.
generationId:
type: string
format: uuid
default: ''
title: generation ID
description: A UUID that identifies the particular generation data is being
requested for.
Expand All @@ -991,10 +996,12 @@ components:
of the Ansible Playbook.
outline:
type: string
default: ''
description: A long step by step outline of the expected Ansible Playbook.
wizardId:
type: string
format: uuid
default: ''
title: wizard ID
description: A UUID to track the succession of interaction from the user.
model:
Expand Down

0 comments on commit 7daa9f8

Please sign in to comment.