Skip to content

Commit

Permalink
Returning trained custom pipeline objects when fitted=True. (#950)
Browse files Browse the repository at this point in the history
  • Loading branch information
kiran-kate authored Jan 15, 2022
1 parent 965f003 commit 9578a4f
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion lale/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,19 @@ def find_lale_wrapper(sklearn_obj):
# This is a custom subclass of sklearn pipeline, so use the wrapper class
# instead of creating a lale pipeline
# We assume it has a hyperparameter `steps`.
lale_op_obj = wrapper_class(steps=nested_pipeline_lale_named_steps)
if (
not fitted
): # If fitted is False, we do not want to return a Trained operator.
lale_op = wrapper_class
else:
lale_op = lale.operators.TrainedIndividualOp(
wrapper_class._name,
wrapper_class._impl,
wrapper_class._schemas,
None,
_lale_trained=True,
)
lale_op_obj = lale_op(steps=nested_pipeline_lale_named_steps)
else: # no conversion to lale if a wrapper is not found for a subclass of pipeline
return sklearn_pipeline
elif isinstance(sklearn_pipeline, sklearn.pipeline.FeatureUnion):
Expand Down

0 comments on commit 9578a4f

Please sign in to comment.