Skip to content

Commit

Permalink
expose /generations/playbook to be consistent with roleGen
Browse files Browse the repository at this point in the history
This way we have:
- `/api/v0/ai/generations/playbook` for playbook generation
and `/api/v0/ai/generations/role` for roles.
  • Loading branch information
goneri committed Dec 16, 2024
1 parent 1a7622f commit 8072274
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 30 deletions.
4 changes: 2 additions & 2 deletions ansible_ai_connect/ai/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ class ExplanationResponseSerializer(serializers.Serializer):
)


class GenerationRequestSerializer(serializers.Serializer):
class GenerationPlaybookRequestSerializer(serializers.Serializer):
class Meta:
fields = [
"text",
Expand Down Expand Up @@ -611,7 +611,7 @@ class GenerationRoleFileEntrySerializer(serializers.Serializer):
file_type = serializers.CharField()


class GenerationResponseSerializer(serializers.Serializer):
class GenerationPlaybookResponseSerializer(serializers.Serializer):
playbook = serializers.CharField()
format = serializers.CharField()
generationId = serializers.UUIDField(
Expand Down
30 changes: 15 additions & 15 deletions ansible_ai_connect/ai/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3262,7 +3262,7 @@ def test_ok(self):
"ansibleExtensionVersion": "24.4.0",
}
self.client.force_authenticate(user=self.user)
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.OK)
self.assertIsNotNone(r.data["playbook"])
self.assertEqual(r.data["format"], "plaintext")
Expand All @@ -3280,7 +3280,7 @@ def test_ok_with_model_id(self):
}
self.client.force_authenticate(user=self.user)
with self.assertLogs(logger="root", level="DEBUG") as log:
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
segment_events = self.extractSegmentEventsFromLog(log)
self.assertEqual(segment_events[0]["properties"]["modelName"], "mymodel")
self.assertEqual(r.status_code, HTTPStatus.OK)
Expand All @@ -3303,7 +3303,7 @@ def test_with_pii(self):
Mock(return_value=mocked_client),
):
self.client.force_authenticate(user=self.user)
self.client.post(reverse("generations"), payload, format="json")
self.client.post(reverse("generations/playbook"), payload, format="json")

args: PlaybookGenerationParameters = mocked_client.invoke.call_args[0][0]
self.assertFalse(args.create_outline)
Expand All @@ -3322,7 +3322,7 @@ def test_unauthorized(self):
Mock(return_value=MockedPipelinePlaybookGeneration(self.response_data)),
):
# Hit the API without authentication
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.UNAUTHORIZED)

@override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=True)
Expand All @@ -3336,7 +3336,7 @@ def test_bad_request(self):
Mock(return_value=MockedPipelinePlaybookGeneration(self.response_data)),
):
self.client.force_authenticate(user=self.user)
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.BAD_REQUEST)

@override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=True)
Expand All @@ -3354,7 +3354,7 @@ def test_with_anonymized_response(self):
Mock(return_value=MockedPipelinePlaybookGeneration(self.response_pii_data)),
):
self.client.force_authenticate(user=self.user)
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.OK)
self.assertIsNotNone(r.data["playbook"])
self.assertIsNotNone(r.data["outline"])
Expand All @@ -3375,7 +3375,7 @@ def test_service_unavailable(self, invoke):
}
self.client.force_authenticate(user=self.user)
with self.assertRaises(Exception):
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.SERVICE_UNAVAILABLE)

@override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=True)
Expand All @@ -3395,7 +3395,7 @@ def test_with_custom_prompt_valid(self):
Mock(return_value=mocked_client),
):
self.client.force_authenticate(user=self.user)
self.client.post(reverse("generations"), payload, format="json")
self.client.post(reverse("generations/playbook"), payload, format="json")

args: PlaybookGenerationParameters = mocked_client.invoke.call_args[0][0]
self.assertFalse(args.create_outline)
Expand All @@ -3422,7 +3422,7 @@ def test_with_custom_prompt_blank(self):
Mock(return_value=mocked_client),
):
self.client.force_authenticate(user=self.user)
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.BAD_REQUEST)
self.assertFalse(mocked_client.generate_playbook.called)
self.assertIn("detail", r.data)
Expand All @@ -3448,7 +3448,7 @@ def test_with_custom_prompt_missing_goal(self):
Mock(return_value=mocked_client),
):
self.client.force_authenticate(user=self.user)
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.BAD_REQUEST)
self.assertFalse(mocked_client.generate_playbook.called)
self.assertIn("detail", r.data)
Expand All @@ -3472,7 +3472,7 @@ def test_with_custom_prompt_missing_outline(self):
Mock(return_value=mocked_client),
):
self.client.force_authenticate(user=self.user)
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.BAD_REQUEST)
self.assertFalse(mocked_client.generate_playbook.called)
self.assertIn("detail", r.data)
Expand All @@ -3498,7 +3498,7 @@ def test_with_custom_prompt_missing_outline_when_not_needed(self):
Mock(return_value=mocked_client),
):
self.client.force_authenticate(user=self.user)
self.client.post(reverse("generations"), payload, format="json")
self.client.post(reverse("generations/playbook"), payload, format="json")

args: PlaybookGenerationParameters = mocked_client.invoke.call_args[0][0]
self.assertFalse(args.create_outline)
Expand Down Expand Up @@ -3587,7 +3587,7 @@ def assert_test(
Mock(return_value=model_client),
):
with self.assertLogs(logger="root", level="DEBUG") as log:
r = self.client.post(reverse("generations"), self.payload, format="json")
r = self.client.post(reverse("generations/playbook"), self.payload, format="json")
self.assertEqual(r.status_code, expected_status_code)
if expected_exception() is not None:
self.assert_error_detail(
Expand Down Expand Up @@ -3898,7 +3898,7 @@ def stub_wca_client(self):
@override_settings(ANSIBLE_AI_ENABLE_PLAYBOOK_ENDPOINT=False)
def test_feature_not_enabled_yet(self):
self.client.force_login(user=self.aap_user)
r = self.client.post(reverse("generations"), self.payload_json, format="json")
r = self.client.post(reverse("generations/playbook"), self.payload_json, format="json")
self.assertEqual(r.status_code, 404)

@override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=False)
Expand All @@ -3912,7 +3912,7 @@ def test_feature_enabled(self):
"get_model_pipeline",
Mock(return_value=self.stub_wca_client()),
):
r = self.client.post(reverse("generations"), self.payload_json, format="json")
r = self.client.post(reverse("generations/playbook"), self.payload_json, format="json")
self.assertEqual(r.status_code, HTTPStatus.OK)
self.assertEqual(r.data["playbook"], "---\n- hosts: all\n")
self.assertEqual(r.data["format"], "plaintext")
Expand Down
6 changes: 4 additions & 2 deletions ansible_ai_connect/ai/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
ContentMatches,
Explanation,
Feedback,
Generation,
GenerationPlaybook,
GenerationRole,
)

urlpatterns = [
path("completions/", Completions.as_view(), name="completions"),
path("contentmatches/", ContentMatches.as_view(), name="contentmatches"),
path("explanations/", Explanation.as_view(), name="explanations"),
path("generations/", Generation.as_view(), name="generations"),
# Legacy
path("generations/", GenerationPlaybook.as_view(), name="generations"),
path("generations/playbook", GenerationPlaybook.as_view(), name="generations/playbook"),
path("generations/role", GenerationRole.as_view(), name="generations/role"),
path("feedback/", Feedback.as_view(), name="feedback"),
path("chat/", Chat.as_view(), name="chat"),
Expand Down
14 changes: 7 additions & 7 deletions ansible_ai_connect/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@
ExplanationRequestSerializer,
ExplanationResponseSerializer,
FeedbackRequestSerializer,
GenerationRequestSerializer,
GenerationResponseSerializer,
GenerationPlaybookRequestSerializer,
GenerationPlaybookResponseSerializer,
GenerationRoleRequestSerializer,
GenerationRoleResponseSerializer,
InlineSuggestionFeedback,
Expand Down Expand Up @@ -797,20 +797,20 @@ def post(self, request) -> Response:
)


class Generation(APIView):
class GenerationPlaybook(APIView):
"""
Returns a playbook based on a text input.
"""

permission_classes = PERMISSIONS_MAP.get(settings.DEPLOYMENT_MODE)
required_scopes = ["read", "write"]

throttle_cache_key_suffix = "_generation"
throttle_cache_key_suffix = "_generation_playbook"

@extend_schema(
request=GenerationRequestSerializer,
request=GenerationPlaybookRequestSerializer,
responses={
200: GenerationResponseSerializer,
200: GenerationPlaybookResponseSerializer,
204: OpenApiResponse(description="Empty response"),
400: OpenApiResponse(description="Bad Request"),
401: OpenApiResponse(description="Unauthorized"),
Expand All @@ -827,7 +827,7 @@ def post(self, request) -> Response:
create_outline = None
anonymized_playbook = ""
playbook = ""
request_serializer = GenerationRequestSerializer(data=request.data)
request_serializer = GenerationPlaybookRequestSerializer(data=request.data)
answer = {}
model_id = ""
try:
Expand Down
49 changes: 45 additions & 4 deletions tools/openapi-schema/ansible-ai-connect-service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,54 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/GenerationRequest'
$ref: '#/components/schemas/GenerationPlaybookRequest'
application/x-www-form-urlencoded:
schema:
$ref: '#/components/schemas/GenerationRequest'
$ref: '#/components/schemas/GenerationPlaybookRequest'
multipart/form-data:
schema:
$ref: '#/components/schemas/GenerationRequest'
$ref: '#/components/schemas/GenerationPlaybookRequest'
required: true
security:
- oauth2:
- read
- write
- cookieAuth: []
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/GenerationResponse'
description: ''
'204':
description: Empty response
'400':
description: Bad Request
'401':
description: Unauthorized
'429':
description: Request was throttled
'503':
description: Service Unavailable
/api/v0/ai/generations/playbook:
post:
operationId: ai_generations_playbook_create
description: Returns a playbook based on a text input.
summary: Inline code suggestions
tags:
- ai
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/GenerationPlaybookRequest'
application/x-www-form-urlencoded:
schema:
$ref: '#/components/schemas/GenerationPlaybookRequest'
multipart/form-data:
schema:
$ref: '#/components/schemas/GenerationPlaybookRequest'
required: true
security:
- oauth2:
Expand Down Expand Up @@ -920,7 +961,7 @@ components:
$ref: '#/components/schemas/SuggestionQualityFeedback'
chatFeedback:
$ref: '#/components/schemas/ChatFeedback'
GenerationRequest:
GenerationPlaybookRequest:
type: object
properties:
text:
Expand Down

0 comments on commit 8072274

Please sign in to comment.