Skip to content

Commit

Permalink
CR feedback: better names, comments, stricter typing, fewer dict lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Mar 1, 2025
1 parent 19d3c93 commit 00e8694
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 95 deletions.
123 changes: 61 additions & 62 deletions app/desktop/studio_server/eval_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
EvalConfigType,
EvalOutputScore,
EvalRun,
EvalTemplate,
EvalTemplateId,
)
from kiln_ai.datamodel.json_schema import string_to_json_key
from kiln_ai.datamodel.prompt_id import is_frozen_prompt
Expand All @@ -47,7 +47,7 @@ def eval_from_id(project_id: str, task_id: str, eval_id: str) -> Eval:

raise HTTPException(
status_code=404,
detail=f"Task not found. ID: {task_id}",
detail=f"Eval not found. ID: {eval_id}",
)


Expand Down Expand Up @@ -79,9 +79,9 @@ def task_run_config_from_id(
)


# JS SSE client (EventSource) doesn't work with POST requests, so we use GET, even though post would be better
async def run_eval_runner_with_status(eval_runner: EvalRunner) -> StreamingResponse:
# Async messages via server side events (SSE)
# Yields async messages designed to be used with server sent events (SSE)
# https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events
async def event_generator():
async for progress in eval_runner.run():
data = {
Expand All @@ -103,7 +103,7 @@ async def event_generator():
class CreateEvaluatorRequest(BaseModel):
name: str
description: str
template: EvalTemplate | None
template: EvalTemplateId | None
output_scores: list[EvalOutputScore]
eval_set_filter_id: DatasetFilterId
eval_configs_filter_id: DatasetFilterId
Expand Down Expand Up @@ -142,18 +142,18 @@ class EvalRunResult(BaseModel):

class EvalResultSummary(BaseModel):
# run_config_id -> output_score_id -> ScoreSummary
results: Dict[str, Dict[str, ScoreSummary]]
results: Dict[ID_TYPE, Dict[str, ScoreSummary]]
# run_config_id -> percent of the dataset that has been processed
run_config_percent_complete: Dict[str, float]
run_config_percent_complete: Dict[ID_TYPE, float]
# The total size of the dataset used for the eval
dataset_size: int


class EvalConfigCompareSummary(BaseModel):
# Summary of results. eval_config_id -> output_score_id -> CorrelationResult
results: Dict[str, Dict[str, CorrelationResult]]
results: Dict[ID_TYPE, Dict[str, CorrelationResult]]
# eval_config_id -> percent of the dataset that has been processed (run with eval scores)
eval_config_percent_complete: Dict[str, float]
eval_config_percent_complete: Dict[ID_TYPE, float]
# The total size of the dataset used for the eval config comparisons (eval.eval_configs_filter_id set size)
dataset_size: int
# The number of dataset items which are fully rated, partially rated, or not rated at all.
Expand All @@ -180,9 +180,10 @@ def human_score_from_task_run(
if score_key == "overall_rating":
human_score = task_run.output.rating.value
else:
req_rating = task_run.output.rating.requirement_ratings.get(
score_key_to_task_requirement_id[score_key], None
)
req_id = score_key_to_task_requirement_id.get(score_key, None)
if req_id is None:
return None
req_rating = task_run.output.rating.requirement_ratings.get(req_id, None)
if req_rating is not None:
human_score = req_rating.value

Expand All @@ -199,7 +200,6 @@ def count_human_evals(
partially_rated_count: int = 0
not_rated_count: int = 0
for dataset_item in items:
# Check it has all scores
has_all_scores = True
has_any_scores = False
for output_score in eval.output_scores:
Expand Down Expand Up @@ -346,8 +346,9 @@ async def create_eval_config(
eval_config.save_to_file()
return eval_config

# JS SSE client (EventSource) doesn't work with POST requests, so we use GET, even though post would be better
@app.get(
"/api/projects/{project_id}/tasks/{task_id}/eval/{eval_id}/eval_config/{eval_config_id}/run"
"/api/projects/{project_id}/tasks/{task_id}/eval/{eval_id}/eval_config/{eval_config_id}/run_task_run_eval"
)
async def run_eval_config(
project_id: str,
Expand Down Expand Up @@ -397,6 +398,7 @@ async def set_default_eval_config(

return eval

# JS SSE client (EventSource) doesn't work with POST requests, so we use GET, even though post would be better
@app.get(
"/api/projects/{project_id}/tasks/{task_id}/eval/{eval_id}/run_eval_config_eval"
)
Expand Down Expand Up @@ -440,6 +442,7 @@ async def get_eval_run_results(
run_config=run_config,
)

# This compares run_configs to each other on a given eval_config. Compare to below which compares eval_configs to each other.
@app.get(
"/api/projects/{project_id}/tasks/{task_id}/eval/{eval_id}/eval_config/{eval_config_id}/score_summary"
)
Expand All @@ -463,29 +466,27 @@ async def get_eval_config_score_summary(
)

# save a copy of the expected dataset ids for each run config, we'll update each as we process each eval run
remaining_expected_dataset_ids: Dict[str, Set[ID_TYPE]] = {
str(run_config.id): set(expected_dataset_ids)
for run_config in task_runs_configs
remaining_expected_dataset_ids: Dict[ID_TYPE, Set[ID_TYPE]] = {
run_config.id: set(expected_dataset_ids) for run_config in task_runs_configs
}
# Track how often we are missing scores in a eval_config. Should be 0 for a complete eval_config
partial_incomplete_counts: Dict[str, int] = {
str(run_config.id): 0 for run_config in task_runs_configs
partial_incomplete_counts: Dict[ID_TYPE, int] = {
run_config.id: 0 for run_config in task_runs_configs
}

# task_run_config_id -> output_score_id -> score/total
total_scores: Dict[str, Dict[str, float]] = {}
score_counts: Dict[str, Dict[str, int]] = {}
# task_run_config_id -> output_score_json_key -> score/total for calculating the mean score
total_scores: Dict[ID_TYPE, Dict[str, float]] = {}
score_counts: Dict[ID_TYPE, Dict[str, int]] = {}

# important: readonly makes this much faster
for eval_run in eval_config.runs(readonly=True):
if eval_run.task_run_config_id is None:
# This eval_run is not associated with a run_config, so we can't count it
# This eval_run is not associated with a run_config, so we should not count it
continue
run_config_id = str(eval_run.task_run_config_id)
run_config_id = eval_run.task_run_config_id

# Check if we should count this eval_run. Not every eval_run has to go into the stats:
# - a dataset_id can be removed from the dataset filter (removed a tag)
# - this dataset_id was already counted (not great there are dupes, but really shouldn't be double counted)
# - this dataset_id was already counted (not great there are dupes, but shouldn't be double counted if there are)
if eval_run.dataset_id not in remaining_expected_dataset_ids[run_config_id]:
continue
else:
Expand Down Expand Up @@ -513,25 +514,25 @@ async def get_eval_config_score_summary(
partial_incomplete_counts[run_config_id] += 1

# Convert to score summaries
results: Dict[str, Dict[str, ScoreSummary]] = {}
results: Dict[ID_TYPE, Dict[str, ScoreSummary]] = {}
for run_config_id, output_scores in total_scores.items():
results[run_config_id] = {}
for output_score_id, score in output_scores.items():
if score_counts[run_config_id][output_score_id] > 0:
count = score_counts[run_config_id][output_score_id]
if count > 0:
results[run_config_id][output_score_id] = ScoreSummary(
mean_score=score / score_counts[run_config_id][output_score_id]
mean_score=score / count
)

# Calculate the percent of the dataset that has been processed
run_config_percent_complete: Dict[str, float] = {}
run_config_percent_complete: Dict[ID_TYPE, float] = {}
for run_config in task_runs_configs:
run_config_id = str(run_config.id)
# Partial incomplete (missing scores), and fully incomplete (no eval_run)
incomplete_count = partial_incomplete_counts[run_config_id] + len(
remaining_expected_dataset_ids[run_config_id]
incomplete_count = partial_incomplete_counts[run_config.id] + len(
remaining_expected_dataset_ids[run_config.id]
)
percent_incomplete = incomplete_count / len(expected_dataset_ids)
run_config_percent_complete[str(run_config.id)] = 1 - percent_incomplete
run_config_percent_complete[run_config.id] = 1 - percent_incomplete

return EvalResultSummary(
results=results,
Expand Down Expand Up @@ -573,18 +574,15 @@ async def get_eval_configs_score_summary(
not_rated_count=0,
)

# save a copy of the expected dataset ids for each eval config, we'll update each as we process each eval run
remaining_expected_dataset_ids: Dict[str, Set[ID_TYPE]] = {
str(eval_config.id): set(expected_dataset_ids)
for eval_config in eval_configs
# save a copy of the expected dataset ids for each eval config id, we'll update each as we process each eval run
remaining_expected_dataset_ids: Dict[ID_TYPE, Set[ID_TYPE]] = {
eval_config.id: set(expected_dataset_ids) for eval_config in eval_configs
}

# eval_config_id -> output_score_id -> correlation calculator
correlation_calculators: Dict[str, Dict[str, CorrelationCalculator]] = {}
# eval_config_id -> output_score_json_key -> correlation calculator
correlation_calculators: Dict[ID_TYPE, Dict[str, CorrelationCalculator]] = {}

# important: readonly makes this much faster
for eval_config in eval_configs:
eval_config_id = str(eval_config.id)
for eval_run in eval_config.runs(readonly=True):
dataset_item = expected_dataset_items.get(eval_run.dataset_id, None)
if dataset_item is None:
Expand All @@ -593,14 +591,14 @@ async def get_eval_configs_score_summary(
continue

# Check if we should count this eval_run. Not every eval_run has to go into the stats:
# Example: this dataset_id was already counted (not great there are dupes, but really shouldn't be double counted)
# Example: this dataset_id was already counted (not great there are dupes, but shouldn't be double counted if there are)
if (
eval_run.dataset_id
not in remaining_expected_dataset_ids[eval_config_id]
not in remaining_expected_dataset_ids[eval_config.id]
):
continue
else:
remaining_expected_dataset_ids[eval_config_id].remove(
remaining_expected_dataset_ids[eval_config.id].remove(
eval_run.dataset_id
)

Expand All @@ -617,21 +615,23 @@ async def get_eval_configs_score_summary(
# This score doesn't have both a human eval and eval score, so we can't compare
continue

if eval_config_id not in correlation_calculators:
correlation_calculators[eval_config_id] = {}
if eval_config.id not in correlation_calculators:
correlation_calculators[eval_config.id] = {}

if score_key not in correlation_calculators[eval_config_id]:
correlation_calculators[eval_config_id][score_key] = (
CorrelationCalculator()
)
calculator = correlation_calculators[eval_config.id].get(
score_key, None
)
if calculator is None:
calculator = CorrelationCalculator()
correlation_calculators[eval_config.id][score_key] = calculator

normalized_eval_score = normalize_rating(
eval_score, output_score.type
)
normalized_human_score = normalize_rating(
human_score, output_score.type
)
correlation_calculators[eval_config_id][score_key].add_score(
calculator.add_score(
CorrelationScore(
measured_score=eval_score,
human_score=human_score,
Expand All @@ -641,27 +641,26 @@ async def get_eval_configs_score_summary(
)

# Convert to score summaries
results: Dict[str, Dict[str, CorrelationResult]] = {}
results: Dict[ID_TYPE, Dict[str, CorrelationResult]] = {}
for eval_config_id in correlation_calculators.keys():
results[eval_config_id] = {}
for score_key in correlation_calculators[eval_config_id].keys():
if not correlation_calculators[eval_config_id][score_key]:
calculator = correlation_calculators[eval_config_id].get(
score_key, None
)
if calculator is None:
# No scores to calculate correlation for this pair
continue

correlation_result = correlation_calculators[eval_config_id][
score_key
].calculate_correlation()
correlation_result = calculator.calculate_correlation()
results[eval_config_id][score_key] = correlation_result

# Calculate the percent of the dataset that has been processed
eval_config_percent_complete: Dict[str, float] = {}
eval_config_percent_complete: Dict[ID_TYPE, float] = {}
for eval_config in eval_configs:
eval_config_id = str(eval_config.id)
# Partial incomplete (missing scores), and fully incomplete (no eval_run)
incomplete_count = len(remaining_expected_dataset_ids[eval_config_id])
incomplete_count = len(remaining_expected_dataset_ids[eval_config.id])
percent_incomplete = incomplete_count / len(expected_dataset_ids)
eval_config_percent_complete[str(eval_config.id)] = 1 - percent_incomplete
eval_config_percent_complete[eval_config.id] = 1 - percent_incomplete

# Count how many dataset items have human evals
fully_rated_count, partially_rated_count, not_rated_count = count_human_evals(
Expand Down
10 changes: 5 additions & 5 deletions app/desktop/studio_server/test_eval_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
EvalConfigType,
EvalOutputScore,
EvalRun,
EvalTemplate,
EvalTemplateId,
)
from kiln_ai.datamodel.task import RunConfigProperties, TaskRunConfig

Expand Down Expand Up @@ -87,7 +87,7 @@ def mock_eval(mock_task):
id="eval1",
name="Test Eval",
description="Test Description",
template=EvalTemplate.bias,
template=EvalTemplateId.bias,
output_scores=[
EvalOutputScore(name="score1", description="desc1", type="five_star"),
EvalOutputScore(
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_get_eval_not_found(client, mock_task, mock_task_from_id):
response = client.get("/api/projects/project1/tasks/task1/eval/non_existent")

assert response.status_code == 404
assert response.json()["detail"] == "Task not found. ID: task1"
assert response.json()["detail"] == "Eval not found. ID: non_existent"


@pytest.fixture
Expand Down Expand Up @@ -428,7 +428,7 @@ async def mock_run():

# Make request with specific run_config_ids
response = client.get(
"/api/projects/project1/tasks/task1/eval/eval1/eval_config/eval_config1/run",
"/api/projects/project1/tasks/task1/eval/eval1/eval_config/eval_config1/run_task_run_eval",
params={"run_config_ids": ["run_config1", "run_config2"]},
)

Expand Down Expand Up @@ -465,7 +465,7 @@ async def test_run_eval_config_no_run_configs_error(

# Make request with no run_config_ids and all_run_configs=False
response = client.get(
"/api/projects/project1/tasks/task1/eval/eval1/eval_config/eval_config1/run"
"/api/projects/project1/tasks/task1/eval/eval1/eval_config/eval_config1/run_task_run_eval"
)

assert response.status_code == 400
Expand Down
Loading

0 comments on commit 00e8694

Please sign in to comment.