Skip to content

Commit

Permalink
naming: remove preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Feb 6, 2025
1 parent 233bde5 commit fbc5d37
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
8 changes: 4 additions & 4 deletions sdmetrics/single_table/data_augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class BaseDataAugmentationMetric(SingleTableMetric):
max_value = 1.0

@classmethod
def _fit_preprocess(cls, data, metadata, prediction_column_name):
def _fit(cls, data, metadata, prediction_column_name):
"""Fit preprocessing parameters."""
discrete_columns = []
datetime_columns = []
Expand All @@ -115,7 +115,7 @@ def _fit_preprocess(cls, data, metadata, prediction_column_name):
return discrete_columns, datetime_columns

@classmethod
def _transform_preprocess(
def _transform(
cls,
tables,
discrete_columns,
Expand Down Expand Up @@ -152,7 +152,7 @@ def _fit_transform(
minority_class_label,
):
"""Fit and transform the metric."""
discrete_columns, datetime_columns = cls._fit_preprocess(
discrete_columns, datetime_columns = cls._fit(
real_training_data, metadata, prediction_column_name
)
tables = {
Expand All @@ -161,7 +161,7 @@ def _fit_transform(
'real_validation_data': real_validation_data,
}

return cls._transform_preprocess(
return cls._transform(
tables,
discrete_columns,
datetime_columns,
Expand Down
26 changes: 11 additions & 15 deletions tests/unit/single_table/data_augmentation/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,22 +185,20 @@ def test_get_scores(self, real_training_data, real_validation_data):
class TestBaseDataAugmentationMetric:
"""Test the BaseDataAugmentationMetric class."""

def test__fit_preprocess(self, real_training_data, metadata):
"""Test the ``_fit_preprocess`` method."""
def test__fit(self, real_training_data, metadata):
"""Test the ``_fit`` method."""
# Setup
metric = BaseDataAugmentationMetric()

# Run
discrete_columns, datetime_columns = metric._fit_preprocess(
real_training_data, metadata, 'target'
)
discrete_columns, datetime_columns = metric._fit(real_training_data, metadata, 'target')

# Assert
assert discrete_columns == ['categorical', 'boolean']
assert datetime_columns == ['datetime']

def test__transform_preprocess(self, real_training_data, synthetic_data, real_validation_data):
"""Test the ``_transform_preprocess`` method."""
def test__transform(self, real_training_data, synthetic_data, real_validation_data):
"""Test the ``_transform`` method."""
# Setup
metric = BaseDataAugmentationMetric()
discrete_columns = ['categorical', 'boolean']
Expand All @@ -212,9 +210,7 @@ def test__transform_preprocess(self, real_training_data, synthetic_data, real_va
}

# Run
transformed = metric._transform_preprocess(
tables, discrete_columns, datetime_columns, 'target', 1
)
transformed = metric._transform(tables, discrete_columns, datetime_columns, 'target', 1)

# Assert
expected_transformed = {
Expand Down Expand Up @@ -257,10 +253,10 @@ def test__fit_transform(
"""Test the ``_fit_transform`` method."""
# Setup
metric = BaseDataAugmentationMetric()
BaseDataAugmentationMetric._fit_preprocess = Mock()
BaseDataAugmentationMetric._fit = Mock()
discrete_columns = ['categorical', 'boolean']
datetime_columns = ['datetime']
BaseDataAugmentationMetric._fit_preprocess.return_value = (
BaseDataAugmentationMetric._fit.return_value = (
discrete_columns,
datetime_columns,
)
Expand All @@ -269,18 +265,18 @@ def test__fit_transform(
'synthetic_data': synthetic_data,
'real_validation_data': real_validation_data,
}
BaseDataAugmentationMetric._transform_preprocess = Mock(return_value=tables)
BaseDataAugmentationMetric._transform = Mock(return_value=tables)

# Run
transformed = metric._fit_transform(
real_training_data, synthetic_data, real_validation_data, metadata, 'target', 1
)

# Assert
BaseDataAugmentationMetric._fit_preprocess.assert_called_once_with(
BaseDataAugmentationMetric._fit.assert_called_once_with(
real_training_data, metadata, 'target'
)
BaseDataAugmentationMetric._transform_preprocess.assert_called_once_with(
BaseDataAugmentationMetric._transform.assert_called_once_with(
tables, discrete_columns, datetime_columns, 'target', 1
)
for table_name, table in transformed.items():
Expand Down

0 comments on commit fbc5d37

Please sign in to comment.