Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose /generations/playbook to be consistent with roleGen #1465

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ansible_ai_connect/ai/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ class ExplanationResponseSerializer(serializers.Serializer):
)


class GenerationRequestSerializer(serializers.Serializer):
class GenerationPlaybookRequestSerializer(serializers.Serializer):
class Meta:
fields = [
"text",
Expand Down Expand Up @@ -616,7 +616,7 @@ class GenerationRoleFileEntrySerializer(serializers.Serializer):
file_type = serializers.CharField()


class GenerationResponseSerializer(serializers.Serializer):
class GenerationPlaybookResponseSerializer(serializers.Serializer):
playbook = serializers.CharField()
format = serializers.CharField()
generationId = serializers.UUIDField(
Expand Down
30 changes: 15 additions & 15 deletions ansible_ai_connect/ai/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3262,7 +3262,7 @@ def test_ok(self):
"ansibleExtensionVersion": "24.4.0",
}
self.client.force_authenticate(user=self.user)
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.OK)
self.assertIsNotNone(r.data["playbook"])
self.assertEqual(r.data["format"], "plaintext")
Expand All @@ -3280,7 +3280,7 @@ def test_ok_with_model_id(self):
}
self.client.force_authenticate(user=self.user)
with self.assertLogs(logger="root", level="DEBUG") as log:
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
segment_events = self.extractSegmentEventsFromLog(log)
self.assertEqual(segment_events[0]["properties"]["modelName"], "mymodel")
self.assertEqual(r.status_code, HTTPStatus.OK)
Expand All @@ -3303,7 +3303,7 @@ def test_with_pii(self):
Mock(return_value=mocked_client),
):
self.client.force_authenticate(user=self.user)
self.client.post(reverse("generations"), payload, format="json")
self.client.post(reverse("generations/playbook"), payload, format="json")

args: PlaybookGenerationParameters = mocked_client.invoke.call_args[0][0]
self.assertFalse(args.create_outline)
Expand All @@ -3322,7 +3322,7 @@ def test_unauthorized(self):
Mock(return_value=MockedPipelinePlaybookGeneration(self.response_data)),
):
# Hit the API without authentication
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.UNAUTHORIZED)

@override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=True)
Expand All @@ -3336,7 +3336,7 @@ def test_bad_request(self):
Mock(return_value=MockedPipelinePlaybookGeneration(self.response_data)),
):
self.client.force_authenticate(user=self.user)
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.BAD_REQUEST)

@override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=True)
Expand All @@ -3354,7 +3354,7 @@ def test_with_anonymized_response(self):
Mock(return_value=MockedPipelinePlaybookGeneration(self.response_pii_data)),
):
self.client.force_authenticate(user=self.user)
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.OK)
self.assertIsNotNone(r.data["playbook"])
self.assertIsNotNone(r.data["outline"])
Expand All @@ -3375,7 +3375,7 @@ def test_service_unavailable(self, invoke):
}
self.client.force_authenticate(user=self.user)
with self.assertRaises(Exception):
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.SERVICE_UNAVAILABLE)

@override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=True)
Expand All @@ -3395,7 +3395,7 @@ def test_with_custom_prompt_valid(self):
Mock(return_value=mocked_client),
):
self.client.force_authenticate(user=self.user)
self.client.post(reverse("generations"), payload, format="json")
self.client.post(reverse("generations/playbook"), payload, format="json")

args: PlaybookGenerationParameters = mocked_client.invoke.call_args[0][0]
self.assertFalse(args.create_outline)
Expand All @@ -3422,7 +3422,7 @@ def test_with_custom_prompt_blank(self):
Mock(return_value=mocked_client),
):
self.client.force_authenticate(user=self.user)
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.BAD_REQUEST)
self.assertFalse(mocked_client.generate_playbook.called)
self.assertIn("detail", r.data)
Expand All @@ -3448,7 +3448,7 @@ def test_with_custom_prompt_missing_goal(self):
Mock(return_value=mocked_client),
):
self.client.force_authenticate(user=self.user)
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.BAD_REQUEST)
self.assertFalse(mocked_client.generate_playbook.called)
self.assertIn("detail", r.data)
Expand All @@ -3472,7 +3472,7 @@ def test_with_custom_prompt_missing_outline(self):
Mock(return_value=mocked_client),
):
self.client.force_authenticate(user=self.user)
r = self.client.post(reverse("generations"), payload, format="json")
r = self.client.post(reverse("generations/playbook"), payload, format="json")
self.assertEqual(r.status_code, HTTPStatus.BAD_REQUEST)
self.assertFalse(mocked_client.generate_playbook.called)
self.assertIn("detail", r.data)
Expand All @@ -3498,7 +3498,7 @@ def test_with_custom_prompt_missing_outline_when_not_needed(self):
Mock(return_value=mocked_client),
):
self.client.force_authenticate(user=self.user)
self.client.post(reverse("generations"), payload, format="json")
self.client.post(reverse("generations/playbook"), payload, format="json")

args: PlaybookGenerationParameters = mocked_client.invoke.call_args[0][0]
self.assertFalse(args.create_outline)
Expand Down Expand Up @@ -3587,7 +3587,7 @@ def assert_test(
Mock(return_value=model_client),
):
with self.assertLogs(logger="root", level="DEBUG") as log:
r = self.client.post(reverse("generations"), self.payload, format="json")
r = self.client.post(reverse("generations/playbook"), self.payload, format="json")
self.assertEqual(r.status_code, expected_status_code)
if expected_exception() is not None:
self.assert_error_detail(
Expand Down Expand Up @@ -3898,7 +3898,7 @@ def stub_wca_client(self):
@override_settings(ANSIBLE_AI_ENABLE_PLAYBOOK_ENDPOINT=False)
def test_feature_not_enabled_yet(self):
self.client.force_login(user=self.aap_user)
r = self.client.post(reverse("generations"), self.payload_json, format="json")
r = self.client.post(reverse("generations/playbook"), self.payload_json, format="json")
self.assertEqual(r.status_code, 404)

@override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=False)
Expand All @@ -3912,7 +3912,7 @@ def test_feature_enabled(self):
"get_model_pipeline",
Mock(return_value=self.stub_wca_client()),
):
r = self.client.post(reverse("generations"), self.payload_json, format="json")
r = self.client.post(reverse("generations/playbook"), self.payload_json, format="json")
self.assertEqual(r.status_code, HTTPStatus.OK)
self.assertEqual(r.data["playbook"], "---\n- hosts: all\n")
self.assertEqual(r.data["format"], "plaintext")
Expand Down
6 changes: 4 additions & 2 deletions ansible_ai_connect/ai/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
ContentMatches,
Explanation,
Feedback,
Generation,
GenerationPlaybook,
GenerationRole,
)

urlpatterns = [
path("completions/", Completions.as_view(), name="completions"),
path("contentmatches/", ContentMatches.as_view(), name="contentmatches"),
path("explanations/", Explanation.as_view(), name="explanations"),
path("generations/", Generation.as_view(), name="generations"),
# Legacy
path("generations/", GenerationPlaybook.as_view(), name="generations"),
path("generations/playbook", GenerationPlaybook.as_view(), name="generations/playbook"),
path("generations/role", GenerationRole.as_view(), name="generations/role"),
path("feedback/", Feedback.as_view(), name="feedback"),
path("chat/", Chat.as_view(), name="chat"),
Expand Down
14 changes: 7 additions & 7 deletions ansible_ai_connect/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@
ExplanationRequestSerializer,
ExplanationResponseSerializer,
FeedbackRequestSerializer,
GenerationRequestSerializer,
GenerationResponseSerializer,
GenerationPlaybookRequestSerializer,
GenerationPlaybookResponseSerializer,
GenerationRoleRequestSerializer,
GenerationRoleResponseSerializer,
InlineSuggestionFeedback,
Expand Down Expand Up @@ -797,20 +797,20 @@ def post(self, request) -> Response:
)


class Generation(APIView):
class GenerationPlaybook(APIView):
"""
Returns a playbook based on a text input.
"""

permission_classes = PERMISSIONS_MAP.get(settings.DEPLOYMENT_MODE)
required_scopes = ["read", "write"]

throttle_cache_key_suffix = "_generation"
throttle_cache_key_suffix = "_generation_playbook"

@extend_schema(
request=GenerationRequestSerializer,
request=GenerationPlaybookRequestSerializer,
responses={
200: GenerationResponseSerializer,
200: GenerationPlaybookResponseSerializer,
204: OpenApiResponse(description="Empty response"),
400: OpenApiResponse(description="Bad Request"),
401: OpenApiResponse(description="Unauthorized"),
Expand All @@ -827,7 +827,7 @@ def post(self, request) -> Response:
create_outline = None
anonymized_playbook = ""
playbook = ""
request_serializer = GenerationRequestSerializer(data=request.data)
request_serializer = GenerationPlaybookRequestSerializer(data=request.data)
answer = {}
model_id = ""
try:
Expand Down
53 changes: 47 additions & 6 deletions tools/openapi-schema/ansible-ai-connect-service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,13 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/GenerationRequest'
$ref: '#/components/schemas/GenerationPlaybookRequest'
application/x-www-form-urlencoded:
schema:
$ref: '#/components/schemas/GenerationRequest'
$ref: '#/components/schemas/GenerationPlaybookRequest'
multipart/form-data:
schema:
$ref: '#/components/schemas/GenerationRequest'
$ref: '#/components/schemas/GenerationPlaybookRequest'
required: true
security:
- oauth2:
Expand All @@ -263,7 +263,48 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/GenerationResponse'
$ref: '#/components/schemas/GenerationPlaybookResponse'
description: ''
'204':
description: Empty response
'400':
description: Bad Request
'401':
description: Unauthorized
'429':
description: Request was throttled
'503':
description: Service Unavailable
/api/v0/ai/generations/playbook:
post:
operationId: ai_generations_playbook_create
description: Returns a playbook based on a text input.
summary: Inline code suggestions
tags:
- ai
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/GenerationPlaybookRequest'
application/x-www-form-urlencoded:
schema:
$ref: '#/components/schemas/GenerationPlaybookRequest'
multipart/form-data:
schema:
$ref: '#/components/schemas/GenerationPlaybookRequest'
required: true
security:
- oauth2:
- read
- write
- cookieAuth: []
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/GenerationPlaybookResponse'
description: ''
'204':
description: Empty response
Expand Down Expand Up @@ -924,7 +965,7 @@ components:
$ref: '#/components/schemas/SuggestionQualityFeedback'
chatFeedback:
$ref: '#/components/schemas/ChatFeedback'
GenerationRequest:
GenerationPlaybookRequest:
type: object
properties:
text:
Expand Down Expand Up @@ -962,7 +1003,7 @@ components:
$ref: '#/components/schemas/Metadata'
required:
- text
GenerationResponse:
GenerationPlaybookResponse:
type: object
properties:
playbook:
Expand Down
Loading