Skip to content

Commit

Permalink
cross val predict fix
Browse files Browse the repository at this point in the history
  • Loading branch information
perib committed Oct 1, 2024
1 parent 80d8da6 commit cd4f27b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tpot2/builtin_modules/estimatortransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sklearn.utils.validation import check_is_fitted

class EstimatorTransformer(BaseEstimator, TransformerMixin):
def __init__(self, estimator, method='auto', passthrough=False, cross_val_predict_cv=0):
def __init__(self, estimator, method='auto', passthrough=False, cross_val_predict_cv=None):
"""
A class for using a sklearn estimator as a transformer. When calling fit_transform, this class returns the out put of cross_val_predict
and trains the estimator on the full dataset. When calling transform, this class uses the estimator fit on the full dataset to transform the data.
Expand Down Expand Up @@ -83,7 +83,7 @@ def fit_transform(self, X, y=None):
else:
method = self.method

if self.cross_val_predict_cv > 0:
if self.cross_val_predict_cv is not None:
output = cross_val_predict(self.estimator, X, y=y, cv=self.cross_val_predict_cv)
else:
output = getattr(self.estimator, method)(X)
Expand Down

0 comments on commit cd4f27b

Please sign in to comment.