Skip to content

Commit

Permalink
Fix broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Jan 22, 2025
1 parent 07ac708 commit e06247d
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 8 deletions.
1 change: 1 addition & 0 deletions rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ClusterBasedNormalizer,
FloatFormatter,
GaussianNormalizer,
LogitScaler,
)
from rdt.transformers.pii.anonymizer import (
AnonymizedFaker,
Expand Down
22 changes: 18 additions & 4 deletions rdt/transformers/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,9 +702,23 @@ def _fit(self, data):

def _transform(self, data):
transformed = super()._transform(data)
self._validate_logit_inputs(transformed)
return logit(transformed, self.min_value, self.max_value)
transformed_vals = transformed if transformed.ndim == 1 else transformed[:, 0]
self._validate_logit_inputs(transformed_vals)
logit_vals = logit(transformed_vals, self.min_value, self.max_value)
if transformed.ndim == 1:
return logit_vals
else:
transformed[:, 0] = logit_vals
return transformed

def _reverse_transform(self, data):
reversed = sigmoid(data, self.min_value, self.max_value)
return super()._reverse_transform(reversed)
if not isinstance(data, np.ndarray):
data = data.to_numpy()

sampled_vals = data if data.ndim == 1 else data[:, 0]
reversed = sigmoid(sampled_vals, self.min_value, self.max_value)
if data.ndim == 1:
return super()._reverse_transform(reversed)
else:
data[:, 0] = reversed
return super()._reverse_transform(data)
7 changes: 4 additions & 3 deletions tests/integration/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
'GaussianNormalizer': {'missing_value_generation': 'from_column'},
'ClusterBasedNormalizer': {'missing_value_generation': 'from_column'},
'LogitScaler': {
'missing_value_generation': 'from_column',
'FROM_DATA': {
'min_value': lambda x: np.nanmin(x) - 1,
'max_value': lambda x: np.nanmax(x) + 1,
}
'min_value': lambda x: np.nanmin(x) - 0.01,
'max_value': lambda x: np.nanmax(x) + 0.01,
},
},
}

Expand Down
55 changes: 54 additions & 1 deletion tests/unit/transformers/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,6 +1961,28 @@ def test__transform(self, mock_logit):
mock_logit.assert_called_once_with(data, ls.min_value, ls.max_value)
assert transformed == mock_logit.return_value

@patch('rdt.transformers.numerical.logit')
def test__transform_multi_column(self, mock_logit):
"""Test the ``transform`` method with multiple columns."""
# Setup
min_value = (1.0,)
max_value = 50.0
ls = LogitScaler(min_value=min_value, max_value=max_value)
ls._validate_logit_inputs = Mock()
data = pd.Series([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])
null_transformer_mock = Mock()
is_null = np.array([0, 0, 0, 1, 0, 1, 0])
null_transformer_mock.transform.return_value = np.array([data.to_numpy(), is_null]).T
ls.null_transformer = null_transformer_mock
logit_values = np.array([0.0, 0.1, 0.2, 0.3, 0.3, 1.4, 2.5])
mock_logit.return_value = logit_values

# Run
transformed = ls._transform(data)

# Assert
np.testing.assert_array_equal(transformed, np.array([logit_values, is_null]).T)

@patch('rdt.transformers.numerical.FloatFormatter._reverse_transform')
@patch('rdt.transformers.numerical.sigmoid')
def test__reverse_transform(self, mock_sigmoid, ff_reverse_transform_mock):
Expand All @@ -1978,6 +2000,37 @@ def test__reverse_transform(self, mock_sigmoid, ff_reverse_transform_mock):
reversed = ls._reverse_transform(data)

# Assert
mock_sigmoid.assert_called_once_with(data, ls.min_value, ls.max_value)
mock_sigmoid_args = mock_sigmoid.call_args[0]
np.testing.assert_array_equal(mock_sigmoid_args[0], data.to_numpy())
assert mock_sigmoid_args[1] == ls.min_value
assert mock_sigmoid_args[2] == ls.max_value
ff_reverse_transform_mock.assert_called_once_with(mock_sigmoid.return_value)
assert reversed == ff_reverse_transform_mock.return_value

@patch('rdt.transformers.numerical.FloatFormatter._reverse_transform')
@patch('rdt.transformers.numerical.sigmoid')
def test__reverse_transform_multi_column(self, mock_sigmoid, ff_reverse_transform_mock):
"""Test the ``transform`` method with multiple columns."""
# Setup
min_value = (1.0,)
max_value = 50.0
ls = LogitScaler(min_value=min_value, max_value=max_value)
sampled_data = np.array([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])
is_null = np.array([0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0])
data = pd.DataFrame({'column': sampled_data, 'column.is_null': is_null})
null_transformer_mock = Mock()
reversed = np.array([1.0, 1.1, np.nan, np.nan, 2.0, np.nan, np.nan])
null_transformer_mock.reverse_transform.return_value = reversed
ls.null_transformer = null_transformer_mock
sigmoid_vals = np.array([3.0, 3.1, 3.3, 3.4, 2.1, 4.0, 4.6])
mock_sigmoid.return_value = sigmoid_vals

# Run
reversed = ls._reverse_transform(data)

# Assert
ff_reverse_transform_args = ff_reverse_transform_mock.call_args[0]
np.testing.assert_array_equal(
ff_reverse_transform_args[0], np.array([sigmoid_vals, is_null]).T
)
assert reversed == ff_reverse_transform_mock.return_value

0 comments on commit e06247d

Please sign in to comment.