From 2127f26ec67ead8e2caaae92531f7d1a9bf38c1c Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Sun, 21 Apr 2024 22:36:27 +0300 Subject: [PATCH] Added HasTree marker interface --- .../sklearn2pmml/tree/CHAIDClassifier.java | 7 +++-- .../sklearn2pmml/tree/CHAIDRegressor.java | 7 +++-- .../java/sklearn2pmml/tree/CHAIDUtil.java | 5 +++- .../main/java/sklearn2pmml/tree/HasTree.java | 26 +++++++++++++++++++ 4 files changed, 36 insertions(+), 9 deletions(-) create mode 100644 pmml-sklearn/src/main/java/sklearn2pmml/tree/HasTree.java diff --git a/pmml-sklearn/src/main/java/sklearn2pmml/tree/CHAIDClassifier.java b/pmml-sklearn/src/main/java/sklearn2pmml/tree/CHAIDClassifier.java index fc97565a1..ce44fb31c 100644 --- a/pmml-sklearn/src/main/java/sklearn2pmml/tree/CHAIDClassifier.java +++ b/pmml-sklearn/src/main/java/sklearn2pmml/tree/CHAIDClassifier.java @@ -28,7 +28,7 @@ import sklearn.HasApplyField; import treelib.Tree; -public class CHAIDClassifier extends Classifier implements HasApplyField { +public class CHAIDClassifier extends Classifier implements HasApplyField, HasTree { public CHAIDClassifier(String module, String name){ super(module, name); @@ -41,11 +41,9 @@ public String getApplyField(){ @Override public TreeModel encodeModel(Schema schema){ - Tree tree = getTree(); - CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel(); - TreeModel treeModel = CHAIDUtil.encodeModel(MiningFunction.CLASSIFICATION, tree, schema); + TreeModel treeModel = CHAIDUtil.encodeModel(this, MiningFunction.CLASSIFICATION, schema); encodePredictProbaOutput(treeModel, DataType.DOUBLE, categoricalLabel); encodeApplyOutput(treeModel, DataType.INTEGER); @@ -53,6 +51,7 @@ public TreeModel encodeModel(Schema schema){ return treeModel; } + @Override public Tree getTree(){ return get("treelib_tree_", Tree.class); } diff --git a/pmml-sklearn/src/main/java/sklearn2pmml/tree/CHAIDRegressor.java b/pmml-sklearn/src/main/java/sklearn2pmml/tree/CHAIDRegressor.java index 86e1f030c..81e299cee 100644 --- a/pmml-sklearn/src/main/java/sklearn2pmml/tree/CHAIDRegressor.java +++ b/pmml-sklearn/src/main/java/sklearn2pmml/tree/CHAIDRegressor.java @@ -27,7 +27,7 @@ import sklearn.Regressor; import treelib.Tree; -public class CHAIDRegressor extends Regressor implements HasApplyField { +public class CHAIDRegressor extends Regressor implements HasApplyField, HasTree { public CHAIDRegressor(String module, String name){ super(module, name); @@ -40,15 +40,14 @@ public String getApplyField(){ @Override public TreeModel encodeModel(Schema schema){ - Tree tree = getTree(); - - TreeModel treeModel = CHAIDUtil.encodeModel(MiningFunction.REGRESSION, tree, schema); + TreeModel treeModel = CHAIDUtil.encodeModel(this, MiningFunction.REGRESSION, schema); encodeApplyOutput(treeModel, DataType.INTEGER); return treeModel; } + @Override public Tree getTree(){ return get("treelib_tree_", Tree.class); } diff --git a/pmml-sklearn/src/main/java/sklearn2pmml/tree/CHAIDUtil.java b/pmml-sklearn/src/main/java/sklearn2pmml/tree/CHAIDUtil.java index e7e06acb9..e3b7ef44d 100644 --- a/pmml-sklearn/src/main/java/sklearn2pmml/tree/CHAIDUtil.java +++ b/pmml-sklearn/src/main/java/sklearn2pmml/tree/CHAIDUtil.java @@ -52,6 +52,7 @@ import org.jpmml.converter.PredicateManager; import org.jpmml.converter.Schema; import org.jpmml.python.ClassDictUtil; +import sklearn.Estimator; import treelib.Node; import treelib.Tree; @@ -61,7 +62,9 @@ private CHAIDUtil(){ } static - public TreeModel encodeModel(MiningFunction miningFunction, Tree tree, Schema schema){ + public TreeModel encodeModel(E estimator, MiningFunction miningFunction, Schema schema){ + Tree tree = estimator.getTree(); + org.dmg.pmml.tree.Node root = encodeNode(True.INSTANCE, tree.selectRoot(), tree, new PredicateManager(), schema); return new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root); diff --git a/pmml-sklearn/src/main/java/sklearn2pmml/tree/HasTree.java b/pmml-sklearn/src/main/java/sklearn2pmml/tree/HasTree.java new file mode 100644 index 000000000..fde24cb20 --- /dev/null +++ b/pmml-sklearn/src/main/java/sklearn2pmml/tree/HasTree.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2024 Villu Ruusmann + * + * This file is part of JPMML-SkLearn + * + * JPMML-SkLearn is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * JPMML-SkLearn is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with JPMML-SkLearn. If not, see . + */ +package sklearn2pmml.tree; + +import treelib.Tree; + +public interface HasTree { + + Tree getTree(); +} \ No newline at end of file