From 30c749cc11a593e699acd9e29aeac494115f0f10 Mon Sep 17 00:00:00 2001 From: Felipe Date: Thu, 3 Oct 2024 09:47:20 -0700 Subject: [PATCH 1/7] Remove % operando --- pyproject.toml | 1 + rdt/transformers/numerical.py | 2 +- rdt/transformers/utils.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8ed11836..4ed7a0ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ test = [ 'rundoc>=0.4.3,<0.5', 'pytest-subtests>=0.5,<1.0', 'pytest-runner >= 2.11.1', + 'pyarrow', 'tomli>=2.0.0,<3', ] dev = [ diff --git a/rdt/transformers/numerical.py b/rdt/transformers/numerical.py index 1425670b..12bc345e 100644 --- a/rdt/transformers/numerical.py +++ b/rdt/transformers/numerical.py @@ -104,7 +104,7 @@ def _raise_out_of_bounds_error(self, value, name, bound_type, min_bound, max_bou def _validate_values_within_bounds(self, data): if not self.computer_representation.startswith('Float'): - fractions = data[~data.isna() & data % 1 != 0] + fractions = data[~data.isna() & (data != (data // 1))] if not fractions.empty: raise ValueError( f"The column '{data.name}' contains float values {fractions.tolist()}. " diff --git a/rdt/transformers/utils.py b/rdt/transformers/utils.py index e2053677..a0ddaac4 100644 --- a/rdt/transformers/utils.py +++ b/rdt/transformers/utils.py @@ -261,7 +261,7 @@ def learn_rounding_digits(data): return None # Doesn't contain decimal digits - if ((roundable_data % 1) == 0).all(): + if (roundable_data == roundable_data.astype(int)).all(): return 0 # Try to round to fewer digits From 98255ba8c9f0465d50547b99ca57b6d226650e28 Mon Sep 17 00:00:00 2001 From: Felipe Date: Fri, 4 Oct 2024 09:39:12 -0700 Subject: [PATCH 2/7] feedback --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4ed7a0ef..1c9c26b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ test = [ 'rundoc>=0.4.3,<0.5', 'pytest-subtests>=0.5,<1.0', 'pytest-runner >= 2.11.1', - 'pyarrow', + 'pyarrow >= 17.0.0', 'tomli>=2.0.0,<3', ] dev = [ From 85bcb01dd37c532be4dc1cfa2312d64db2026241 Mon Sep 17 00:00:00 2001 From: Felipe Date: Tue, 8 Oct 2024 08:37:44 -0700 Subject: [PATCH 3/7] Add tests --- pyproject.toml | 3 ++- tests/unit/transformers/test_numerical.py | 11 +++++++++++ tests/unit/transformers/test_utils.py | 13 +++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1c9c26b2..04006a52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,9 @@ rdt = { main = 'rdt.cli.__main__:main' } [project.optional-dependencies] copulas = ['copulas>=0.11.0',] +pyarrow = ['pyarrow>=17.0.0'] test = [ + 'rdt[pyarrow]', 'rdt[copulas]', 'pytest>=3.4.2', @@ -58,7 +60,6 @@ test = [ 'rundoc>=0.4.3,<0.5', 'pytest-subtests>=0.5,<1.0', 'pytest-runner >= 2.11.1', - 'pyarrow >= 17.0.0', 'tomli>=2.0.0,<3', ] dev = [ diff --git a/tests/unit/transformers/test_numerical.py b/tests/unit/transformers/test_numerical.py index 0af6b449..a878075f 100644 --- a/tests/unit/transformers/test_numerical.py +++ b/tests/unit/transformers/test_numerical.py @@ -5,6 +5,7 @@ import copulas import numpy as np import pandas as pd +import pyarrow as pa import pytest from copulas import univariate from pandas.api.types import is_float_dtype @@ -44,6 +45,16 @@ def test__validate_values_within_bounds(self): # Run transformer._validate_values_within_bounds(data) + def test__validate_values_within_bounds_pyarrow(self): + """Test it works with pyarrow.""" + # Setup + data = pd.Series(range(10), dtype='int64[pyarrow]') + transformer = FloatFormatter() + transformer.computer_representation = 'UInt8' + + # Run + transformer._validate_values_within_bounds(data) + def test__validate_values_within_bounds_under_minimum(self): """Test the ``_validate_values_within_bounds`` method. diff --git a/tests/unit/transformers/test_utils.py b/tests/unit/transformers/test_utils.py index c2bccead..164b3a20 100644 --- a/tests/unit/transformers/test_utils.py +++ b/tests/unit/transformers/test_utils.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd +import pyarrow as pa import pytest from rdt.transformers.utils import ( @@ -225,6 +226,18 @@ def test_learn_rounding_digits_less_than_15_decimals(): assert output == 3 +def test_learn_rounding_digits_pyarrow(): + """Test it works with pyarrow.""" + # Setup + data = pd.Series(range(10), dtype='int64[pyarrow]') + + # Run + output = learn_rounding_digits(data) + + # Assert + assert output == 0 + + def test_learn_rounding_digits_negative_decimals_float(): """Test the learn_rounding_digits method with floats multiples of powers of 10. From 3e9bd7a954420925dd8b36da59fc7fd36667228f Mon Sep 17 00:00:00 2001 From: Felipe Date: Tue, 8 Oct 2024 11:08:03 -0700 Subject: [PATCH 4/7] Add new test --- rdt/transformers/utils.py | 5 ++++- tests/unit/transformers/test_numerical.py | 1 - tests/unit/transformers/test_utils.py | 13 ++++++++++++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/rdt/transformers/utils.py b/rdt/transformers/utils.py index a0ddaac4..8be7609d 100644 --- a/rdt/transformers/utils.py +++ b/rdt/transformers/utils.py @@ -252,8 +252,11 @@ def learn_rounding_digits(data): int or None: Number of digits to round to. """ - # check if data has any decimals name = data.name + if isinstance(data.dtype, pd.ArrowDtype): + data = data.to_numpy() + + # check if data has any decimals roundable_data = data[~(np.isinf(data.astype(float)) | pd.isna(data))] # Doesn't contain numbers diff --git a/tests/unit/transformers/test_numerical.py b/tests/unit/transformers/test_numerical.py index a878075f..314b6206 100644 --- a/tests/unit/transformers/test_numerical.py +++ b/tests/unit/transformers/test_numerical.py @@ -5,7 +5,6 @@ import copulas import numpy as np import pandas as pd -import pyarrow as pa import pytest from copulas import univariate from pandas.api.types import is_float_dtype diff --git a/tests/unit/transformers/test_utils.py b/tests/unit/transformers/test_utils.py index 164b3a20..4a0ba2ea 100644 --- a/tests/unit/transformers/test_utils.py +++ b/tests/unit/transformers/test_utils.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd -import pyarrow as pa import pytest from rdt.transformers.utils import ( @@ -238,6 +237,18 @@ def test_learn_rounding_digits_pyarrow(): assert output == 0 +def test_learn_rounding_digits_pyarrow_float(): + """Test it learns the proper amount of digits with pyarrow.""" + # Setup + data = pd.Series([0.5, 0.19, 3], dtype='float64[pyarrow]') + + # Run + output = learn_rounding_digits(data) + + # Assert + assert output == 2 + + def test_learn_rounding_digits_negative_decimals_float(): """Test the learn_rounding_digits method with floats multiples of powers of 10. From cbe1abf66da427890a74137b63acae361a47793e Mon Sep 17 00:00:00 2001 From: Felipe Date: Wed, 9 Oct 2024 09:17:42 -0700 Subject: [PATCH 5/7] Fix pyarrow for old pandas --- rdt/transformers/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/rdt/transformers/utils.py b/rdt/transformers/utils.py index 8be7609d..faa4d280 100644 --- a/rdt/transformers/utils.py +++ b/rdt/transformers/utils.py @@ -252,8 +252,7 @@ def learn_rounding_digits(data): int or None: Number of digits to round to. """ - name = data.name - if isinstance(data.dtype, pd.ArrowDtype): + if str(data.dtype).endswith("[pyarrow]"): data = data.to_numpy() # check if data has any decimals @@ -276,7 +275,7 @@ def learn_rounding_digits(data): # Can't round, not equal after MAX_DECIMALS digits of precision LOGGER.info( "No rounding scheme detected for column '%s'. Data will not be rounded.", - name, + data.name, ) return None From a28f04a870fd495686289fb20c2a6c870fff7ceb Mon Sep 17 00:00:00 2001 From: Felipe Date: Wed, 9 Oct 2024 10:03:53 -0700 Subject: [PATCH 6/7] Update test --- rdt/transformers/utils.py | 1 + tests/unit/transformers/test_numerical.py | 5 ++++- tests/unit/transformers/test_utils.py | 10 ++++++++-- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/rdt/transformers/utils.py b/rdt/transformers/utils.py index faa4d280..6a851d98 100644 --- a/rdt/transformers/utils.py +++ b/rdt/transformers/utils.py @@ -252,6 +252,7 @@ def learn_rounding_digits(data): int or None: Number of digits to round to. """ + # check it ends with pyarrow if str(data.dtype).endswith("[pyarrow]"): data = data.to_numpy() diff --git a/tests/unit/transformers/test_numerical.py b/tests/unit/transformers/test_numerical.py index 314b6206..a947e518 100644 --- a/tests/unit/transformers/test_numerical.py +++ b/tests/unit/transformers/test_numerical.py @@ -47,7 +47,10 @@ def test__validate_values_within_bounds(self): def test__validate_values_within_bounds_pyarrow(self): """Test it works with pyarrow.""" # Setup - data = pd.Series(range(10), dtype='int64[pyarrow]') + try: + data = pd.Series(range(10), dtype='int64[pyarrow]') + except TypeError: + pytest.skip("Skipping as old numpy/pandas versions don't support arrow") transformer = FloatFormatter() transformer.computer_representation = 'UInt8' diff --git a/tests/unit/transformers/test_utils.py b/tests/unit/transformers/test_utils.py index 4a0ba2ea..a99843f9 100644 --- a/tests/unit/transformers/test_utils.py +++ b/tests/unit/transformers/test_utils.py @@ -228,7 +228,10 @@ def test_learn_rounding_digits_less_than_15_decimals(): def test_learn_rounding_digits_pyarrow(): """Test it works with pyarrow.""" # Setup - data = pd.Series(range(10), dtype='int64[pyarrow]') + try: + data = pd.Series(range(10), dtype='int64[pyarrow]') + except TypeError: + pytest.skip("Skipping as old numpy/pandas versions don't support arrow") # Run output = learn_rounding_digits(data) @@ -240,7 +243,10 @@ def test_learn_rounding_digits_pyarrow(): def test_learn_rounding_digits_pyarrow_float(): """Test it learns the proper amount of digits with pyarrow.""" # Setup - data = pd.Series([0.5, 0.19, 3], dtype='float64[pyarrow]') + try: + data = pd.Series([0.5, 0.19, 3], dtype='float64[pyarrow]') + except TypeError: + pytest.skip("Skipping as old numpy/pandas versions don't support arrow") # Run output = learn_rounding_digits(data) From 00e7e0d5b0ef91bbe450486c356d2120f64c58d3 Mon Sep 17 00:00:00 2001 From: Felipe Date: Fri, 11 Oct 2024 09:24:48 -0700 Subject: [PATCH 7/7] fix lint --- rdt/transformers/utils.py | 9 ++++----- tests/unit/transformers/test_utils.py | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/rdt/transformers/utils.py b/rdt/transformers/utils.py index 6a851d98..9544ff1e 100644 --- a/rdt/transformers/utils.py +++ b/rdt/transformers/utils.py @@ -252,11 +252,10 @@ def learn_rounding_digits(data): int or None: Number of digits to round to. """ - # check it ends with pyarrow - if str(data.dtype).endswith("[pyarrow]"): - data = data.to_numpy() - # check if data has any decimals + name = data.name + if str(data.dtype).endswith('[pyarrow]'): + data = data.to_numpy() roundable_data = data[~(np.isinf(data.astype(float)) | pd.isna(data))] # Doesn't contain numbers @@ -276,7 +275,7 @@ def learn_rounding_digits(data): # Can't round, not equal after MAX_DECIMALS digits of precision LOGGER.info( "No rounding scheme detected for column '%s'. Data will not be rounded.", - data.name, + name, ) return None diff --git a/tests/unit/transformers/test_utils.py b/tests/unit/transformers/test_utils.py index a99843f9..f37ee58b 100644 --- a/tests/unit/transformers/test_utils.py +++ b/tests/unit/transformers/test_utils.py @@ -1,7 +1,7 @@ import sre_parse import warnings from sre_constants import MAXREPEAT -from unittest.mock import patch +from unittest.mock import Mock, patch import numpy as np import pandas as pd @@ -329,6 +329,20 @@ def test_learn_rounding_digits_nullable_numerical_pandas_dtypes(): assert output == expected_output[column] +def test_learn_rounding_digits_pyarrow_to_numpy(): + """Test that ``learn_rounding_digits`` works with pyarrow to numpy conversion.""" + # Setup + data = Mock() + data.dtype = 'int64[pyarrow]' + data.to_numpy.return_value = np.array([1, 2, 3]) + + # Run + learn_rounding_digits(data) + + # Assert + assert data.to_numpy.called + + def test_warn_dict(): """Test that ``WarnDict`` will raise a warning when called with `text`.""" # Setup