Skip to content

Commit

Permalink
Add 2 new scores: normalized MSE and MAE
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Feb 26, 2025
1 parent 50811b1 commit a493ccd
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 1 deletion.
30 changes: 29 additions & 1 deletion app/desktop/studio_server/eval_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from kiln_ai.datamodel.json_schema import string_to_json_key
from kiln_ai.datamodel.prompt_id import is_frozen_prompt
from kiln_ai.datamodel.task import RunConfigProperties, TaskRunConfig
from kiln_ai.datamodel.task_output import normalize_rating
from kiln_ai.utils.name_generator import generate_memorable_name
from kiln_server.task_api import task_from_id
from pydantic import BaseModel
Expand Down Expand Up @@ -144,7 +145,9 @@ class EvalResultSummary(BaseModel):

class EvalConfigScoreSummary(BaseModel):
mean_absolute_error: float
mean_normalized_absolute_error: float
mean_squared_error: float
mean_normalized_squared_error: float


class EvalConfigCompareSummary(BaseModel):
Expand Down Expand Up @@ -588,7 +591,9 @@ async def get_eval_configs_score_summary(

# eval_config_id -> output_score_id -> scores/total
total_squared_error: Dict[str, Dict[str, float]] = {}
total_normalized_squared_error: Dict[str, Dict[str, float]] = {}
total_absolute_error: Dict[str, Dict[str, float]] = {}
total_normalized_absolute_error: Dict[str, Dict[str, float]] = {}
total_count: Dict[str, Dict[str, int]] = {}

# important: readonly makes this much faster
Expand Down Expand Up @@ -630,18 +635,33 @@ async def get_eval_configs_score_summary(
total_squared_error[eval_config_id] = {}
total_absolute_error[eval_config_id] = {}
total_count[eval_config_id] = {}
total_normalized_squared_error[eval_config_id] = {}
total_normalized_absolute_error[eval_config_id] = {}
if score_key not in total_squared_error[eval_config_id]:
total_squared_error[eval_config_id][score_key] = 0
total_absolute_error[eval_config_id][score_key] = 0
total_count[eval_config_id][score_key] = 0
total_normalized_squared_error[eval_config_id][score_key] = 0
total_normalized_absolute_error[eval_config_id][score_key] = 0

# TODO normalize MSE?
normalized_eval_score = normalize_rating(
eval_score, output_score.type
)
normalized_human_score = normalize_rating(
human_score, output_score.type
)
total_squared_error[eval_config_id][score_key] += (
eval_score - human_score
) ** 2
total_normalized_squared_error[eval_config_id][score_key] += (
normalized_eval_score - normalized_human_score
) ** 2
total_absolute_error[eval_config_id][score_key] += abs(
eval_score - human_score
)
total_normalized_absolute_error[eval_config_id][score_key] += abs(
normalized_eval_score - normalized_human_score
)
total_count[eval_config_id][score_key] += 1

# Convert to score summaries
Expand All @@ -658,6 +678,14 @@ async def get_eval_configs_score_summary(
mean_absolute_error=(
total_absolute_error[eval_config_id][score_key] / count
),
mean_normalized_squared_error=(
total_normalized_squared_error[eval_config_id][score_key]
/ count
),
mean_normalized_absolute_error=(
total_normalized_absolute_error[eval_config_id][score_key]
/ count
),
)

# Calculate the percent of the dataset that has been processed
Expand Down
10 changes: 10 additions & 0 deletions app/desktop/studio_server/test_eval_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,10 +923,14 @@ class EvalCondigSummaryTestData:
"overall_rating": {
"mean_squared_error": 16.0, # error 4.0^2
"mean_absolute_error": 4.0, # error 4.0
"mean_normalized_squared_error": 1, # max error: 1 v 5
"mean_normalized_absolute_error": 1, # max error: 1 v 5
},
"score1": {
"mean_squared_error": 2.25, # error (3.5-5.0)^2
"mean_absolute_error": 1.5, # error 1.5
"mean_normalized_squared_error": 0.140625, # hand calc
"mean_normalized_absolute_error": 0.375, # 1.5/4
},
}
# 1 of total_in_dataset eval configs are are in ec1 test
Expand All @@ -937,10 +941,14 @@ class EvalCondigSummaryTestData:
"overall_rating": {
"mean_squared_error": 2.5, # error (1^2 + 2^2) / 2
"mean_absolute_error": 1.5, # (1+2)/2
"mean_normalized_squared_error": 0.15625, # (0.25^2 + 0.5^2) / 2
"mean_normalized_absolute_error": 0.375, # (0.25 + 0.5) / 2
},
"score1": {
"mean_squared_error": 2.5, # (1^2+2^2)/2
"mean_absolute_error": 1.5, # (1+2)/2
"mean_normalized_squared_error": 0.15625, # (0.25^2 + 0.5^2) / 2
"mean_normalized_absolute_error": 0.375, # (0.25 + 0.5) / 2
},
}
# 2 of total_in_dataset eval configs are are in ec2 test
Expand All @@ -951,6 +959,8 @@ class EvalCondigSummaryTestData:
"overall_rating": {
"mean_squared_error": 4,
"mean_absolute_error": 2,
"mean_normalized_squared_error": 0.25,
"mean_normalized_absolute_error": 0.5,
},
}
# 2 of total_in_dataset eval configs are are in ec2 test
Expand Down
22 changes: 22 additions & 0 deletions libs/core/kiln_ai/datamodel/task_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from kiln_ai.datamodel.datamodel_enums import TaskOutputRatingType
from kiln_ai.datamodel.json_schema import validate_schema
from kiln_ai.datamodel.strict_mode import strict_mode
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error

if TYPE_CHECKING:
from kiln_ai.datamodel.task import Task
Expand All @@ -25,6 +26,27 @@ class RequirementRating(BaseModel):
type: TaskOutputRatingType = Field(description="The type of rating")


def normalize_rating(rating: float, rating_type: TaskOutputRatingType) -> float:
"""Normalize a rating to a 0-1 scale. Simple normalization, not z-score."""
match rating_type:
case TaskOutputRatingType.five_star:
if rating < 1 or rating > 5:
raise ValueError("Five star rating must be between 1 and 5")
return (rating - 1) / 4
case TaskOutputRatingType.pass_fail:
if rating < 0 or rating > 1:
raise ValueError("Pass fail rating must 0 to 1")
return rating
case TaskOutputRatingType.pass_fail_critical:
if rating < -1 or rating > 1:
raise ValueError("Pass fail critical rating must -1 to 1")
return (rating + 1) / 2 # -1 to 1
case TaskOutputRatingType.custom:
raise ValueError("Custom rating type can not be normalized")
case _:
raise_exhaustive_enum_error(rating_type)


class TaskOutputRating(KilnBaseModel):
"""
A rating for a task output, including an overall rating and ratings for each requirement.
Expand Down
41 changes: 41 additions & 0 deletions libs/core/kiln_ai/datamodel/test_task.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
from pydantic import ValidationError

from kiln_ai.datamodel.datamodel_enums import TaskOutputRatingType
from kiln_ai.datamodel.prompt_id import PromptGenerators
from kiln_ai.datamodel.task import RunConfig, RunConfigProperties, Task, TaskRunConfig
from kiln_ai.datamodel.task_output import normalize_rating


def test_runconfig_valid_creation():
Expand Down Expand Up @@ -116,3 +118,42 @@ def test_task_run_config_missing_task_in_run_config(sample_task):
model_provider_name="openai",
task=None, # type: ignore
)


@pytest.mark.parametrize(
"rating_type,rating,expected",
[
(TaskOutputRatingType.five_star, 1, 0),
(TaskOutputRatingType.five_star, 2, 0.25),
(TaskOutputRatingType.five_star, 3, 0.5),
(TaskOutputRatingType.five_star, 4, 0.75),
(TaskOutputRatingType.five_star, 5, 1),
(TaskOutputRatingType.pass_fail, 0, 0),
(TaskOutputRatingType.pass_fail, 1, 1),
(TaskOutputRatingType.pass_fail, 0.5, 0.5),
(TaskOutputRatingType.pass_fail_critical, -1, 0),
(TaskOutputRatingType.pass_fail_critical, 0, 0.5),
(TaskOutputRatingType.pass_fail_critical, 1, 1),
(TaskOutputRatingType.pass_fail_critical, 0.5, 0.75),
],
)
def test_normalize_rating(rating_type, rating, expected):
assert normalize_rating(rating, rating_type) == expected


@pytest.mark.parametrize(
"rating_type,rating",
[
(TaskOutputRatingType.five_star, 0),
(TaskOutputRatingType.five_star, 6),
(TaskOutputRatingType.pass_fail, -0.5),
(TaskOutputRatingType.pass_fail, 1.5),
(TaskOutputRatingType.pass_fail_critical, -1.5),
(TaskOutputRatingType.pass_fail_critical, 1.5),
(TaskOutputRatingType.custom, 0),
(TaskOutputRatingType.custom, 99),
],
)
def test_normalize_rating_errors(rating_type, rating):
with pytest.raises(ValueError):
normalize_rating(rating, rating_type)

0 comments on commit a493ccd

Please sign in to comment.