Skip to content

Commit

Permalink
Fixes and pending tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
romartin committed Dec 19, 2023
1 parent 28e0d02 commit c3b8a2d
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 43 deletions.
18 changes: 8 additions & 10 deletions ansible_wisdom/ai/api/model_client/wca_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def check(self, context: InferenceContext):

class ResponseStatusCode403UserTrialExpired(Check[InferenceContext]):
def check(self, context: InferenceContext):
CommonChecks.is_user_trial_expired(
is_user_trial_expired(
context.model_id, context.result.status_code, lambda: context.result.text
)

Expand Down Expand Up @@ -174,7 +174,7 @@ def check(self, context: ContentMatchContext):

class ResponseStatusCode403UserTrialExpired(Check[InferenceContext]):
def check(self, context: ContentMatchContext):
CommonChecks.is_user_trial_expired(
is_user_trial_expired(
context.model_id, context.result.status_code, lambda: context.result.text
)

Expand All @@ -192,11 +192,9 @@ def __init__(self):
)


class CommonChecks:
@staticmethod
def is_user_trial_expired(model_id, result_code, result_text_provider):
if result_code == 403:
text = result_text_provider()
# TODO: Improve condition efficiency. Eg: by only matching message_id == WCA-0001-E
if text and "CUH limit is reached" in text.lower():
raise WcaUserTrialExpired(model_id=model_id)
def is_user_trial_expired(model_id, result_code, result_text_provider):
if result_code == 403:
text = result_text_provider()
# TODO: Improve condition efficiency. Eg: by only matching message_id == WCA-0001-E
if text and "CUH limit is reached" in text.lower():
raise WcaUserTrialExpired(model_id=model_id)
31 changes: 18 additions & 13 deletions ansible_wisdom/ai/api/pipelines/completion_stages/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def process(self, context: CompletionContext) -> None:

predictions = None
exception = None
event = None
event_name = None
start_time = time.time()
try:
predictions = model_mesh_client.infer(
Expand Down Expand Up @@ -167,12 +169,11 @@ def process(self, context: CompletionContext) -> None:
f"User`s trial expired, when requesting suggestion {payload.suggestionId}"
)
event = {
"type": "inference",
"type": "prediction",
"modelName": model_id,
"suggestionId": str(suggestion_id),
}
# TODO: Sending both trialExpired & prediction? Makes sense?
send_segment_event(event, "trialExpired", request.user)
event_name = 'trialExpired'
raise WcaUserTrialExpiredException(cause=e)

except Exception as e:
Expand All @@ -195,16 +196,20 @@ def process(self, context: CompletionContext) -> None:
)
if model_id_in_exception:
model_id = model_id_in_exception
event = {
"duration": duration,
"exception": exception is not None,
"modelName": model_id,
"problem": None if exception is None else exception.__class__.__name__,
"request": data,
"response": anonymized_predictions,
"suggestionId": str(suggestion_id),
}
send_segment_event(event, "prediction", request.user)
if event:
event['modelName'] = model_id
else:
event = {
"duration": duration,
"exception": exception is not None,
"modelName": model_id,
"problem": None if exception is None else exception.__class__.__name__,
"request": data,
"response": anonymized_predictions,
"suggestionId": str(suggestion_id),
}
event_name = event_name if event_name else "prediction"
send_segment_event(event, event_name, request.user)

logger.debug(f"response from inference for suggestion id {suggestion_id}:\n{predictions}")

Expand Down
31 changes: 18 additions & 13 deletions ansible_wisdom/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ def perform_content_matching(
)

exception = None
event = None
event_name = None
start_time = time.time()
response_serializer = None
metadata = []
Expand Down Expand Up @@ -543,13 +545,12 @@ def perform_content_matching(
except WcaUserTrialExpired as e:
exception = e
logger.exception(f"User`s trial expired, when requesting suggestion {suggestion_id}")
event_name = "trialExpired"
event = {
"type": "contentMatch",
"type": "contentmatch",
"modelName": model_id,
"suggestionId": str(suggestion_id),
}
# TODO: Sending both trialExpired & prediction? Makes sense?
send_segment_event(event, "trialExpired", user)
raise WcaUserTrialExpiredException(cause=e)

except Exception as e:
Expand All @@ -565,16 +566,20 @@ def perform_content_matching(
)
if model_id_in_exception:
model_id = model_id_in_exception
self._write_to_segment(
request_data,
duration,
exception,
metadata,
model_id,
response_serializer.data if response_serializer else {},
suggestion_id,
user,
)
if event:
event['modelName'] = model_id
send_segment_event(event, event_name, user)
else:
self._write_to_segment(
request_data,
duration,
exception,
metadata,
model_id,
response_serializer.data if response_serializer else {},
suggestion_id,
user,
)

return response_serializer

Expand Down
22 changes: 15 additions & 7 deletions ansible_wisdom/ai/api/wca/model_id_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def validate(api_key, model_id):

def do_validated_operation(request, api_key_provider, model_id_provider, on_success, event_name):
exception = None
event = None
start_time = time.time()
model_id = UNKNOWN_MODEL_ID
try:
Expand Down Expand Up @@ -238,7 +239,11 @@ def do_validated_operation(request, api_key_provider, model_id_provider, on_succ
except WcaUserTrialExpired as e:
exception = e
logger.info(e, exc_info=True)
# TODO: Send segment event? Consider is actually also sending the modelIdValidate one.
event = {
"type": event_name,
"modelName": model_id,
}
event_name = 'trialExpired'
raise WcaUserTrialExpiredException(cause=e)

except Exception as e:
Expand All @@ -248,12 +253,15 @@ def do_validated_operation(request, api_key_provider, model_id_provider, on_succ

finally:
duration = round((time.time() - start_time) * 1000, 2)
event = {
"duration": duration,
"exception": exception is not None,
"problem": None if exception is None else exception.__class__.__name__,
"modelName": model_id,
}
if event:
event['modelName'] = model_id
else:
event = {
"duration": duration,
"exception": exception is not None,
"problem": None if exception is None else exception.__class__.__name__,
"modelName": model_id,
}
send_segment_event(event, event_name, request.user)


Expand Down

0 comments on commit c3b8a2d

Please sign in to comment.