From 9578a4f68d71883c1617b3311979c4643c352884 Mon Sep 17 00:00:00 2001 From: kiran-kate <40038037+kiran-kate@users.noreply.github.com> Date: Sat, 15 Jan 2022 07:13:13 -0500 Subject: [PATCH] Returning trained custom pipeline objects when fitted=True. (#950) --- lale/helpers.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/lale/helpers.py b/lale/helpers.py index 87d072e30..518c7e44c 100644 --- a/lale/helpers.py +++ b/lale/helpers.py @@ -640,7 +640,19 @@ def find_lale_wrapper(sklearn_obj): # This is a custom subclass of sklearn pipeline, so use the wrapper class # instead of creating a lale pipeline # We assume it has a hyperparameter `steps`. - lale_op_obj = wrapper_class(steps=nested_pipeline_lale_named_steps) + if ( + not fitted + ): # If fitted is False, we do not want to return a Trained operator. + lale_op = wrapper_class + else: + lale_op = lale.operators.TrainedIndividualOp( + wrapper_class._name, + wrapper_class._impl, + wrapper_class._schemas, + None, + _lale_trained=True, + ) + lale_op_obj = lale_op(steps=nested_pipeline_lale_named_steps) else: # no conversion to lale if a wrapper is not found for a subclass of pipeline return sklearn_pipeline elif isinstance(sklearn_pipeline, sklearn.pipeline.FeatureUnion):