diff --git a/rdt/transformers/numerical.py b/rdt/transformers/numerical.py index 5761a974..82617aa8 100644 --- a/rdt/transformers/numerical.py +++ b/rdt/transformers/numerical.py @@ -670,6 +670,10 @@ def __init__( max_value=1.0, learn_rounding_scheme=False, ): + if min_value == max_value: + error_msg = 'The min_value and max_value for the logit function cannot be equal.' + raise TransformerInputError(error_msg) + super().__init__( missing_value_replacement=missing_value_replacement, missing_value_generation=missing_value_generation, @@ -707,18 +711,18 @@ def _transform(self, data): logit_vals = logit(transformed_vals, self.min_value, self.max_value) if transformed.ndim == 1: return logit_vals - else: - transformed[:, 0] = logit_vals - return transformed + + transformed[:, 0] = logit_vals + return transformed def _reverse_transform(self, data): if not isinstance(data, np.ndarray): data = data.to_numpy() sampled_vals = data if data.ndim == 1 else data[:, 0] - reversed = sigmoid(sampled_vals, self.min_value, self.max_value) + reversed_values = sigmoid(sampled_vals, self.min_value, self.max_value) if data.ndim == 1: - return super()._reverse_transform(reversed) - else: - data[:, 0] = reversed - return super()._reverse_transform(data) + return super()._reverse_transform(reversed_values) + + data[:, 0] = reversed_values + return super()._reverse_transform(data) diff --git a/tests/integration/test_transformers.py b/tests/integration/test_transformers.py index 0a3d2b3a..4f1a06a3 100644 --- a/tests/integration/test_transformers.py +++ b/tests/integration/test_transformers.py @@ -47,6 +47,28 @@ } +def _create_transformer_args_from_data(transformer_args, data): + """Helper to extract transformer arguments that are data-dependent. + + Args: + transformer_args (dict): + The transformer arguments. + data (pd.Series): + The data for the transformer. + + Returns: + dict: + The transformer arguments with data-specific arguments added. + """ + if 'FROM_DATA' in transformer_args: + transformer_args = {**transformer_args} + args = transformer_args.pop('FROM_DATA') + for arg, arg_func in args.items(): + transformer_args[arg] = arg_func(data) + + return transformer_args + + def _validate_helper(validator_function, args, steps): """Wrap around validation functions to either return a boolean or assert. @@ -157,11 +179,7 @@ def _test_transformer_with_dataset(transformer_class, input_data, steps): """ transformer_args = TRANSFORMER_ARGS.get(transformer_class.__name__, {}) - if 'FROM_DATA' in transformer_args: - transformer_args = {**transformer_args} - args = transformer_args.pop('FROM_DATA') - for arg, arg_func in args.items(): - transformer_args[arg] = arg_func(input_data[TEST_COL]) + transformer_args = _create_transformer_args_from_data(transformer_args, input_data[TEST_COL]) transformer = transformer_class(**transformer_args) # Fit @@ -217,12 +235,9 @@ def _test_transformer_with_hypertransformer(transformer_class, input_data, steps transformer_args = TRANSFORMER_ARGS.get(transformer_class.__name__, {}) hypertransformer = HyperTransformer() if transformer_args: - if 'FROM_DATA' in transformer_args: - transformer_args = {**transformer_args} - args = transformer_args.pop('FROM_DATA') - for arg, arg_func in args.items(): - transformer_args[arg] = arg_func(input_data[TEST_COL]) - + transformer_args = _create_transformer_args_from_data( + transformer_args, input_data[TEST_COL] + ) field_transformers = {TEST_COL: transformer_class(**transformer_args)} else: diff --git a/tests/unit/transformers/test_numerical.py b/tests/unit/transformers/test_numerical.py index ab23eea2..1d4f6755 100644 --- a/tests/unit/transformers/test_numerical.py +++ b/tests/unit/transformers/test_numerical.py @@ -1889,7 +1889,18 @@ def test___init__(self): assert ls.max_value == 100.0 assert ls.min_value == 2.0 - def test__validate_logit_inputs(self): + def test___init___invalid_inputs(self): + """Test super() arguments are properly passed and set as attributes.""" + # Setup + min_value = 10.0 + max_value = 10.0 + + # Run / Assert + expected_msg = 'The min_value and max_value for the logit function cannot be equal.' + with pytest.raises(TransformerInputError, match=re.escape(expected_msg)): + LogitScaler(max_value=max_value, min_value=min_value) + + def test__validate_logit_inputs_with_default_settings(self): """Test validating data against input arguments.""" # Setup ls = LogitScaler() @@ -1898,6 +1909,15 @@ def test__validate_logit_inputs(self): # Run and Assert ls._validate_logit_inputs(data) + def test__validate_logit_inputs_with_custom_inputs(self): + """Test validating data against input arguments.""" + # Setup + ls = LogitScaler(min_value=0, max_value=100) + data = pd.Series([0.0, 10.1, 20.2, 30.3, 100]) + + # Run and Assert + ls._validate_logit_inputs(data) + def test__validate_logit_inputs_errors_invalid_value(self): """Test error message contains invalid values.""" # Setup @@ -1944,7 +1964,7 @@ def test__fit(self): def test__transform(self, mock_logit): """Test the ``transform`` method.""" # Setup - min_value = (1.0,) + min_value = 1.0 max_value = 50.0 ls = LogitScaler(min_value=min_value, max_value=max_value) ls._validate_logit_inputs = Mock() @@ -1965,7 +1985,7 @@ def test__transform(self, mock_logit): def test__transform_multi_column(self, mock_logit): """Test the ``transform`` method with multiple columns.""" # Setup - min_value = (1.0,) + min_value = 1.0 max_value = 50.0 ls = LogitScaler(min_value=min_value, max_value=max_value) ls._validate_logit_inputs = Mock() @@ -2012,7 +2032,7 @@ def test__reverse_transform(self, mock_sigmoid, ff_reverse_transform_mock): def test__reverse_transform_multi_column(self, mock_sigmoid, ff_reverse_transform_mock): """Test the ``transform`` method with multiple columns.""" # Setup - min_value = (1.0,) + min_value = 1.0 max_value = 50.0 ls = LogitScaler(min_value=min_value, max_value=max_value) sampled_data = np.array([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0]) @@ -2026,11 +2046,11 @@ def test__reverse_transform_multi_column(self, mock_sigmoid, ff_reverse_transfor mock_sigmoid.return_value = sigmoid_vals # Run - reversed = ls._reverse_transform(data) + reversed_values = ls._reverse_transform(data) # Assert ff_reverse_transform_args = ff_reverse_transform_mock.call_args[0] np.testing.assert_array_equal( ff_reverse_transform_args[0], np.array([sigmoid_vals, is_null]).T ) - assert reversed == ff_reverse_transform_mock.return_value + assert reversed_values == ff_reverse_transform_mock.return_value