diff --git a/rdt/transformers/numerical.py b/rdt/transformers/numerical.py index 64a5967f..1425670b 100644 --- a/rdt/transformers/numerical.py +++ b/rdt/transformers/numerical.py @@ -236,8 +236,9 @@ def _set_fitted_parameters( self._min_value = min(min_max_values) self._max_value = max(min_max_values) - if rounding_digits: + if rounding_digits is not None: self._rounding_digits = rounding_digits + self.learn_rounding_scheme = True if self.null_transformer.models_missing_values(): self.output_columns.append(column_name + '.is_null') diff --git a/tests/integration/transformers/test_numerical.py b/tests/integration/transformers/test_numerical.py index 056cb94a..93a78ab6 100644 --- a/tests/integration/transformers/test_numerical.py +++ b/tests/integration/transformers/test_numerical.py @@ -287,6 +287,26 @@ def test__support__nullable_numerical_pandas_dtypes(self): reverse_transformed[column].round(expected_rounding_digits[column]), ) + def test__set_fitted_parameter_rounding_to_integer(self): + """Test the ``_set_fitted_parameters`` method with rounding_digits set to 0.""" + # Setup + data = pd.DataFrame({ + 'col 1': 100 * np.random.random(10), + }) + transformer = FloatFormatter() + + # Run + transformer._set_fitted_parameters( + column_name='col 1', + null_transformer=NullTransformer(), + rounding_digits=0, + dtype='float', + ) + reverse_transformed_data = transformer.reverse_transform(data) + + # Assert + pd.testing.assert_frame_equal(reverse_transformed_data, data.round(0)) + class TestGaussianNormalizer: def test_stats(self): diff --git a/tests/unit/transformers/test_numerical.py b/tests/unit/transformers/test_numerical.py index 08a950ad..0af6b449 100644 --- a/tests/unit/transformers/test_numerical.py +++ b/tests/unit/transformers/test_numerical.py @@ -748,6 +748,7 @@ def test__set_fitted_parameters(self): assert transformer._max_value == 100.0 assert transformer._rounding_digits == rounding_digits assert transformer._dtype == dtype + assert transformer.learn_rounding_scheme is True def test__set_fitted_parameters_from_column(self): """Test ``_set_fitted_parameters`` sets the required parameters for transformer."""