Skip to content

Commit

Permalink
Attempt to move role gen to the saas pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
mabashian committed Feb 3, 2025
1 parent e3a0106 commit 14c441f
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
wca_codegen_playbook_hist,
wca_codegen_playbook_retry_counter,
wca_codegen_retry_counter,
wca_codegen_role_hist,
wca_codematch_hist,
wca_codematch_retry_counter,
wca_explain_playbook_hist,
Expand Down Expand Up @@ -537,7 +538,7 @@ def setUp(self):
wca_client.session.post.return_value = response
self.wca_client = wca_client

@assert_call_count_metrics(metric=wca_codegen_playbook_hist)
@assert_call_count_metrics(metric=wca_codegen_role_hist)
@override_settings(ENABLE_ANSIBLE_LINT_POSTPROCESS=True)
def test_role_gen_with_lint(self):
fake_linter = Mock()
Expand All @@ -554,9 +555,9 @@ def test_role_gen_with_lint(self):
for file in files:
self.assertEqual(file["content"], "I'm super fake!")

@assert_call_count_metrics(metric=wca_codegen_playbook_hist)
@assert_call_count_metrics(metric=wca_codegen_role_hist)
@override_settings(ENABLE_ANSIBLE_LINT_POSTPROCESS=True)
def test_tole_gen_when_is_not_initialized(self):
def test_role_gen_when_is_not_initialized(self):
self.mock_ansible_lint_caller_with(None)
name, files, outline, warnings = self.wca_client.invoke(
RoleGenerationParameters.init(
Expand Down
78 changes: 13 additions & 65 deletions ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@
namespace=NAMESPACE,
buckets=DEFAULT_LATENCY_BUCKETS,
)
wca_codegen_role_hist = Histogram(
"wca_codegen_role_latency_seconds",
"Histogram of WCA codegen-role API processing time",
namespace=NAMESPACE,
buckets=DEFAULT_LATENCY_BUCKETS,
)
wca_explain_playbook_hist = Histogram(
"wca_explain_playbook_latency_seconds",
"Histogram of WCA explain-playbook API processing time",
Expand Down Expand Up @@ -147,6 +153,11 @@
"Counter of WCA codegen-playbook API invocation retries",
namespace=NAMESPACE,
)
wca_codegen_role_retry_counter = Counter(
"wca_codegen_role_retries",
"Counter of WCA codegen-role API invocation retries",
namespace=NAMESPACE,
)
wca_explain_playbook_retry_counter = Counter(
"wca_explain_playbook_retries",
"Counter of WCA explain-playbook API invocation retries",
Expand Down Expand Up @@ -244,7 +255,7 @@ def on_backoff_codegen_playbook(details):
@staticmethod
def on_backoff_codegen_role(details):
WCABasePipeline.log_backoff_exception(details)
wca_codegen_playbook_retry_counter.inc()
wca_codegen_role_retry_counter.inc()

@staticmethod
def on_backoff_explain_playbook(details):
Expand Down Expand Up @@ -512,70 +523,7 @@ def __init__(self, config: WCA_PIPELINE_CONFIGURATION):
super().__init__(config=config)

def invoke(self, params: RoleGenerationParameters) -> RoleGenerationResponse:
request = params.request
text = params.text
create_outline = params.create_outline
outline = params.outline
model_id = params.model_id
generation_id = params.generation_id

organization_id = request.user.organization.id if request.user.organization else None
api_key = self.get_api_key(request.user, organization_id)
model_id = self.get_model_id(request.user, organization_id, model_id)

headers = self.get_request_headers(api_key, generation_id)
data = {
"model_id": model_id,
"text": text,
"create_outline": create_outline,
}
if outline:
data["outline"] = outline

@backoff.on_exception(
backoff.expo,
Exception,
max_tries=self.retries + 1,
giveup=self.fatal_exception,
on_backoff=self.on_backoff_codegen_role,
)
@wca_codegen_playbook_hist.time()
def post_request():
return self.session.post(
f"{self.config.inference_url}/v1/wca/codegen/ansible/roles",
headers=headers,
json=data,
verify=self.config.verify_ssl,
)

result = post_request()

x_request_id = result.headers.get(WCA_REQUEST_ID_HEADER)
if generation_id and x_request_id:
# request/payload suggestion_id is a UUID not a string whereas
# HTTP headers are strings.
if x_request_id != str(generation_id):
raise WcaRequestIdCorrelationFailure(model_id=model_id, x_request_id=x_request_id)

context = Context(model_id, result, False)
InferenceResponseChecks().run_checks(context)
result.raise_for_status()

response = json.loads(result.text)

name = response["name"]
files = response["files"]
outline = response["outline"]
warnings = response["warnings"] if "warnings" in response else []

from ansible_ai_connect.ai.apps import AiConfig

ai_config = cast(AiConfig, apps.get_app_config("ai"))
if ansible_lint_caller := ai_config.get_ansible_lint_caller():
for file in files:
file["content"] = ansible_lint_caller.run_linter(file["content"])

return name, files, outline, warnings
raise NotImplementedError


class WCABasePlaybookExplanationPipeline(
Expand Down
74 changes: 73 additions & 1 deletion ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_saas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
from abc import ABCMeta
from typing import TYPE_CHECKING, Generic, Optional
from typing import TYPE_CHECKING, Generic, Optional, cast

import backoff
from django.apps import apps
Expand All @@ -32,6 +33,7 @@
WcaKeyNotFound,
WcaModelIdNotFound,
WcaNoDefaultModelId,
WcaRequestIdCorrelationFailure,
WcaTokenFailure,
)
from ansible_ai_connect.ai.api.model_pipelines.pipelines import (
Expand Down Expand Up @@ -65,8 +67,11 @@
WcaModelRequestException,
WcaTokenRequestException,
ibm_cloud_identity_token_hist,
wca_codegen_role_hist,
)
from ansible_ai_connect.ai.api.model_pipelines.wca.wca_utils import (
Context,
InferenceResponseChecks,
TokenContext,
TokenResponseChecks,
)
Expand Down Expand Up @@ -336,6 +341,73 @@ class WCASaaSRoleGenerationPipeline(
def __init__(self, config: WCASaaSConfiguration):
super().__init__(config=config)

# This should be moved to the base WCA class when it becomes available on-prem
def invoke(self, params: RoleGenerationParameters) -> RoleGenerationResponse:
request = params.request
text = params.text
create_outline = params.create_outline
outline = params.outline
model_id = params.model_id
generation_id = params.generation_id

organization_id = request.user.organization.id if request.user.organization else None
api_key = self.get_api_key(request.user, organization_id)
model_id = self.get_model_id(request.user, organization_id, model_id)

headers = self.get_request_headers(api_key, generation_id)
data = {
"model_id": model_id,
"text": text,
"create_outline": create_outline,
}
if outline:
data["outline"] = outline

@backoff.on_exception(
backoff.expo,
Exception,
max_tries=self.retries + 1,
giveup=self.fatal_exception,
on_backoff=self.on_backoff_codegen_role,
)
@wca_codegen_role_hist.time()
def post_request():
return self.session.post(
f"{self.config.inference_url}/v1/wca/codegen/ansible/roles",
headers=headers,
json=data,
verify=self.config.verify_ssl,
)

result = post_request()

x_request_id = result.headers.get(WCA_REQUEST_ID_HEADER)
if generation_id and x_request_id:
# request/payload suggestion_id is a UUID not a string whereas
# HTTP headers are strings.
if x_request_id != str(generation_id):
raise WcaRequestIdCorrelationFailure(model_id=model_id, x_request_id=x_request_id)

context = Context(model_id, result, False)
InferenceResponseChecks().run_checks(context)
result.raise_for_status()

response = json.loads(result.text)

name = response["name"]
files = response["files"]
outline = response["outline"]
warnings = response["warnings"] if "warnings" in response else []

from ansible_ai_connect.ai.apps import AiConfig

ai_config = cast(AiConfig, apps.get_app_config("ai"))
if ansible_lint_caller := ai_config.get_ansible_lint_caller():
for file in files:
file["content"] = ansible_lint_caller.run_linter(file["content"])

return name, files, outline, warnings

def self_test(self) -> Optional[HealthCheckSummary]:
raise NotImplementedError

Expand Down

0 comments on commit 14c441f

Please sign in to comment.