Skip to content

Commit

Permalink
Add LogitScaler transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Jan 17, 2025
1 parent d6ea5d1 commit 07ac708
Show file tree
Hide file tree
Showing 5 changed files with 367 additions and 3 deletions.
86 changes: 84 additions & 2 deletions rdt/transformers/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import pandas as pd
import scipy

from rdt.errors import TransformerInputError
from rdt.errors import InvalidDataError, TransformerInputError
from rdt.transformers.base import BaseTransformer
from rdt.transformers.null import NullTransformer
from rdt.transformers.utils import learn_rounding_digits
from rdt.transformers.utils import learn_rounding_digits, logit, sigmoid

EPSILON = np.finfo(np.float32).eps
INTEGER_BOUNDS = {
Expand Down Expand Up @@ -626,3 +626,85 @@ def _reverse_transform(self, data):
recovered_data = np.stack([recovered_data, data[:, -1]], axis=1) # noqa: PD013

return super()._reverse_transform(recovered_data)


class LogitScaler(FloatFormatter):
"""Transformer for numerical data by applying a logit function.
This transformer works by replacing the values with a scaled
version and then applying a logit function. The reverse transform
applies a sigmoid to the data and then scales it back to the original space.
Null values are replaced using a ``NullTransformer``.
Args:
missing_value_replacement (object):
Indicate what to replace the null values with. If an integer or float is given,
replace them with the given value. If the strings ``'mean'`` or ``'mode'``
are given, replace them with the corresponding aggregation and if ``'random'``
replace each null value with a random value in the data range. Defaults to ``mean``.
missing_value_generation (str or None):
The way missing values are being handled. There are three strategies:
* ``random``: Randomly generates missing values based on the percentage of
missing values.
* ``from_column``: Creates a binary column that describes whether the original
value was missing. Then use it to recreate missing values.
* ``None``: Do nothing with the missing values on the reverse transform. Simply
pass whatever data we get through.
min_value (float):
The min value for the logit function. Defaults to 0.
max_value (float):
max_value (float): The max value for the logit function. Defaults to 1.0.
learn_rounding_scheme (bool):
Whether or not to learn what place to round to based on the data seen during ``fit``.
If ``True``, the data returned by ``reverse_transform`` will be rounded to that place.
Defaults to ``False``.
"""

def __init__(
self,
missing_value_replacement='mean',
missing_value_generation='random',
min_value=0.0,
max_value=1.0,
learn_rounding_scheme=False,
):
super().__init__(
missing_value_replacement=missing_value_replacement,
missing_value_generation=missing_value_generation,
learn_rounding_scheme=learn_rounding_scheme,
)
self.min_value = min_value
self.max_value = max_value

def _validate_logit_inputs(self, data):
out_of_range_vals = data[(data < self.min_value) | (data > self.max_value)]
if len(out_of_range_vals) > 0:
num_vals_to_print = 5
out_of_range_vals = [str(x) for x in sorted(out_of_range_vals, key=lambda x: str(x))]
if len(out_of_range_vals) > 5:
extra_missing_vals = f'+ {len(out_of_range_vals) - num_vals_to_print} more'
out_of_range_vals = (

Check warning on line 688 in rdt/transformers/numerical.py

View check run for this annotation

Codecov / codecov/patch

rdt/transformers/numerical.py#L684-L688

Added lines #L684 - L688 were not covered by tests
f'[{", ".join(out_of_range_vals[:num_vals_to_print])} {extra_missing_vals}]'
)
else:
out_of_range_vals = f'[{", ".join(out_of_range_vals)}]'

Check warning on line 692 in rdt/transformers/numerical.py

View check run for this annotation

Codecov / codecov/patch

rdt/transformers/numerical.py#L692

Added line #L692 was not covered by tests

raise InvalidDataError(

Check warning on line 694 in rdt/transformers/numerical.py

View check run for this annotation

Codecov / codecov/patch

rdt/transformers/numerical.py#L694

Added line #L694 was not covered by tests
f"Unable to apply logit function to column '{self.columns[0]}' due to out of "
f'range values ({out_of_range_vals}).'
)

def _fit(self, data):
self._validate_logit_inputs(data)
return super()._fit(data)

def _transform(self, data):
transformed = super()._transform(data)
self._validate_logit_inputs(transformed)
return logit(transformed, self.min_value, self.max_value)

def _reverse_transform(self, data):
reversed = sigmoid(data, self.min_value, self.max_value)
return super()._reverse_transform(reversed)
56 changes: 56 additions & 0 deletions rdt/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import warnings
from collections import defaultdict
from decimal import Decimal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -126,6 +127,17 @@ def _from_generators(generators, max_repeat):
yield ''.join(reversed(generated))


def _cast_to_type(data, dtype):
if isinstance(data, pd.Series):
data = data.apply(dtype)
elif isinstance(data, (np.ndarray, list)):
data = np.array([dtype(value) for value in data])
else:
data = dtype(data)

Check warning on line 136 in rdt/transformers/utils.py

View check run for this annotation

Codecov / codecov/patch

rdt/transformers/utils.py#L136

Added line #L136 was not covered by tests

return data


def strings_from_regex(regex, max_repeat=16):
"""Generate strings that match the given regular expression.
Expand Down Expand Up @@ -280,6 +292,50 @@ def learn_rounding_digits(data):
return None


def logit(data, low, high):
"""Apply a logit function to the data using ``low`` and ``high``.
Args:
data (pd.Series, pd.DataFrame, np.array, int, or float):
Data to apply the logit function to.
low (pd.Series, np.array, int, or float):
Low value/s to use when scaling.
high (pd.Series, np.array, int, or float):
High value/s to use when scaling.
Returns:
Logit scaled version of the input data.
"""
data = (data - low) / (high - low)
data = _cast_to_type(data, Decimal)
data = data * Decimal(0.95) + Decimal(0.025)
data = _cast_to_type(data, float)
return np.log(data / (1.0 - data))


def sigmoid(data, low, high):
"""Apply a sigmoid function to the data using ``low`` and ``high``.
Args:
data (pd.Series, pd.DataFrame, np.array, int, float or datetime):
Data to apply the logit function to.
low (pd.Series, np.array, int, float or datetime):
Low value/s to use when scaling.
high (pd.Series, np.array, int, float or datetime):
High value/s to use when scaling.
Returns:
Sigmoid transform of the input data.
"""
data = 1 / (1 + np.exp(-data))
data = _cast_to_type(data, Decimal)
data = (data - Decimal(0.025)) / Decimal(0.95)
data = _cast_to_type(data, float)
data = data * (high - low) + low

return data


class WarnDict(dict):
"""Custom dictionary to raise a deprecation warning."""

Expand Down
19 changes: 19 additions & 0 deletions tests/integration/test_transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict

import numpy as np
import pandas as pd
import pytest

Expand All @@ -23,6 +24,12 @@
'FloatFormatter': {'missing_value_generation': 'from_column'},
'GaussianNormalizer': {'missing_value_generation': 'from_column'},
'ClusterBasedNormalizer': {'missing_value_generation': 'from_column'},
'LogitScaler': {
'FROM_DATA': {
'min_value': lambda x: np.nanmin(x) - 1,
'max_value': lambda x: np.nanmax(x) + 1,
}
},
}

# Mapping of rdt sdtype to dtype
Expand Down Expand Up @@ -149,6 +156,12 @@ def _test_transformer_with_dataset(transformer_class, input_data, steps):
"""

transformer_args = TRANSFORMER_ARGS.get(transformer_class.__name__, {})
if 'FROM_DATA' in transformer_args:
transformer_args = {**transformer_args}
args = transformer_args.pop('FROM_DATA')
for arg, arg_func in args.items():
transformer_args[arg] = arg_func(input_data[TEST_COL])

transformer = transformer_class(**transformer_args)
# Fit
transformer.fit(input_data, [TEST_COL])
Expand Down Expand Up @@ -203,6 +216,12 @@ def _test_transformer_with_hypertransformer(transformer_class, input_data, steps
transformer_args = TRANSFORMER_ARGS.get(transformer_class.__name__, {})
hypertransformer = HyperTransformer()
if transformer_args:
if 'FROM_DATA' in transformer_args:
transformer_args = {**transformer_args}
args = transformer_args.pop('FROM_DATA')
for arg, arg_func in args.items():
transformer_args[arg] = arg_func(input_data[TEST_COL])

field_transformers = {TEST_COL: transformer_class(**transformer_args)}

else:
Expand Down
120 changes: 119 additions & 1 deletion tests/unit/transformers/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from copulas import univariate
from pandas.api.types import is_float_dtype

from rdt.errors import TransformerInputError
from rdt.errors import InvalidDataError, TransformerInputError
from rdt.transformers.null import NullTransformer
from rdt.transformers.numerical import (
ClusterBasedNormalizer,
FloatFormatter,
GaussianNormalizer,
LogitScaler,
)


Expand Down Expand Up @@ -1863,3 +1864,120 @@ def test__reverse_transform_missing_value_replacement_missing_value_replacement_
call_data,
rtol=1e-1,
)


class TestLogitScaler:
def test___init__super_attrs(self):
"""Test super() arguments are properly passed and set as attributes."""
# Run
ls = LogitScaler(
missing_value_generation='random',
learn_rounding_scheme=False,
)

# Assert
assert ls.missing_value_replacement == 'mean'
assert ls.missing_value_generation == 'random'
assert ls.learn_rounding_scheme is False

def test___init__(self):
"""Test super() arguments are properly passed and set as attributes."""
# Run
ls = LogitScaler(max_value=100.0, min_value=2.0)

# Assert
assert ls.max_value == 100.0
assert ls.min_value == 2.0

def test__validate_logit_inputs(self):
"""Test validating data against input arguments."""
# Setup
ls = LogitScaler()
data = pd.Series([0.0, 0.1, 0.2, 0.3, 1.0])

# Run and Assert
ls._validate_logit_inputs(data)

def test__validate_logit_inputs_errors_invalid_value(self):
"""Test error message contains invalid values."""
# Setup
ls = LogitScaler()
ls.columns = ['column']
data = pd.Series([0.0, 0.1, 0.2, 0.3, 1.0, 2.0])

# Run and Assert
expected_msg = re.escape(
"Unable to apply logit function to column 'column' due to out of range values ([2.0])."
)
with pytest.raises(InvalidDataError, match=expected_msg):
ls._validate_logit_inputs(data)

def test__validate_logit_inputs_errors_many_invalid_values(self):
"""Test error message clips many invalid values."""
# Setup
ls = LogitScaler()
ls.columns = ['column']
data = pd.Series([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])

# Run and Assert
expected_msg = re.escape(
"Unable to apply logit function to column 'column' due to out of range values "
'([1.1, 1.2, 1.3, 2.0, 3.0 + 1 more]).'
)
with pytest.raises(InvalidDataError, match=expected_msg):
ls._validate_logit_inputs(data)

def test__fit(self):
"""Test the ``_fit`` method validates the inputs."""
# Setup
ls = LogitScaler()
ls._validate_logit_inputs = Mock()
data = pd.Series([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])

# Run
ls._fit(data)

# Assert
ls._validate_logit_inputs.assert_called_once_with(data)

@patch('rdt.transformers.numerical.logit')
def test__transform(self, mock_logit):
"""Test the ``transform`` method."""
# Setup
min_value = (1.0,)
max_value = 50.0
ls = LogitScaler(min_value=min_value, max_value=max_value)
ls._validate_logit_inputs = Mock()
data = pd.Series([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])
null_transformer_mock = Mock()
null_transformer_mock.transform.return_value = data
ls.null_transformer = null_transformer_mock

# Run
transformed = ls._transform(data)

# Assert
ls._validate_logit_inputs.assert_called_once_with(data)
mock_logit.assert_called_once_with(data, ls.min_value, ls.max_value)
assert transformed == mock_logit.return_value

@patch('rdt.transformers.numerical.FloatFormatter._reverse_transform')
@patch('rdt.transformers.numerical.sigmoid')
def test__reverse_transform(self, mock_sigmoid, ff_reverse_transform_mock):
"""Test the ``transform`` method."""
# Setup
min_value = (1.0,)
max_value = 50.0
ls = LogitScaler(min_value=min_value, max_value=max_value)
data = pd.Series([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])
null_transformer_mock = Mock()
null_transformer_mock.reverse_transform.return_value = data
ls.null_transformer = null_transformer_mock

# Run
reversed = ls._reverse_transform(data)

# Assert
mock_sigmoid.assert_called_once_with(data, ls.min_value, ls.max_value)
ff_reverse_transform_mock.assert_called_once_with(mock_sigmoid.return_value)
assert reversed == ff_reverse_transform_mock.return_value
Loading

0 comments on commit 07ac708

Please sign in to comment.