Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Feb 4, 2025
1 parent b08939c commit 69dff4b
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def install_libomp():
print('✅ Homebrew installed. Now installing libomp...') # noqa: T201
os.system('brew install libomp')
else:
print(
print( # noqa
'⚠️ Skipping Homebrew installation. Please install it manually: https://brew.sh/'
) # noqa: T201

Expand All @@ -55,6 +55,6 @@ def install_libomp():
if is_installed(['where', 'vcomp140.dll']) or is_installed(['where', 'libomp.dll']):
print('✅ libomp is already installed.') # noqa: T201
else:
print(
"⚠️ libomp not found. Please install 'Microsoft OpenMP Library' from https://visualstudio.microsoft.com/downloads/ (included in MSVC)."
print( # noqa
"⚠️ libomp not found. Please install 'Microsoft OpenMP Library' from https://visualstudio.microsoft.com/downloads/ (included in MSVC)." # noqa
) # noqa
10 changes: 4 additions & 6 deletions sdmetrics/single_table/data_augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@
from sklearn.metrics import confusion_matrix, precision_recall_curve, precision_score, recall_score
from sklearn.preprocessing import OrdinalEncoder

from sdmetrics.single_table.data_augmentation._libomp_installation import install_libomp


install_libomp()
from xgboost import XGBClassifier

from sdmetrics.goal import Goal
from sdmetrics.single_table.base import SingleTableMetric
from sdmetrics.single_table.data_augmentation._libomp_installation import install_libomp
from sdmetrics.single_table.data_augmentation.utils import _validate_inputs

install_libomp()
from xgboost import XGBClassifier # noqa

METRIC_NAME_TO_METHOD = {'recall': recall_score, 'precision': precision_score}


Expand Down
4 changes: 2 additions & 2 deletions sdmetrics/single_table/data_augmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _validate_data_and_metadata(
if prediction_column_name not in metadata['columns']:
raise ValueError(
f'The column `{prediction_column_name}` is not described in the metadata.'
'Please update your metadata.'
' Please update your metadata.'
)

if metadata['columns'][prediction_column_name]['sdtype'] not in ('categorical', 'boolean'):
Expand Down Expand Up @@ -83,7 +83,7 @@ def _validate_data_and_metadata(
raise ValueError(
f"The metric can't be computed because the value `{minority_class_label}` "
f'is not present in the column `{prediction_column_name}` for the real validation data.'
'The `precision`and `recall` are undefined for this case.'
' The `precision`and `recall` are undefined for this case.'
)

synthetic_labels = set(synthetic_data[prediction_column_name].unique())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def test_end_to_end(self):
'real_data_baseline': {
'recall_score_training': 0.8095238095238095,
'recall_score_validation': 0.07692307692307693,
'precision_score_validation': 0.25,
'precision_score_validation': 1.0,
'prediction_counts_validation': {
'true_positive': 1,
'false_positive': 3,
'true_negative': 22,
'false_positive': 0,
'true_negative': 25,
'false_negative': 12,
},
},
Expand All @@ -59,8 +59,8 @@ def test_end_to_end(self):
'precision_score_validation': 0.0,
'prediction_counts_validation': {
'true_positive': 0,
'false_positive': 2,
'true_negative': 23,
'false_positive': 0,
'true_negative': 25,
'false_negative': 13,
},
},
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_with_no_minority_class_in_validation(self):
)

# Run and Assert
with pytest.errors(ValueError, match=expected_error):
with pytest.raises(ValueError, match=expected_error):
BinaryClassifierPrecisionEfficacy.compute(
real_training_data=real_training,
synthetic_data=synthetic_data,
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/single_table/data_augmentation/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.metrics import precision_score
from sklearn.metrics import precision_score, recall_score

from sdmetrics.single_table.data_augmentation.base import BaseDataAugmentationMetric

Expand Down Expand Up @@ -101,7 +101,7 @@ def test__fit(self, real_training_data, metadata):
assert metric.prediction_column_name == prediction_column_name
assert metric.minority_class_label == minority_class_label
assert metric.fixed_value == fixed_recall_value
assert metric._metric_method == precision_score
assert metric._metric_method == recall_score
assert metric._classifier_name == classifier
# assert metric._classifier == 'XGBClassifier()'

Expand All @@ -119,7 +119,7 @@ def test__get_best_threshold(self, mock_precision_recall_curve, real_training_da
np.array([0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0, 0.0]),
np.array([0.02, 0.15, 0.25, 0.35, 0.42, 0.51, 0.63, 0.77, 0.82, 0.93, 0.97]),
]
metric.metric_name = 'precision'
metric.metric_name = 'recall'
metric.fixed_value = 0.69
train_data = real_training_data[['numerical']]
train_target = real_training_data['target']
Expand Down Expand Up @@ -188,6 +188,7 @@ def test__get_scores(self, real_training_data, real_validation_data):
metric = BaseDataAugmentationMetric()
metric.metric_name = 'precision'
metric._train_model = Mock(return_value=0.78)
metric._metric_to_fix = 'recall'
metric._compute_validation_scores = Mock(
return_value=(
1.0,
Expand All @@ -206,7 +207,7 @@ def test__get_scores(self, real_training_data, real_validation_data):

# Assert
assert scores == {
'precision_score_training': 0.78,
'recall_score_training': 0.78,
'recall_score_validation': 1.0,
'precision_score_validation': 0.5,
'prediction_counts_validation': {
Expand Down
8 changes: 5 additions & 3 deletions tests/unit/single_table/data_augmentation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ def test__validate_data_and_metadata():
expected_message_value = re.escape(
'The value `1` is not present in the column `target` for the real training data.'
)
expected_warning = re.escape(
'The value `1` is not present in the column `target` for the real validation data.'
expected_error_missing_minority = re.escape(
"The metric can't be computed because the value `1` is not present in "
'the column `target` for the real validation data. The `precision`and `recall`'
' are undefined for this case.'
)

# Run and Assert
Expand Down Expand Up @@ -141,7 +143,7 @@ def test__validate_data_and_metadata():
missing_minority_class_label_validation['real_validation_data'] = pd.DataFrame({
'target': [0, 0, 0]
})
with pytest.warns(UserWarning, match=expected_warning):
with pytest.raises(ValueError, match=expected_error_missing_minority):
_validate_data_and_metadata(**missing_minority_class_label_validation)


Expand Down

0 comments on commit 69dff4b

Please sign in to comment.