Skip to content

Commit

Permalink
fix: percentage fix (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
rivamarco authored Dec 9, 2024
1 parent 5c3ca24 commit 523d197
Show file tree
Hide file tree
Showing 9 changed files with 2,602 additions and 23 deletions.
40 changes: 20 additions & 20 deletions spark/jobs/metrics/percentages.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def calculate_percentages(

model_quality_reference = metrics_service.calculate_model_quality()

def compute_mq_percentage(metrics_cur, metric_ref):
def _compute_mq_percentage(metrics_cur, metric_ref):
metrics_cur_np = np.array(metrics_cur)

# bootstrap Parameters
Expand All @@ -80,6 +80,7 @@ def compute_mq_percentage(metrics_cur, metric_ref):
# calculate 95% confidence interval
lower_bound = np.percentile(bootstrap_means, 2.5)
upper_bound = np.percentile(bootstrap_means, 97.5)

return 1 if not (lower_bound <= metric_ref <= upper_bound) else 0

perc_model_quality = {"value": 0, "details": []}
Expand All @@ -95,15 +96,15 @@ def compute_mq_percentage(metrics_cur, metric_ref):
perc_model_quality["value"] = -1
break
else:
is_flag = compute_mq_percentage(metrics_cur, metric_ref)
is_flag = _compute_mq_percentage(metrics_cur, metric_ref)
flagged_metrics += is_flag
if is_flag:
perc_model_quality["details"].append(
{"feature_name": key_m, "score": -1}
)
perc_model_quality["value"] = 1 - (
flagged_metrics / len(model_quality_reference)
)
perc_model_quality["value"] = 1 - (
flagged_metrics / len(model_quality_current["grouped_metrics"])
)

elif model.model_type == ModelType.MULTI_CLASS:
flagged_metrics = 0
Expand All @@ -119,23 +120,22 @@ def compute_mq_percentage(metrics_cur, metric_ref):
# not enough values to do the test, return -1
cumulative_sum -= 10000
else:
is_flag = compute_mq_percentage(metrics_cur, metric_ref)
is_flag = _compute_mq_percentage(metrics_cur, metric_ref)
flagged_metrics += is_flag
perc_model_quality["details"].append(
{
"feature_name": cm["class_name"] + "_" + key_m,
"score": -1,
}
)
cumulative_sum += 1 - (
flagged_metrics / len(model_quality_reference)
)
perc_model_quality["value"] = (
cumulative_sum
/ (
len(model_quality_reference["classes"])
* len(model_quality_reference["class_metrics"][0])
if is_flag:
perc_model_quality["details"].append(
{
"feature_name": cm["class_name"] + "_" + key_m,
"score": -1,
}
)
cumulative_sum += 1 - (
flagged_metrics
/ len(model_quality_reference["class_metrics"][0]["metrics"])
)
flagged_metrics = 0
perc_model_quality["value"] = (
cumulative_sum / len(model_quality_reference["classes"])
if cumulative_sum > 0
else -1
)
Expand Down
2 changes: 1 addition & 1 deletion spark/jobs/utils/reference_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def calculate_model_quality(self) -> ModelQualityRegression:
model=self.reference.model,
dataframe=self.reference.reference,
dataframe_count=self.reference.reference_count,
).dict()
).model_dump()

metrics["residuals"] = ModelQualityRegressionCalculator.residual_metrics(
model=self.reference.model,
Expand Down
199 changes: 199 additions & 0 deletions spark/tests/percentages_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,34 @@ def dataset_perfect_classes(spark_fixture, test_data_dir):
)


@pytest.fixture()
def dataset_talk(spark_fixture, test_data_dir):
yield (
spark_fixture.read.csv(
f"{test_data_dir}/reference/multiclass/reference_sentiment_analysis_talk.csv",
header=True,
),
spark_fixture.read.csv(
f"{test_data_dir}/current/multiclass/current_sentiment_analysis_talk.csv",
header=True,
),
)


@pytest.fixture()
def dataset_demo(spark_fixture, test_data_dir):
yield (
spark_fixture.read.csv(
f"{test_data_dir}/reference/multiclass/3_classes_reference.csv",
header=True,
),
spark_fixture.read.csv(
f"{test_data_dir}/current/multiclass/3_classes_current1.csv",
header=True,
),
)


def test_calculation_dataset_perfect_classes(spark_fixture, dataset_perfect_classes):
output = OutputType(
prediction=ColumnDefinition(
Expand Down Expand Up @@ -363,3 +391,174 @@ def test_percentages_abalone(spark_fixture, test_dataset_abalone):
ignore_order=True,
significant_digits=6,
)


def test_percentages_dataset_talk(spark_fixture, dataset_talk):
output = OutputType(
prediction=ColumnDefinition(
name="content", type=SupportedTypes.int, field_type=FieldTypes.categorical
),
prediction_proba=None,
output=[
ColumnDefinition(
name="content",
type=SupportedTypes.int,
field_type=FieldTypes.categorical,
)
],
)
target = ColumnDefinition(
name="label", type=SupportedTypes.int, field_type=FieldTypes.categorical
)
timestamp = ColumnDefinition(
name="rbit_prediction_ts",
type=SupportedTypes.datetime,
field_type=FieldTypes.datetime,
)
granularity = Granularity.HOUR
features = [
ColumnDefinition(
name="total_tokens",
type=SupportedTypes.int,
field_type=FieldTypes.numerical,
),
ColumnDefinition(
name="prompt_tokens",
type=SupportedTypes.int,
field_type=FieldTypes.numerical,
),
]
model = ModelOut(
uuid=uuid.uuid4(),
name="talk model",
description="description",
model_type=ModelType.MULTI_CLASS,
data_type=DataType.TABULAR,
timestamp=timestamp,
granularity=granularity,
outputs=output,
target=target,
features=features,
frameworks="framework",
algorithm="algorithm",
created_at=str(datetime.datetime.now()),
updated_at=str(datetime.datetime.now()),
)

raw_reference_dataset, raw_current_dataset = dataset_talk
current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset)
reference_dataset = ReferenceDataset(
model=model, raw_dataframe=raw_reference_dataset
)

drift = DriftCalculator.calculate_drift(
spark_session=spark_fixture,
current_dataset=current_dataset,
reference_dataset=reference_dataset,
)

metrics_service = CurrentMetricsMulticlassService(
spark_session=spark_fixture,
current=current_dataset,
reference=reference_dataset,
)

model_quality = metrics_service.calculate_model_quality()

percentages = PercentageCalculator.calculate_percentages(
spark_session=spark_fixture,
drift=drift,
model_quality_current=model_quality,
current_dataset=current_dataset,
reference_dataset=reference_dataset,
model=model,
)

assert not deepdiff.DeepDiff(
percentages,
res.test_dataset_talk,
ignore_order=True,
significant_digits=6,
)


def test_percentages_dataset_demo(spark_fixture, dataset_demo):
output = OutputType(
prediction=ColumnDefinition(
name="prediction",
type=SupportedTypes.int,
field_type=FieldTypes.categorical,
),
prediction_proba=None,
output=[
ColumnDefinition(
name="prediction",
type=SupportedTypes.int,
field_type=FieldTypes.categorical,
)
],
)
target = ColumnDefinition(
name="ground_truth", type=SupportedTypes.int, field_type=FieldTypes.categorical
)
timestamp = ColumnDefinition(
name="timestamp", type=SupportedTypes.datetime, field_type=FieldTypes.datetime
)
granularity = Granularity.DAY
features = [
ColumnDefinition(
name="age", type=SupportedTypes.int, field_type=FieldTypes.numerical
)
]
model = ModelOut(
uuid=uuid.uuid4(),
name="talk model",
description="description",
model_type=ModelType.MULTI_CLASS,
data_type=DataType.TABULAR,
timestamp=timestamp,
granularity=granularity,
outputs=output,
target=target,
features=features,
frameworks="framework",
algorithm="algorithm",
created_at=str(datetime.datetime.now()),
updated_at=str(datetime.datetime.now()),
)

raw_reference_dataset, raw_current_dataset = dataset_demo
current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset)
reference_dataset = ReferenceDataset(
model=model, raw_dataframe=raw_reference_dataset
)

drift = DriftCalculator.calculate_drift(
spark_session=spark_fixture,
current_dataset=current_dataset,
reference_dataset=reference_dataset,
)

metrics_service = CurrentMetricsMulticlassService(
spark_session=spark_fixture,
current=current_dataset,
reference=reference_dataset,
)

model_quality = metrics_service.calculate_model_quality()

percentages = PercentageCalculator.calculate_percentages(
spark_session=spark_fixture,
drift=drift,
model_quality_current=model_quality,
current_dataset=current_dataset,
reference_dataset=reference_dataset,
model=model,
)

assert not deepdiff.DeepDiff(
percentages,
res.test_dataset_demo,
ignore_order=True,
significant_digits=6,
)
Loading

0 comments on commit 523d197

Please sign in to comment.