Skip to content

Commit

Permalink
add logic to OneHotEncoder
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Jan 19, 2024
1 parent a55b143 commit 3b2606d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
7 changes: 6 additions & 1 deletion rdt/transformers/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ class OneHotEncoder(BaseTransformer):
_dummy_encoded = False
_indexer = None
_uniques = None
dtype = None

@staticmethod
def _prepare_data(data):
Expand Down Expand Up @@ -588,6 +589,7 @@ def _fit(self, data):
data (pandas.Series or pandas.DataFrame):
Data to fit the transformer to.
"""
self.dtype = data.dtype
data = self._prepare_data(data)

null = pd.isna(data).to_numpy()
Expand Down Expand Up @@ -663,15 +665,18 @@ def _reverse_transform(self, data):
Returns:
pandas.Series
"""
check_nan_in_transform(data, self.dtype)
if not isinstance(data, np.ndarray):
data = data.to_numpy()

if data.ndim == 1:
data = data.reshape(-1, 1)

indices = np.argmax(data, axis=1)
result = pd.Series(indices).map(self.dummies.__getitem__)
result = try_convert_to_dtype(result, self.dtype)

return pd.Series(indices).map(self.dummies.__getitem__)
return result


class LabelEncoder(BaseTransformer):
Expand Down
11 changes: 10 additions & 1 deletion tests/unit/transformers/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,7 @@ def test__fit_dummies_no_nans(self):

# Assert
np.testing.assert_array_equal(ohe.dummies, ['a', 2, 'c'])
assert ohe.dtype == 'object'

def test__fit_dummies_nans(self):
"""Test the ``_fit`` method without nans.
Expand Down Expand Up @@ -1810,11 +1811,14 @@ def test__transform_numeric(self):
assert not ohe._dummy_encoded
np.testing.assert_array_equal(out, expected)

def test__reverse_transform_no_nans(self):
@patch('rdt.transformers.categorical.check_nan_in_transform')
@patch('rdt.transformers.categorical.try_convert_to_dtype')
def test__reverse_transform_no_nans(self, mock_convert_dtype, mock_check_nan):
# Setup
ohe = OneHotEncoder()
data = pd.Series(['a', 'b', 'c'])
ohe._fit(data)
mock_convert_dtype.return_value = data

# Run
transformed = np.array([
Expand All @@ -1827,6 +1831,11 @@ def test__reverse_transform_no_nans(self):
# Assert
expected = pd.Series(['a', 'b', 'c'])
pd.testing.assert_series_equal(out, expected)
mock_input_data = mock_check_nan.call_args.args[0]
mock_input_dtype = mock_check_nan.call_args.args[1]
np.testing.assert_array_equal(mock_input_data, transformed)
assert mock_input_dtype == 'O'
mock_convert_dtype.assert_called_once()

def test__reverse_transform_nans(self):
# Setup
Expand Down

0 comments on commit 3b2606d

Please sign in to comment.