Skip to content

Commit

Permalink
AAP-24704: Refactor ModelMeshClient environment variable names (#1087)
Browse files Browse the repository at this point in the history
  • Loading branch information
manstis authored Jul 10, 2024
1 parent 5bf75ec commit e11e291
Show file tree
Hide file tree
Showing 16 changed files with 148 additions and 107 deletions.
12 changes: 5 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ SECRET_KEY="somesecretvalue"
ENABLE_ARI_POSTPROCESS="False"
WCA_SECRET_BACKEND_TYPE="dummy"
# configure model server
ANSIBLE_AI_MODEL_MESH_HOST="http://host.containers.internal"
ANSIBLE_AI_MODEL_MESH_INFERENCE_PORT="11434"
ANSIBLE_AI_MODEL_MESH_API_URL="http://host.containers.internal:11434"
ANSIBLE_AI_MODEL_MESH_API_TYPE="ollama"
ANSIBLE_AI_MODEL_NAME="mistral:instruct"
ANSIBLE_AI_MODEL_MESH_MODEL_ID="mistral:instruct"
```

### Start service and dependencies
Expand Down Expand Up @@ -96,7 +95,7 @@ command line the variable `DEBUG=True`.
The Django service listens on <http://127.0.0.1:8000>.

Note that there is no pytorch service defined in the docker-compose
file. You should adjust the `ANSIBLE_AI_MODEL_MESH_HOST`
file. You should adjust the `ANSIBLE_AI_MODEL_MESH_API_URL`
configuration key to point on an existing service.

## <a name="aws-config">Use the WCA API Keys Manager</a>
Expand Down Expand Up @@ -444,10 +443,9 @@ To connect to the Mistal 7b Instruct model running on locally on [llama.cpp](htt
```
1. Set the appropriate environment variables
```bash
ANSIBLE_AI_MODEL_MESH_HOST=http://$YOUR_REAL_IP
ANSIBLE_AI_MODEL_MESH_INFERENCE_PORT=8080
ANSIBLE_AI_MODEL_MESH_API_URL=http://$YOUR_REAL_IP:8080
ANSIBLE_AI_MODEL_MESH_API_TYPE=llamacpp
ANSIBLE_AI_MODEL_NAME=mistral-7b-instruct-v0.2.Q5_K_M.gguf
ANSIBLE_AI_MODEL_MESH_MODEL_ID=mistral-7b-instruct-v0.2.Q5_K_M.gguf
ENABLE_ARI_POSTPROCESS=False
```

Expand Down
2 changes: 1 addition & 1 deletion ansible_ai_connect/ai/api/model_client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_model_id(
organization_id: Optional[int] = None,
requested_model_id: str = "",
) -> str:
return requested_model_id or settings.ANSIBLE_AI_MODEL_NAME
return requested_model_id or settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID

def timeout(self, task_count=1):
return self._timeout * task_count if self._timeout else None
Expand Down
6 changes: 1 addition & 5 deletions ansible_ai_connect/ai/api/model_client/grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,7 @@ def infer(self, request, model_input, model_id="", suggestion_id=None) -> Dict[s
raise

def self_test(self) -> HealthCheckSummary:
url = (
f"{settings.ANSIBLE_AI_MODEL_MESH_API_HEALTHCHECK_PROTOCOL}://"
f"{settings.ANSIBLE_AI_MODEL_MESH_HOST}:"
f"{settings.ANSIBLE_AI_MODEL_MESH_API_HEALTHCHECK_PORT}/oauth/healthz"
)
url = f"{settings.ANSIBLE_GRPC_HEALTHCHECK_URL}/oauth/healthz"
summary: HealthCheckSummary = HealthCheckSummary(
{
MODEL_MESH_HEALTH_CHECK_PROVIDER: settings.ANSIBLE_AI_MODEL_MESH_API_TYPE,
Expand Down
2 changes: 1 addition & 1 deletion ansible_ai_connect/ai/api/model_client/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def infer(self, request, model_input, model_id="", suggestion_id=None) -> Dict[s
raise ModelTimeoutError

def self_test(self) -> HealthCheckSummary:
url = f"{settings.ANSIBLE_AI_MODEL_MESH_INFERENCE_URL}/ping"
url = f"{self._inference_url}/ping"
summary: HealthCheckSummary = HealthCheckSummary(
{
MODEL_MESH_HEALTH_CHECK_PROVIDER: settings.ANSIBLE_AI_MODEL_MESH_API_TYPE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setUp(self):
}

@override_settings(ANSIBLE_AI_MODEL_MESH_API_KEY="my_key")
@override_settings(ANSIBLE_AI_MODEL_NAME="test")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID="test")
@responses.activate
def test_infer(self):
model = "test"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def setUp(self):
"model_id": "test",
}

@override_settings(ANSIBLE_AI_MODEL_NAME="test")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID="test")
@responses.activate
def test_infer(self):
model_client = LlamaCPPClient(inference_url=self.inference_url)
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_infer_override(self):
response = model_client.infer(None, self.model_input, model_id=model)
self.assertEqual(json.dumps(self.expected_response), json.dumps(response))

@override_settings(ANSIBLE_AI_MODEL_NAME="test")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID="test")
@responses.activate
def test_infer_timeout(self):
model_client = LlamaCPPClient(inference_url=self.inference_url)
Expand Down
22 changes: 11 additions & 11 deletions ansible_ai_connect/ai/api/model_client/tests/test_wca_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_count():

@override_settings(WCA_SECRET_BACKEND_TYPE="dummy")
@override_settings(ANSIBLE_AI_MODEL_MESH_API_KEY=None)
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_NAME=None)
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID=None)
class TestWCAClient(WisdomAppsBackendMocking, WisdomServiceLogAwareTestCase):
@override_settings(WCA_SECRET_DUMMY_SECRETS="11009103:my-key<sep>my-optimized-model")
def test_mock_wca_get_api_key(self):
Expand Down Expand Up @@ -210,13 +210,13 @@ def test_get_model_id_org_cannot_have_no_model(self):
with self.assertRaises(WcaModelIdNotFound):
wca_client.get_model_id(123, None)

@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_NAME="gemini")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID="gemini")
def test_model_id_with_environment_override(self):
wca_client = WCAClient(inference_url="http://example.com/")
model_id = wca_client.get_model_id(123, None)
self.assertEqual(model_id, "gemini")

@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_NAME="gemini")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID="gemini")
def test_model_id_with_environment_and_user_override(self):
wca_client = WCAClient(inference_url="http://example.com/")
model_id = wca_client.get_model_id(123, "bard")
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_fatal_exception(self):

@override_settings(WCA_SECRET_BACKEND_TYPE="dummy")
@override_settings(ANSIBLE_AI_MODEL_MESH_API_KEY=None)
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_NAME=None)
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID=None)
@override_settings(ENABLE_ANSIBLE_LINT_POSTPROCESS=False)
class TestWCAClientExpGen(WisdomAppsBackendMocking, WisdomServiceLogAwareTestCase):
def setUp(self):
Expand Down Expand Up @@ -966,7 +966,7 @@ def test_codematch_empty_response(self):
self.assertEqual(e.exception.model_id, model_id)


@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_NAME=None)
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID=None)
class TestDummySecretManager(TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -998,7 +998,7 @@ def test_get_secret(self):


@override_settings(WCA_SECRET_BACKEND_TYPE="dummy")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_NAME=None)
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID=None)
@override_settings(WCA_SECRET_DUMMY_SECRETS="")
class TestWCAClientOnPrem(WisdomAppsBackendMocking, WisdomServiceLogAwareTestCase):
@override_settings(ANSIBLE_WCA_USERNAME="username")
Expand All @@ -1017,23 +1017,23 @@ def test_get_api_key_without_setting(self):

@override_settings(ANSIBLE_WCA_USERNAME="username")
@override_settings(ANSIBLE_AI_MODEL_MESH_API_KEY="12345")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_NAME="model-name")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID="model-name")
def test_get_model_id(self):
model_client = WCAOnPremClient(inference_url="http://example.com/")
model_id = model_client.get_model_id(11009103)
self.assertEqual(model_id, "model-name")

@override_settings(ANSIBLE_WCA_USERNAME="username")
@override_settings(ANSIBLE_AI_MODEL_MESH_API_KEY="12345")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_NAME="model-name")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID="model-name")
def test_get_model_id_with_override(self):
model_client = WCAOnPremClient(inference_url="http://example.com/")
model_id = model_client.get_model_id(11009103, "override-model-name")
self.assertEqual(model_id, "override-model-name")

@override_settings(ANSIBLE_WCA_USERNAME="username")
@override_settings(ANSIBLE_AI_MODEL_MESH_API_KEY="12345")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_NAME=None)
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID=None)
def test_get_model_id_without_setting(self):
model_client = WCAOnPremClient(inference_url="http://example.com/")
with self.assertRaises(WcaModelIdNotFound):
Expand All @@ -1043,7 +1043,7 @@ def test_get_model_id_without_setting(self):
@override_settings(ANSIBLE_WCA_RETRY_COUNT=1)
@override_settings(ANSIBLE_WCA_USERNAME="username")
@override_settings(ANSIBLE_AI_MODEL_MESH_API_KEY="12345")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_NAME="model-name")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID="model-name")
@override_settings(ANSIBLE_AI_MODEL_MESH_API_TIMEOUT=None)
class TestWCAOnPremCodegen(WisdomServiceLogAwareTestCase):
def test_headers(self):
Expand Down Expand Up @@ -1089,7 +1089,7 @@ def test_headers(self):
@override_settings(ANSIBLE_WCA_RETRY_COUNT=1)
@override_settings(ANSIBLE_WCA_USERNAME="username")
@override_settings(ANSIBLE_AI_MODEL_MESH_API_KEY="12345")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_NAME="model-name")
@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID="model-name")
@override_settings(ANSIBLE_AI_MODEL_MESH_API_TIMEOUT=None)
class TestWCAOnPremCodematch(WisdomServiceLogAwareTestCase):
def test_headers(self):
Expand Down
11 changes: 5 additions & 6 deletions ansible_ai_connect/ai/api/model_client/wca_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_i

headers = self.get_request_headers(api_key, suggestion_id)
task_count = len(get_task_names_from_prompt(prompt))
# path matches ANSIBLE_WCA_INFERENCE_URL="https://api.dataplatform.test.cloud.ibm.com"
prediction_url = f"{self._inference_url}/v1/wca/codegen/ansible"

@backoff.on_exception(
Expand Down Expand Up @@ -446,8 +445,8 @@ def get_model_id(
# requested_model_id defined: i.e. not None, not "", not {} etc.
# let them use what they ask for
return requested_model_id
elif settings.ANSIBLE_AI_MODEL_MESH_MODEL_NAME:
return settings.ANSIBLE_AI_MODEL_MESH_MODEL_NAME
elif settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID:
return settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID
elif organization_id is None:
logger.error(
"User is not linked to an organization and no default WCA model ID is found"
Expand Down Expand Up @@ -633,7 +632,7 @@ def __init__(self, inference_url):
raise WcaUsernameNotFound
if not settings.ANSIBLE_AI_MODEL_MESH_API_KEY:
raise WcaKeyNotFound
# ANSIBLE_AI_MODEL_MESH_MODEL_NAME cannot be validated until runtime. The
# ANSIBLE_AI_MODEL_MESH_MODEL_ID cannot be validated until runtime. The
# User may provide an override value if the Environment Variable is not set.

def get_api_key(self, organization_id: Optional[int]) -> str:
Expand All @@ -648,8 +647,8 @@ def get_model_id(
# requested_model_id defined: let them use what they ask for
return requested_model_id

if settings.ANSIBLE_AI_MODEL_MESH_MODEL_NAME:
return settings.ANSIBLE_AI_MODEL_MESH_MODEL_NAME
if settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID:
return settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID

raise WcaModelIdNotFound()

Expand Down
32 changes: 16 additions & 16 deletions ansible_ai_connect/ai/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def test_full_payload(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -798,7 +798,7 @@ def test_multi_task_prompt_commercial(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [
"- name: Install Apache\n ansible.builtin.apt:\n name: apache2\n state: latest\n- name: start Apache\n ansible.builtin.service:\n name: apache2\n state: started\n enabled: yes\n" # noqa: E501
],
Expand Down Expand Up @@ -842,7 +842,7 @@ def test_multi_task_prompt_commercial_with_pii(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [
"- name: Install Apache\n ansible.builtin.apt:\n name: apache2\n state: latest\n- name: say hello [email protected]\n ansible.builtin.debug:\n msg: Hello there [email protected]\n" # noqa: E501
],
Expand Down Expand Up @@ -879,7 +879,7 @@ def test_rate_limit(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -903,7 +903,7 @@ def test_missing_prompt(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -924,7 +924,7 @@ def test_authentication_error(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
# self.client.force_authenticate(user=self.user)
Expand Down Expand Up @@ -953,7 +953,7 @@ def test_completions_preprocessing_error(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -979,7 +979,7 @@ def test_completions_preprocessing_error_without_name_prompt(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -1003,7 +1003,7 @@ def test_full_payload_without_ARI(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -1028,7 +1028,7 @@ def test_full_payload_with_recommendation_with_broken_last_line(self):
}
# quotation in the last line is not closed, but the truncate function can handle this.
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [
' ansible.builtin.apt:\n name: apache2\n register: "test'
],
Expand All @@ -1055,7 +1055,7 @@ def test_completions_postprocessing_error_for_invalid_yaml(self):
}
# this prediction has indentation problem with the prompt above
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [" ansible.builtin.apt:\n garbage name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand Down Expand Up @@ -1126,7 +1126,7 @@ def test_payload_with_ansible_lint_without_commercial(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -1149,7 +1149,7 @@ def test_full_payload_without_ansible_lint_without_commercial(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand Down Expand Up @@ -1178,7 +1178,7 @@ def test_full_payload_without_ansible_lint_with_commercial_user(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand Down Expand Up @@ -1343,7 +1343,7 @@ def test_completions_pii_clean_up(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [""],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -1364,7 +1364,7 @@ def test_full_completion_post_response(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_NAME,
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand Down
Loading

0 comments on commit e11e291

Please sign in to comment.