From 7daa9f804ca56f4ab81359df5a1ea789032f368d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gon=C3=A9ri=20Le=20Bouder?= Date: Mon, 16 Dec 2024 16:05:09 -0500 Subject: [PATCH] AACSAPIView: validate the payload in the class Validate the view parameter from within AACSAPIView. --- ansible_ai_connect/ai/api/serializers.py | 8 +- ansible_ai_connect/ai/api/views.py | 101 ++++++------------ .../ansible-ai-connect-service.yaml | 7 ++ 3 files changed, 49 insertions(+), 67 deletions(-) diff --git a/ansible_ai_connect/ai/api/serializers.py b/ansible_ai_connect/ai/api/serializers.py index f29c18dd4..e4b2e2f0f 100644 --- a/ansible_ai_connect/ai/api/serializers.py +++ b/ansible_ai_connect/ai/api/serializers.py @@ -424,6 +424,7 @@ class ExplanationRequestSerializer(Metadata): required=False, label="Custom prompt", help_text="Custom prompt passed to the LLM when explaining a playbook.", + default="", ) explanationId = serializers.UUIDField( format="hex_verbose", @@ -432,8 +433,9 @@ class ExplanationRequestSerializer(Metadata): help_text=( "A UUID that identifies the particular explanation data is being requested for." ), + default="", ) - model = serializers.CharField(required=False, allow_blank=True) + model = serializers.CharField(required=False, allow_blank=True, default="") metadata = Metadata(required=False) def validate(self, data): @@ -483,12 +485,14 @@ class Meta: required=False, label="Custom prompt", help_text="Custom prompt passed to the LLM when generating the text of a playbook.", + default="", ) generationId = serializers.UUIDField( format="hex_verbose", required=False, label="generation ID", help_text=("A UUID that identifies the particular generation data is being requested for."), + default="", ) createOutline = serializers.BooleanField( required=False, @@ -503,12 +507,14 @@ class Meta: required=False, label="outline", help_text="A long step by step outline of the expected Ansible Playbook.", + default="", ) wizardId = serializers.UUIDField( format="hex_verbose", required=False, label="wizard ID", help_text=("A UUID to track the succession of interaction from the user."), + default="", ) model = serializers.CharField(required=False, allow_blank=True) diff --git a/ansible_ai_connect/ai/api/views.py b/ansible_ai_connect/ai/api/views.py index 006fccaed..f86edf856 100644 --- a/ansible_ai_connect/ai/api/views.py +++ b/ansible_ai_connect/ai/api/views.py @@ -189,15 +189,19 @@ def initialize_request(self, request, *args, **kwargs): self.event: schema1.Schema1Event = self.schema1_event() self.event.set_request(initialised_request) - # TODO: when we will move the request_serializer handling in this - # class, we will change this line below. The model_id attribute - # should exposed as self.validated_data.model, or using a special - # method. - # See: https://github.com/ansible/ansible-ai-connect-service/pull/1147/files#diff-ecfb6919dfd8379aafba96af7457b253e4dce528897dfe6bfc207ca2b3b2ada9R143-R151 # noqa: E501 - self.req_model_id: str = "" + if hasattr(request, "data"): + self.load_parameters(request) return initialised_request + def load_parameters(self, request) -> Response: + if hasattr(self, "request_serializer_class"): + request_serializer = self.request_serializer_class(data=request.data) + request_serializer.is_valid(raise_exception=True) + self.validated_data = request_serializer.validated_data + if req_model_id := self.validated_data.get("mode_id"): + self.req_model_id = req_model_id + def handle_exception(self, exc): self.exception = exc @@ -747,7 +751,7 @@ class Explanation(AACSAPIView): permission_classes = PERMISSIONS_MAP.get(settings.DEPLOYMENT_MODE) required_scopes = ["read", "write"] schema1_event = schema1.ExplainPlaybookEvent - + request_serializer_class = ExplanationRequestSerializer throttle_cache_key_suffix = "_explanation" @extend_schema( @@ -763,32 +767,20 @@ class Explanation(AACSAPIView): summary="Inline code suggestions", ) def post(self, request) -> Response: - explanation_id: str = None - playbook: str = "" - answer = {} - - # TODO: This request_serializer block will move in AACSAPIView later - request_serializer = ExplanationRequestSerializer(data=request.data) - request_serializer.is_valid(raise_exception=True) - explanation_id = str(request_serializer.validated_data.get("explanationId", "")) - playbook = request_serializer.validated_data.get("content") - custom_prompt = str(request_serializer.validated_data.get("customPrompt", "")) - self.req_model_id = str(request_serializer.validated_data.get("model", "")) - llm: ModelPipelinePlaybookExplanation = apps.get_app_config("ai").get_model_pipeline( ModelPipelinePlaybookExplanation ) explanation = llm.invoke( PlaybookExplanationParameters.init( request=request, - content=playbook, - custom_prompt=custom_prompt, - explanation_id=explanation_id, - model_id=self.model_id, + content=self.validated_data["content"], + custom_prompt=self.validated_data["customPrompt"], + explanation_id=self.validated_data["explanationId"], + model_id=self.validated_data["model"], ) ) - self.event.playbook_length = len(playbook) - self.event.explanationId = explanation_id + self.event.playbook_length = self.validated_data["content"] + self.event.explanationId = self.validated_data["explanationId"] # Anonymize response # Anonymized in the View to be consistent with where Completions are anonymized @@ -799,7 +791,7 @@ def post(self, request) -> Response: answer = { "content": anonymized_explanation, "format": "markdown", - "explanationId": explanation_id, + "explanationId": self.validated_data["explanationId"], } return Response( @@ -816,7 +808,7 @@ class GenerationPlaybook(AACSAPIView): permission_classes = PERMISSIONS_MAP.get(settings.DEPLOYMENT_MODE) required_scopes = ["read", "write"] schema1_event = schema1.GenerationPlaybookEvent - + request_serializer_class = GenerationPlaybookRequestSerializer throttle_cache_key_suffix = "_generation_playbook" @extend_schema( @@ -832,43 +824,25 @@ class GenerationPlaybook(AACSAPIView): summary="Inline code suggestions", ) def post(self, request) -> Response: - generation_id = None - wizard_id = None - create_outline = None - anonymized_playbook = "" - playbook = "" - answer = {} - model_id = "" - request_serializer = GenerationPlaybookRequestSerializer(data=request.data) - request_serializer.is_valid(raise_exception=True) - generation_id = str(request_serializer.validated_data.get("generationId", "")) - create_outline = request_serializer.validated_data["createOutline"] - outline = str(request_serializer.validated_data.get("outline", "")) - text = request_serializer.validated_data["text"] - custom_prompt = str(request_serializer.validated_data.get("customPrompt", "")) - wizard_id = str(request_serializer.validated_data.get("wizardId", "")) - self.req_model_id = str(request_serializer.validated_data.get("model", "")) - - self.event.generationId = generation_id - self.event.wizardId = wizard_id - self.event.modelName = model_id - - self.event.create_outline = create_outline + self.event.create_outline = self.validated_data["createOutline"] + llm: ModelPipelinePlaybookGeneration = apps.get_app_config("ai").get_model_pipeline( ModelPipelinePlaybookGeneration ) playbook, outline, warnings = llm.invoke( PlaybookGenerationParameters.init( request=request, - text=text, - custom_prompt=custom_prompt, - create_outline=create_outline, - outline=outline, - generation_id=generation_id, - model_id=model_id, + text=self.validated_data["text"], + custom_prompt=self.validated_data["customPrompt"], + create_outline=self.validated_data["createOutline"], + outline=self.validated_data["outline"], + generation_id=self.validated_data["generationId"], + model_id=self.model_id, ) ) + self.event.generationId = self.validated_data["generationId"] + self.event.wizard_id = self.validated_data["wizardId"] # Anonymize responses # Anonymized in the View to be consistent with where Completions are anonymized @@ -885,7 +859,7 @@ def post(self, request) -> Response: "outline": anonymized_outline, "warnings": warnings, "format": "plaintext", - "generationId": generation_id, + "generationId": self.validated_data["generationId"], } return Response( @@ -1007,8 +981,9 @@ class ChatEndpointThrottle(EndpointRateThrottle): IsRHEmployee | IsTestUser, ] required_scopes = ["read", "write"] - throttle_classes = [ChatEndpointThrottle] schema1_event = schema1.ChatBotOperationalEvent + request_serializer_class = ExplanationRequestSerializer + throttle_classes = [ChatEndpointThrottle] def __init__(self): self.chatbot_enabled = ( @@ -1036,7 +1011,6 @@ def __init__(self): ) def post(self, request) -> Response: headers = {"Content-Type": "application/json"} - request_serializer = ChatRequestSerializer(data=request.data) if not self.chatbot_enabled: raise ChatbotNotEnabledException() @@ -1046,17 +1020,12 @@ def post(self, request) -> Response: req_query = request_serializer.validated_data["query"] req_system_prompt = ( - request_serializer.validated_data["system_prompt"] + self.validated_data["system_prompt"] if "system_prompt" in request_serializer.validated_data else None ) - req_model_id = ( - request_serializer.validated_data["model"] - if "model" in request_serializer.validated_data - else settings.CHATBOT_DEFAULT_MODEL - ) req_provider = ( - request_serializer.validated_data["provider"] + self.validated_data["provider"] if "provider" in request_serializer.validated_data else settings.CHATBOT_DEFAULT_PROVIDER ) @@ -1072,7 +1041,7 @@ def post(self, request) -> Response: self.event.chat_system_prompt = req_system_prompt self.event.provider_id = req_provider self.event.conversation_id = conversation_id - self.event.modelName = req_model_id + self.event.modelName = self.req_model_id or settings.CHATBOT_DEFAULT_MODEL llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline( ModelPipelineChatBot diff --git a/tools/openapi-schema/ansible-ai-connect-service.yaml b/tools/openapi-schema/ansible-ai-connect-service.yaml index b26c7ce76..45e01f441 100644 --- a/tools/openapi-schema/ansible-ai-connect-service.yaml +++ b/tools/openapi-schema/ansible-ai-connect-service.yaml @@ -907,16 +907,19 @@ components: description: The playbook that needs to be explained. customPrompt: type: string + default: '' title: Custom prompt description: Custom prompt passed to the LLM when explaining a playbook. explanationId: type: string format: uuid + default: '' title: Explanation ID description: A UUID that identifies the particular explanation data is being requested for. model: type: string + default: '' metadata: $ref: '#/components/schemas/Metadata' required: @@ -974,12 +977,14 @@ components: description: The description that needs to be converted to a playbook. customPrompt: type: string + default: '' title: Custom prompt description: Custom prompt passed to the LLM when generating the text of a playbook. generationId: type: string format: uuid + default: '' title: generation ID description: A UUID that identifies the particular generation data is being requested for. @@ -991,10 +996,12 @@ components: of the Ansible Playbook. outline: type: string + default: '' description: A long step by step outline of the expected Ansible Playbook. wizardId: type: string format: uuid + default: '' title: wizard ID description: A UUID to track the succession of interaction from the user. model: