Skip to content

Commit

Permalink
FloatFormatter does not round the data correctly for integer columns …
Browse files Browse the repository at this point in the history
…when using _set_fitted_parameters (#875)
  • Loading branch information
R-Palazzo authored Aug 28, 2024
1 parent bb3262e commit 50cb707
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
3 changes: 2 additions & 1 deletion rdt/transformers/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ def _set_fitted_parameters(
self._min_value = min(min_max_values)
self._max_value = max(min_max_values)

if rounding_digits:
if rounding_digits is not None:
self._rounding_digits = rounding_digits
self.learn_rounding_scheme = True

if self.null_transformer.models_missing_values():
self.output_columns.append(column_name + '.is_null')
Expand Down
20 changes: 20 additions & 0 deletions tests/integration/transformers/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,26 @@ def test__support__nullable_numerical_pandas_dtypes(self):
reverse_transformed[column].round(expected_rounding_digits[column]),
)

def test__set_fitted_parameter_rounding_to_integer(self):
"""Test the ``_set_fitted_parameters`` method with rounding_digits set to 0."""
# Setup
data = pd.DataFrame({
'col 1': 100 * np.random.random(10),
})
transformer = FloatFormatter()

# Run
transformer._set_fitted_parameters(
column_name='col 1',
null_transformer=NullTransformer(),
rounding_digits=0,
dtype='float',
)
reverse_transformed_data = transformer.reverse_transform(data)

# Assert
pd.testing.assert_frame_equal(reverse_transformed_data, data.round(0))


class TestGaussianNormalizer:
def test_stats(self):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/transformers/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ def test__set_fitted_parameters(self):
assert transformer._max_value == 100.0
assert transformer._rounding_digits == rounding_digits
assert transformer._dtype == dtype
assert transformer.learn_rounding_scheme is True

def test__set_fitted_parameters_from_column(self):
"""Test ``_set_fitted_parameters`` sets the required parameters for transformer."""
Expand Down

0 comments on commit 50cb707

Please sign in to comment.