Skip to content

Commit

Permalink
integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Feb 4, 2025
1 parent ce7187a commit c655fd3
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 16 deletions.
22 changes: 15 additions & 7 deletions sdmetrics/single_table/data_augmentation/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Base class for Efficacy metrics for single table datasets."""

from copy import deepcopy

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, precision_recall_curve, precision_score, recall_score
Expand Down Expand Up @@ -44,7 +46,9 @@ def _fit(
self.prediction_column_name = prediction_column_name
self.minority_class_label = minority_class_label
self.fixed_value = fixed_value
self._metric_method = METRIC_NAME_TO_METHOD[self.metric_name]
# To assess the preicision efficacy, we have to fix the recall and reciprocally
self._metric_to_fix = 'recall' if self.metric_name == 'precision' else 'precision'
self._metric_method = METRIC_NAME_TO_METHOD[self._metric_to_fix]
self._classifier_name = classifier
self._classifier = XGBClassifier(enable_categorical=True)

Expand Down Expand Up @@ -78,7 +82,8 @@ def _get_best_threshold(self, train_data, train_target):
"""Find the best threshold for the classifier model."""
target_probabilities = self._classifier.predict_proba(train_data)[:, 1]
precision, recall, thresholds = precision_recall_curve(train_target, target_probabilities)
metric = recall if self.metric_name == 'recall' else precision
# To assess the preicision efficacy, we have to fix the recall and reciprocally
metric = precision if self.metric_name == 'recall' else recall
best_threshold = 0.0
valid_idx = np.where(metric >= self.fixed_value)[0]
if valid_idx.size:
Expand Down Expand Up @@ -111,21 +116,24 @@ def _compute_validation_scores(self, real_validation_data):
precision = precision_score(real_validation_target, predictions)
conf_matrix = confusion_matrix(real_validation_target, predictions)
prediction_counts_validation = {
'true_positive': conf_matrix[1, 1],
'false_positive': conf_matrix[0, 1],
'true_negative': conf_matrix[0, 0],
'false_negative': conf_matrix[1, 0],
'true_positive': int(conf_matrix[1, 1]),
'false_positive': int(conf_matrix[0, 1]),
'true_negative': int(conf_matrix[0, 0]),
'false_negative': int(conf_matrix[1, 0]),
}

return recall, precision, prediction_counts_validation

def _get_scores(self, training_table, validation_table):
"""Get the scores of the metric."""
training_table = deepcopy(training_table)
validation_table = deepcopy(validation_table)
training_score = self._train_model(training_table)
recall, precision, prediction_counts_validation = self._compute_validation_scores(
validation_table
)
scores = {
f'{self.metric_name}_score_training': training_score,
f'{self._metric_to_fix}_score_training': training_score,
'recall_score_validation': recall,
'precision_score_validation': precision,
'prediction_counts_validation': prediction_counts_validation,
Expand Down
26 changes: 20 additions & 6 deletions sdmetrics/single_table/data_augmentation/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Utils method for data augmentation metrics."""

import warnings

import pandas as pd


Expand Down Expand Up @@ -51,10 +49,16 @@ def _validate_data_and_metadata(
minority_class_label,
):
"""Validate the data and metadata of the Data Augmentation metrics."""
if prediction_column_name not in metadata['columns']:
raise ValueError(
f'The column `{prediction_column_name}` is not described in the metadata.'
'Please update your metadata.'
)

if metadata['columns'][prediction_column_name]['sdtype'] not in ('categorical', 'boolean'):
raise ValueError(
f'The column `{prediction_column_name}` must be either categorical or boolean.'
'Please update your metadata.'
' Please update your metadata.'
)

columns_match = (
Expand All @@ -76,9 +80,19 @@ def _validate_data_and_metadata(
)

if minority_class_label not in real_validation_data[prediction_column_name].unique():
warnings.warn(
f'The value `{minority_class_label}` is not present in the column '
f'`{prediction_column_name}` for the real validation data.'
raise ValueError(
f"The metric can't be computed because the value `{minority_class_label}` "
f'is not present in the column `{prediction_column_name}` for the real validation data.'
'The `precision`and `recall` are undefined for this case.'
)

synthetic_labels = set(synthetic_data[prediction_column_name].unique())
real_labels = set(real_training_data[prediction_column_name].unique())
if not synthetic_labels.issubset(real_labels):
raise ValueError(
f'The ``{prediction_column_name}`` column must have the same values in the real '
'and synthetic data. The synthetic data has the following unseen values: '
f'{sorted(synthetic_labels - real_labels)}'
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import re

import numpy as np
import pytest

from sdmetrics.demos import load_demo
from sdmetrics.single_table.data_augmentation import BinaryClassifierPrecisionEfficacy

Expand All @@ -6,8 +11,11 @@ class TestBinaryClassifierPrecisionEfficacy:
def test_end_to_end(self):
"""Test the metric end-to-end."""
# Setup
np.random.seed(0)
real_data, synthetic_data, metadata = load_demo(modality='single_table')
real_training, real_validation = real_data.train_test_split(test_size=0.2, random_state=0)
mask_validation = np.random.rand(len(real_data)) < 0.8
real_training = real_data[mask_validation]
real_validation = real_data[~mask_validation]

# Run
score_breakdown = BinaryClassifierPrecisionEfficacy.compute_breakdown(
Expand All @@ -21,5 +29,221 @@ def test_end_to_end(self):
fixed_recall_value=0.8,
)

score = BinaryClassifierPrecisionEfficacy.compute(
real_training_data=real_training,
synthetic_data=synthetic_data,
real_validation_data=real_validation,
metadata=metadata,
prediction_column_name='gender',
minority_class_label='F',
classifier='XGBoost',
fixed_recall_value=0.8,
)

# Assert
expected_score_breakdown = {
'real_data_baseline': {
'recall_score_training': 0.8095238095238095,
'recall_score_validation': 0.07692307692307693,
'precision_score_validation': 0.25,
'prediction_counts_validation': {
'true_positive': 1,
'false_positive': 3,
'true_negative': 22,
'false_negative': 12,
},
},
'augmented_data': {
'recall_score_training': 0.8057553956834532,
'recall_score_validation': 0.0,
'precision_score_validation': 0.0,
'prediction_counts_validation': {
'true_positive': 0,
'false_positive': 2,
'true_negative': 23,
'false_negative': 13,
},
},
'parameters': {
'prediction_column_name': 'gender',
'minority_class_label': 'F',
'classifier': 'XGBoost',
'fixed_recall_value': 0.8,
},
'score': 0,
}
assert np.isclose(
score_breakdown['real_data_baseline']['recall_score_training'], 0.8, atol=0.1
)
assert np.isclose(
score_breakdown['augmented_data']['recall_score_validation'], 0.1, atol=0.1
)
assert score_breakdown == expected_score_breakdown
assert score == score_breakdown['score']

def test_with_no_minority_class_in_validation(self):
"""Test the metric when the minority class is not present in the validation data."""
# Setup
np.random.seed(0)
real_data, synthetic_data, metadata = load_demo(modality='single_table')
mask_validation = np.random.rand(len(real_data)) < 0.8
real_training = real_data[mask_validation]
real_validation = real_data[~mask_validation]
real_validation['gender'] = 'M'
expected_error = re.escape(
"The metric can't be computed because the value `F` is not present in the column "
'`gender` for the real validation data. The `precision`and `recall` are undefined'
' for this case.'
)

# Run and Assert
with pytest.errors(ValueError, match=expected_error):
BinaryClassifierPrecisionEfficacy.compute(
real_training_data=real_training,
synthetic_data=synthetic_data,
real_validation_data=real_validation,
metadata=metadata,
prediction_column_name='gender',
minority_class_label='F',
classifier='XGBoost',
fixed_recall_value=0.8,
)

def test_with_nan_target_column(self):
"""Test the metric when the target column has NaN values."""
# Setup
np.random.seed(35)
real_data, synthetic_data, metadata = load_demo(modality='single_table')
mask_validation = np.random.rand(len(real_data)) < 0.8
real_training = real_data[mask_validation].reset_index(drop=True)
real_validation = real_data[~mask_validation].reset_index(drop=True)
real_training.loc[:3, 'gender'] = np.nan
real_validation.loc[:5, 'gender'] = np.nan

# Run
result_breakdown = BinaryClassifierPrecisionEfficacy.compute_breakdown(
real_training_data=real_training,
synthetic_data=synthetic_data,
real_validation_data=real_validation,
metadata=metadata,
prediction_column_name='gender',
minority_class_label='F',
classifier='XGBoost',
fixed_recall_value=0.8,
)

# Assert
expected_result = {
'real_data_baseline': {
'recall_score_training': 0.8135593220338984,
'recall_score_validation': 0.23076923076923078,
'precision_score_validation': 0.42857142857142855,
'prediction_counts_validation': {
'true_positive': 3,
'false_positive': 4,
'true_negative': 29,
'false_negative': 10,
},
},
'augmented_data': {
'recall_score_training': 0.8,
'recall_score_validation': 0.07692307692307693,
'precision_score_validation': 0.25,
'prediction_counts_validation': {
'true_positive': 1,
'false_positive': 3,
'true_negative': 30,
'false_negative': 12,
},
},
'parameters': {
'prediction_column_name': 'gender',
'minority_class_label': 'F',
'classifier': 'XGBoost',
'fixed_recall_value': 0.8,
},
'score': 0,
}
assert result_breakdown == expected_result

def test_with_minority_being_majority(self):
"""Test the metric when the minority class is the majority class."""
# Setup
np.random.seed(0)
real_data, synthetic_data, metadata = load_demo(modality='single_table')
mask_validation = np.random.rand(len(real_data)) < 0.8
real_training = real_data[mask_validation]
real_validation = real_data[~mask_validation]

# Run
score = BinaryClassifierPrecisionEfficacy.compute(
real_training_data=real_training,
synthetic_data=synthetic_data,
real_validation_data=real_validation,
metadata=metadata,
prediction_column_name='gender',
minority_class_label='F',
classifier='XGBoost',
fixed_recall_value=0.8,
)

# Assert
assert score == 0

def test_with_multi_class(self):
"""Test the metric with multi-class classification.
The `high_spec` column has 3 values `Commerce`, `Science`, and `Arts`.
"""
# Setup
np.random.seed(0)
real_data, synthetic_data, metadata = load_demo(modality='single_table')
mask_validation = np.random.rand(len(real_data)) < 0.8
real_training = real_data[mask_validation]
real_validation = real_data[~mask_validation]

# Run
score_breakdown = BinaryClassifierPrecisionEfficacy.compute_breakdown(
real_training_data=real_training,
synthetic_data=synthetic_data,
real_validation_data=real_validation,
metadata=metadata,
prediction_column_name='high_spec',
minority_class_label='Science',
classifier='XGBoost',
fixed_recall_value=0.8,
)

# Assert
assert 'real_data_baseline' in score_breakdown
expected_score_breakdown = {
'real_data_baseline': {
'recall_score_training': 0.8076923076923077,
'recall_score_validation': 0.6923076923076923,
'precision_score_validation': 0.9,
'prediction_counts_validation': {
'true_positive': 9,
'false_positive': 1,
'true_negative': 24,
'false_negative': 4,
},
},
'augmented_data': {
'recall_score_training': 0.8035714285714286,
'recall_score_validation': 0.5384615384615384,
'precision_score_validation': 0.875,
'prediction_counts_validation': {
'true_positive': 7,
'false_positive': 1,
'true_negative': 24,
'false_negative': 6,
},
},
'parameters': {
'prediction_column_name': 'high_spec',
'minority_class_label': 'Science',
'classifier': 'XGBoost',
'fixed_recall_value': 0.8,
},
'score': 0,
}
assert score_breakdown == expected_score_breakdown
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ def test__validate_data_and_metadata():
'prediction_column_name': 'target',
'minority_class_label': 1,
}
expected_message_missing_prediction_column = re.escape(
'The column `target` is not described in the metadata. Please update your metadata.'
)
expected_message_sdtype = re.escape(
'The column `target` must be either categorical or boolean.Please update your metadata.'
'The column `target` must be either categorical or boolean. Please update your metadata.'
)
expected_message_column_missmatch = re.escape(
'`real_training_data`, `synthetic_data` and `real_validation_data` must have the '
Expand All @@ -109,6 +112,11 @@ def test__validate_data_and_metadata():

# Run and Assert
_validate_data_and_metadata(**inputs)
missing_prediction_column = deepcopy(inputs)
missing_prediction_column['metadata']['columns'].pop('target')
with pytest.raises(ValueError, match=expected_message_missing_prediction_column):
_validate_data_and_metadata(**missing_prediction_column)

wrong_inputs_sdtype = deepcopy(inputs)
wrong_inputs_sdtype['metadata']['columns']['target']['sdtype'] = 'numerical'
with pytest.raises(ValueError, match=expected_message_sdtype):
Expand Down

0 comments on commit c655fd3

Please sign in to comment.