Skip to content

Commit

Permalink
Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Jan 23, 2025
1 parent e06247d commit b9eb720
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 25 deletions.
20 changes: 12 additions & 8 deletions rdt/transformers/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,10 @@ def __init__(
max_value=1.0,
learn_rounding_scheme=False,
):
if min_value == max_value:
error_msg = 'The min_value and max_value for the logit function cannot be equal.'
raise TransformerInputError(error_msg)

super().__init__(
missing_value_replacement=missing_value_replacement,
missing_value_generation=missing_value_generation,
Expand Down Expand Up @@ -707,18 +711,18 @@ def _transform(self, data):
logit_vals = logit(transformed_vals, self.min_value, self.max_value)
if transformed.ndim == 1:
return logit_vals
else:
transformed[:, 0] = logit_vals
return transformed

transformed[:, 0] = logit_vals
return transformed

def _reverse_transform(self, data):
if not isinstance(data, np.ndarray):
data = data.to_numpy()

sampled_vals = data if data.ndim == 1 else data[:, 0]
reversed = sigmoid(sampled_vals, self.min_value, self.max_value)
reversed_values = sigmoid(sampled_vals, self.min_value, self.max_value)
if data.ndim == 1:
return super()._reverse_transform(reversed)
else:
data[:, 0] = reversed
return super()._reverse_transform(data)
return super()._reverse_transform(reversed_values)

data[:, 0] = reversed_values
return super()._reverse_transform(data)
37 changes: 26 additions & 11 deletions tests/integration/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,28 @@
}


def _create_transformer_args_from_data(transformer_args, data):
"""Helper to extract transformer arguments that are data-dependent.
Args:
transformer_args (dict):
The transformer arguments.
data (pd.Series):
The data for the transformer.
Returns:
dict:
The transformer arguments with data-specific arguments added.
"""
if 'FROM_DATA' in transformer_args:
transformer_args = {**transformer_args}
args = transformer_args.pop('FROM_DATA')
for arg, arg_func in args.items():
transformer_args[arg] = arg_func(data)

return transformer_args


def _validate_helper(validator_function, args, steps):
"""Wrap around validation functions to either return a boolean or assert.
Expand Down Expand Up @@ -157,11 +179,7 @@ def _test_transformer_with_dataset(transformer_class, input_data, steps):
"""

transformer_args = TRANSFORMER_ARGS.get(transformer_class.__name__, {})
if 'FROM_DATA' in transformer_args:
transformer_args = {**transformer_args}
args = transformer_args.pop('FROM_DATA')
for arg, arg_func in args.items():
transformer_args[arg] = arg_func(input_data[TEST_COL])
transformer_args = _create_transformer_args_from_data(transformer_args, input_data[TEST_COL])

transformer = transformer_class(**transformer_args)
# Fit
Expand Down Expand Up @@ -217,12 +235,9 @@ def _test_transformer_with_hypertransformer(transformer_class, input_data, steps
transformer_args = TRANSFORMER_ARGS.get(transformer_class.__name__, {})
hypertransformer = HyperTransformer()
if transformer_args:
if 'FROM_DATA' in transformer_args:
transformer_args = {**transformer_args}
args = transformer_args.pop('FROM_DATA')
for arg, arg_func in args.items():
transformer_args[arg] = arg_func(input_data[TEST_COL])

transformer_args = _create_transformer_args_from_data(
transformer_args, input_data[TEST_COL]
)
field_transformers = {TEST_COL: transformer_class(**transformer_args)}

else:
Expand Down
32 changes: 26 additions & 6 deletions tests/unit/transformers/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,7 +1889,18 @@ def test___init__(self):
assert ls.max_value == 100.0
assert ls.min_value == 2.0

def test__validate_logit_inputs(self):
def test___init___invalid_inputs(self):
"""Test super() arguments are properly passed and set as attributes."""
# Setup
min_value = 10.0
max_value = 10.0

# Run / Assert
expected_msg = 'The min_value and max_value for the logit function cannot be equal.'
with pytest.raises(TransformerInputError, match=re.escape(expected_msg)):
LogitScaler(max_value=max_value, min_value=min_value)

def test__validate_logit_inputs_with_default_settings(self):
"""Test validating data against input arguments."""
# Setup
ls = LogitScaler()
Expand All @@ -1898,6 +1909,15 @@ def test__validate_logit_inputs(self):
# Run and Assert
ls._validate_logit_inputs(data)

def test__validate_logit_inputs_with_custom_inputs(self):
"""Test validating data against input arguments."""
# Setup
ls = LogitScaler(min_value=0, max_value=100)
data = pd.Series([0.0, 10.1, 20.2, 30.3, 100])

# Run and Assert
ls._validate_logit_inputs(data)

def test__validate_logit_inputs_errors_invalid_value(self):
"""Test error message contains invalid values."""
# Setup
Expand Down Expand Up @@ -1944,7 +1964,7 @@ def test__fit(self):
def test__transform(self, mock_logit):
"""Test the ``transform`` method."""
# Setup
min_value = (1.0,)
min_value = 1.0
max_value = 50.0
ls = LogitScaler(min_value=min_value, max_value=max_value)
ls._validate_logit_inputs = Mock()
Expand All @@ -1965,7 +1985,7 @@ def test__transform(self, mock_logit):
def test__transform_multi_column(self, mock_logit):
"""Test the ``transform`` method with multiple columns."""
# Setup
min_value = (1.0,)
min_value = 1.0
max_value = 50.0
ls = LogitScaler(min_value=min_value, max_value=max_value)
ls._validate_logit_inputs = Mock()
Expand Down Expand Up @@ -2012,7 +2032,7 @@ def test__reverse_transform(self, mock_sigmoid, ff_reverse_transform_mock):
def test__reverse_transform_multi_column(self, mock_sigmoid, ff_reverse_transform_mock):
"""Test the ``transform`` method with multiple columns."""
# Setup
min_value = (1.0,)
min_value = 1.0
max_value = 50.0
ls = LogitScaler(min_value=min_value, max_value=max_value)
sampled_data = np.array([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])
Expand All @@ -2026,11 +2046,11 @@ def test__reverse_transform_multi_column(self, mock_sigmoid, ff_reverse_transfor
mock_sigmoid.return_value = sigmoid_vals

# Run
reversed = ls._reverse_transform(data)
reversed_values = ls._reverse_transform(data)

# Assert
ff_reverse_transform_args = ff_reverse_transform_mock.call_args[0]
np.testing.assert_array_equal(
ff_reverse_transform_args[0], np.array([sigmoid_vals, is_null]).T
)
assert reversed == ff_reverse_transform_mock.return_value
assert reversed_values == ff_reverse_transform_mock.return_value

0 comments on commit b9eb720

Please sign in to comment.