From 0394ccd9dfab82e3356b68f6ff4b103030881aac Mon Sep 17 00:00:00 2001 From: Brandon Bodnar Date: Fri, 13 Sep 2024 05:06:03 -0400 Subject: [PATCH] Support fill_value for SimpleImputer with string data (#1123) Signed-off-by: Brandon Bodnar --- skl2onnx/operator_converters/imputer_op.py | 16 +-- tests/test_sklearn_imputer_converter.py | 143 +++++++++++++++++++++ 2 files changed, 151 insertions(+), 8 deletions(-) diff --git a/skl2onnx/operator_converters/imputer_op.py b/skl2onnx/operator_converters/imputer_op.py index 8b466f674..fc227af3a 100644 --- a/skl2onnx/operator_converters/imputer_op.py +++ b/skl2onnx/operator_converters/imputer_op.py @@ -17,14 +17,6 @@ def convert_sklearn_imputer( op_type = "Imputer" attrs = {"name": scope.get_unique_operator_name(op_type)} op = operator.raw_operator - if ( - hasattr(op, "fill_value") - and isinstance(op.fill_value, str) - and op.fill_value.lower() != "nan" - ): - raise RuntimeError( - "Imputer cannot fill missing values with a string '%s'." % op.fill_value - ) if not hasattr(op, "statistics_"): raise RuntimeError("Member statistics_ is not present, was the model fitted?") @@ -86,6 +78,14 @@ def convert_sklearn_imputer( apply_concat(scope, names, operator.outputs[0].full_name, container, axis=1) else: + if ( + hasattr(op, "fill_value") + and isinstance(op.fill_value, str) + and op.fill_value.lower() != "nan" + ): + raise RuntimeError( + "Imputer cannot fill missing values with a string '%s'." % op.fill_value + ) if isinstance(operator.inputs[0].type, Int64TensorType): attrs["imputed_value_int64s"] = op.statistics_.astype(np.int64) use_int = True diff --git a/tests/test_sklearn_imputer_converter.py b/tests/test_sklearn_imputer_converter.py index 331698d65..5e62affe1 100644 --- a/tests/test_sklearn_imputer_converter.py +++ b/tests/test_sklearn_imputer_converter.py @@ -51,6 +51,15 @@ def _check_outputs_ints(self, model, model_onnx, data): exp = model.transform(data) assert_almost_equal(res, exp) + def _check_outputs_floats(self, model, model_onnx, data): + sess = InferenceSession( + model_onnx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + idata = {"input": np.array(data).astype(np.float32)} + res = sess.run(None, idata)[0] + exp = model.transform(data) + assert_almost_equal(res, exp) + def _check_outputs_strings(self, model, model_onnx, data, verbose=0): idata = {"input": np.array(data).astype(np.str_)} sess = InferenceSession( @@ -206,6 +215,140 @@ def test_simple_imputer_string_inputs_int_mostf_default(self): self.assertEqual(len(model_onnx.graph.output), 1) self._check_outputs_strings(model, model_onnx, data) + @unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20") + def test_simple_imputer_float_constant_default_fill_value(self): + model = SimpleImputer(strategy="constant") + data = [[1, 2], [np.nan, 3], [7, 6]] + model.fit(data) + + model_onnx = convert_sklearn( + model, + "scikit-learn simple imputer", + [("input", FloatTensorType([None, 2]))], + target_opset=TARGET_OPSET, + ) + self.assertIsNotNone(model_onnx.graph.node) + + # should contain only node + self.assertEqual(len(model_onnx.graph.node), 1) + + # last node should contain the Imputer + outputs = model_onnx.graph.output + self.assertEqual(len(outputs), 1) + self._check_outputs_floats(model, model_onnx, data) + + @unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20") + def test_simple_imputer_float_constant_provided_fill_value(self): + model = SimpleImputer(strategy="constant", fill_value=99.0) + data = [[1, 2], [np.nan, 3], [7, 6]] + model.fit(data) + + model_onnx = convert_sklearn( + model, + "scikit-learn simple imputer", + [("input", FloatTensorType([None, 2]))], + target_opset=TARGET_OPSET, + ) + self.assertIsNotNone(model_onnx.graph.node) + + # should contain only node + self.assertEqual(len(model_onnx.graph.node), 1) + + # last node should contain the Imputer + outputs = model_onnx.graph.output + self.assertEqual(len(outputs), 1) + self._check_outputs_floats(model, model_onnx, data) + + @unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20") + def test_simple_imputer_int_constant_default_fill_value(self): + model = SimpleImputer(strategy="constant") + data = [[1, 2], [np.nan, 3], [7, 6], [8, np.nan]] + model.fit(data) + + model_onnx = convert_sklearn( + model, + "scikit-learn simple imputer", + [("input", Int64TensorType([None, 2]))], + target_opset=TARGET_OPSET, + ) + self.assertIsNotNone(model_onnx.graph.node) + + # should contain only node + self.assertEqual(len(model_onnx.graph.node), 1) + + # last node should contain the Imputer + outputs = model_onnx.graph.output + self.assertEqual(len(outputs), 1) + self._check_outputs_ints(model, model_onnx, data) + + @unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20") + def test_simple_imputer_int_constant_provided_fill_value(self): + model = SimpleImputer(strategy="constant", fill_value=99) + data = [[1, 2], [np.nan, 3], [7, 6], [8, np.nan]] + model.fit(data) + + model_onnx = convert_sklearn( + model, + "scikit-learn simple imputer", + [("input", Int64TensorType([None, 2]))], + target_opset=TARGET_OPSET, + ) + self.assertIsNotNone(model_onnx.graph.node) + + # should contain only node + self.assertEqual(len(model_onnx.graph.node), 1) + + # last node should contain the Imputer + outputs = model_onnx.graph.output + self.assertEqual(len(outputs), 1) + self._check_outputs_ints(model, model_onnx, data) + + @unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20") + @unittest.skipIf( + pv.Version(skl_ver) < pv.Version("0.24"), + reason="SimpleImputer does not support strings", + ) + def test_simple_imputer_string_inputs_constant_provided_fill_value(self): + model = SimpleImputer( + strategy="constant", missing_values="", fill_value="missing" + ) + data = pd.DataFrame( + [["s1", "s2"], ["s1", "s2"], ["", "s3"], ["s7", "s6"], ["s8", ""]] + ) + model.fit(data) + model_onnx = convert_sklearn( + model, + "scikit-learn simple imputer", + [("input", StringTensorType([None, 2]))], + target_opset=TARGET_OPSET, + ) + self.assertIn("ai.onnx.ml", str(model_onnx)) + self.assertIsNotNone(model_onnx.graph.node) + self.assertEqual(len(model_onnx.graph.output), 1) + self._check_outputs_strings(model, model_onnx, data) + + @unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20") + @unittest.skipIf( + pv.Version(skl_ver) < pv.Version("0.24"), + reason="SimpleImputer does not support strings", + ) + def test_simple_imputer_string_inputs_constant_default_fill_value(self): + model = SimpleImputer(strategy="constant", missing_values="") + data = pd.DataFrame( + [["s1", "s2"], ["s1", "s2"], ["", "s3"], ["s7", "s6"], ["s8", ""]] + ) + model.fit(data) + model_onnx = convert_sklearn( + model, + "scikit-learn simple imputer", + [("input", StringTensorType([None, 2]))], + target_opset=TARGET_OPSET, + ) + self.assertIn("ai.onnx.ml", str(model_onnx)) + self.assertIsNotNone(model_onnx.graph.node) + self.assertEqual(len(model_onnx.graph.output), 1) + self._check_outputs_strings(model, model_onnx, data) + if __name__ == "__main__": unittest.main()