From 14c441f592baaad2ce1a60042e7c55ae165b7b65 Mon Sep 17 00:00:00 2001 From: Michael Abashian Date: Mon, 3 Feb 2025 15:24:03 -0500 Subject: [PATCH] Attempt to move role gen to the saas pipeline --- .../model_pipelines/tests/test_wca_client.py | 7 +- .../api/model_pipelines/wca/pipelines_base.py | 78 ++++--------------- .../api/model_pipelines/wca/pipelines_saas.py | 74 +++++++++++++++++- 3 files changed, 90 insertions(+), 69 deletions(-) diff --git a/ansible_ai_connect/ai/api/model_pipelines/tests/test_wca_client.py b/ansible_ai_connect/ai/api/model_pipelines/tests/test_wca_client.py index 9bee4f58e..d2c00f305 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/tests/test_wca_client.py +++ b/ansible_ai_connect/ai/api/model_pipelines/tests/test_wca_client.py @@ -64,6 +64,7 @@ wca_codegen_playbook_hist, wca_codegen_playbook_retry_counter, wca_codegen_retry_counter, + wca_codegen_role_hist, wca_codematch_hist, wca_codematch_retry_counter, wca_explain_playbook_hist, @@ -537,7 +538,7 @@ def setUp(self): wca_client.session.post.return_value = response self.wca_client = wca_client - @assert_call_count_metrics(metric=wca_codegen_playbook_hist) + @assert_call_count_metrics(metric=wca_codegen_role_hist) @override_settings(ENABLE_ANSIBLE_LINT_POSTPROCESS=True) def test_role_gen_with_lint(self): fake_linter = Mock() @@ -554,9 +555,9 @@ def test_role_gen_with_lint(self): for file in files: self.assertEqual(file["content"], "I'm super fake!") - @assert_call_count_metrics(metric=wca_codegen_playbook_hist) + @assert_call_count_metrics(metric=wca_codegen_role_hist) @override_settings(ENABLE_ANSIBLE_LINT_POSTPROCESS=True) - def test_tole_gen_when_is_not_initialized(self): + def test_role_gen_when_is_not_initialized(self): self.mock_ansible_lint_caller_with(None) name, files, outline, warnings = self.wca_client.invoke( RoleGenerationParameters.init( diff --git a/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_base.py b/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_base.py index 6641f18e1..af5a031ad 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_base.py +++ b/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_base.py @@ -120,6 +120,12 @@ namespace=NAMESPACE, buckets=DEFAULT_LATENCY_BUCKETS, ) +wca_codegen_role_hist = Histogram( + "wca_codegen_role_latency_seconds", + "Histogram of WCA codegen-role API processing time", + namespace=NAMESPACE, + buckets=DEFAULT_LATENCY_BUCKETS, +) wca_explain_playbook_hist = Histogram( "wca_explain_playbook_latency_seconds", "Histogram of WCA explain-playbook API processing time", @@ -147,6 +153,11 @@ "Counter of WCA codegen-playbook API invocation retries", namespace=NAMESPACE, ) +wca_codegen_role_retry_counter = Counter( + "wca_codegen_role_retries", + "Counter of WCA codegen-role API invocation retries", + namespace=NAMESPACE, +) wca_explain_playbook_retry_counter = Counter( "wca_explain_playbook_retries", "Counter of WCA explain-playbook API invocation retries", @@ -244,7 +255,7 @@ def on_backoff_codegen_playbook(details): @staticmethod def on_backoff_codegen_role(details): WCABasePipeline.log_backoff_exception(details) - wca_codegen_playbook_retry_counter.inc() + wca_codegen_role_retry_counter.inc() @staticmethod def on_backoff_explain_playbook(details): @@ -512,70 +523,7 @@ def __init__(self, config: WCA_PIPELINE_CONFIGURATION): super().__init__(config=config) def invoke(self, params: RoleGenerationParameters) -> RoleGenerationResponse: - request = params.request - text = params.text - create_outline = params.create_outline - outline = params.outline - model_id = params.model_id - generation_id = params.generation_id - - organization_id = request.user.organization.id if request.user.organization else None - api_key = self.get_api_key(request.user, organization_id) - model_id = self.get_model_id(request.user, organization_id, model_id) - - headers = self.get_request_headers(api_key, generation_id) - data = { - "model_id": model_id, - "text": text, - "create_outline": create_outline, - } - if outline: - data["outline"] = outline - - @backoff.on_exception( - backoff.expo, - Exception, - max_tries=self.retries + 1, - giveup=self.fatal_exception, - on_backoff=self.on_backoff_codegen_role, - ) - @wca_codegen_playbook_hist.time() - def post_request(): - return self.session.post( - f"{self.config.inference_url}/v1/wca/codegen/ansible/roles", - headers=headers, - json=data, - verify=self.config.verify_ssl, - ) - - result = post_request() - - x_request_id = result.headers.get(WCA_REQUEST_ID_HEADER) - if generation_id and x_request_id: - # request/payload suggestion_id is a UUID not a string whereas - # HTTP headers are strings. - if x_request_id != str(generation_id): - raise WcaRequestIdCorrelationFailure(model_id=model_id, x_request_id=x_request_id) - - context = Context(model_id, result, False) - InferenceResponseChecks().run_checks(context) - result.raise_for_status() - - response = json.loads(result.text) - - name = response["name"] - files = response["files"] - outline = response["outline"] - warnings = response["warnings"] if "warnings" in response else [] - - from ansible_ai_connect.ai.apps import AiConfig - - ai_config = cast(AiConfig, apps.get_app_config("ai")) - if ansible_lint_caller := ai_config.get_ansible_lint_caller(): - for file in files: - file["content"] = ansible_lint_caller.run_linter(file["content"]) - - return name, files, outline, warnings + raise NotImplementedError class WCABasePlaybookExplanationPipeline( diff --git a/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_saas.py b/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_saas.py index 7e2180236..a6f9ed6bf 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_saas.py +++ b/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_saas.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging from abc import ABCMeta -from typing import TYPE_CHECKING, Generic, Optional +from typing import TYPE_CHECKING, Generic, Optional, cast import backoff from django.apps import apps @@ -32,6 +33,7 @@ WcaKeyNotFound, WcaModelIdNotFound, WcaNoDefaultModelId, + WcaRequestIdCorrelationFailure, WcaTokenFailure, ) from ansible_ai_connect.ai.api.model_pipelines.pipelines import ( @@ -65,8 +67,11 @@ WcaModelRequestException, WcaTokenRequestException, ibm_cloud_identity_token_hist, + wca_codegen_role_hist, ) from ansible_ai_connect.ai.api.model_pipelines.wca.wca_utils import ( + Context, + InferenceResponseChecks, TokenContext, TokenResponseChecks, ) @@ -336,6 +341,73 @@ class WCASaaSRoleGenerationPipeline( def __init__(self, config: WCASaaSConfiguration): super().__init__(config=config) + # This should be moved to the base WCA class when it becomes available on-prem + def invoke(self, params: RoleGenerationParameters) -> RoleGenerationResponse: + request = params.request + text = params.text + create_outline = params.create_outline + outline = params.outline + model_id = params.model_id + generation_id = params.generation_id + + organization_id = request.user.organization.id if request.user.organization else None + api_key = self.get_api_key(request.user, organization_id) + model_id = self.get_model_id(request.user, organization_id, model_id) + + headers = self.get_request_headers(api_key, generation_id) + data = { + "model_id": model_id, + "text": text, + "create_outline": create_outline, + } + if outline: + data["outline"] = outline + + @backoff.on_exception( + backoff.expo, + Exception, + max_tries=self.retries + 1, + giveup=self.fatal_exception, + on_backoff=self.on_backoff_codegen_role, + ) + @wca_codegen_role_hist.time() + def post_request(): + return self.session.post( + f"{self.config.inference_url}/v1/wca/codegen/ansible/roles", + headers=headers, + json=data, + verify=self.config.verify_ssl, + ) + + result = post_request() + + x_request_id = result.headers.get(WCA_REQUEST_ID_HEADER) + if generation_id and x_request_id: + # request/payload suggestion_id is a UUID not a string whereas + # HTTP headers are strings. + if x_request_id != str(generation_id): + raise WcaRequestIdCorrelationFailure(model_id=model_id, x_request_id=x_request_id) + + context = Context(model_id, result, False) + InferenceResponseChecks().run_checks(context) + result.raise_for_status() + + response = json.loads(result.text) + + name = response["name"] + files = response["files"] + outline = response["outline"] + warnings = response["warnings"] if "warnings" in response else [] + + from ansible_ai_connect.ai.apps import AiConfig + + ai_config = cast(AiConfig, apps.get_app_config("ai")) + if ansible_lint_caller := ai_config.get_ansible_lint_caller(): + for file in files: + file["content"] = ansible_lint_caller.run_linter(file["content"]) + + return name, files, outline, warnings + def self_test(self) -> Optional[HealthCheckSummary]: raise NotImplementedError