Skip to content

Commit

Permalink
[MRG+1] issue scikit-learn#6532 Add inverse_transform function (sci…
Browse files Browse the repository at this point in the history
…kit-learn#6570)

* [MRG+1] scikit-learn#6532 Add inverse_func argument to FunctionTransformer


* modify test:inverse_func is not true inverse
  • Loading branch information
facaiy authored and MechCoder committed Apr 12, 2016
1 parent 45ff64b commit 1d487fb
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
26 changes: 22 additions & 4 deletions sklearn/preprocessing/_function_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ class FunctionTransformer(BaseEstimator, TransformerMixin):
the same arguments as transform, with args and kwargs forwarded.
If func is None, then func will be the identity function.
inverse_func : callable, optional default=None
The callable to use for the inverse transformation. This will be
passed the same arguments as inverse transform, with args and
kwargs forwarded. If inverse_func is None, then inverse_func
will be the identity function.
validate : bool, optional default=True
Indicate that the input X array should be checked before calling
func. If validate is false, there will be no input validation.
Expand All @@ -49,26 +55,38 @@ class FunctionTransformer(BaseEstimator, TransformerMixin):
kw_args : dict, optional
Dictionary of additional keyword arguments to pass to func.
inv_kw_args : dict, optional
Dictionary of additional keyword arguments to pass to inverse_func.
"""
def __init__(self, func=None, validate=True,
def __init__(self, func=None, inverse_func=None, validate=True,
accept_sparse=False, pass_y=False,
kw_args=None):
kw_args=None, inv_kw_args=None):
self.func = func
self.inverse_func = inverse_func
self.validate = validate
self.accept_sparse = accept_sparse
self.pass_y = pass_y
self.kw_args = kw_args
self.inv_kw_args = inv_kw_args

def fit(self, X, y=None):
if self.validate:
check_array(X, self.accept_sparse)
return self

def transform(self, X, y=None):
return self._transform(X, y, self.func, self.kw_args)

def inverse_transform(self, X, y=None):
return self._transform(X, y, self.inverse_func, self.inv_kw_args)

def _transform(self, X, y=None, func=None, kw_args=None):
if self.validate:
X = check_array(X, self.accept_sparse)
func = self.func if self.func is not None else _identity

if func is None:
func = _identity

return func(X, *((y,) if self.pass_y else ()),
**(self.kw_args if self.kw_args else {}))
**(kw_args if kw_args else {}))
14 changes: 13 additions & 1 deletion sklearn/preprocessing/tests/test_function_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,16 @@ def test_kw_arg_reset():

# Test that rounding is correct
testing.assert_array_equal(F.transform(X),
np.around(X, decimals=1))
np.around(X, decimals=1))


def test_inverse_transform():
X = np.array([1, 4, 9, 16]).reshape((2, 2))

# Test that inverse_transform works correctly
F = FunctionTransformer(
func=np.sqrt,
inverse_func=np.around, inv_kw_args=dict(decimals=3))
testing.assert_array_equal(
F.inverse_transform(F.transform(X)),
np.around(np.sqrt(X), decimals=3))

0 comments on commit 1d487fb

Please sign in to comment.