Skip to content

Commit

Permalink
Added support for the 'ObliqueRandomForestClassifier' model type
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Apr 21, 2024
1 parent 583a4b1 commit 36580db
Show file tree
Hide file tree
Showing 13 changed files with 2,304 additions and 125 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ Java library and command-line application for converting [Scikit-Learn](https://

Examples: [extensions/sktree.py](https://github.com/jpmml/jpmml-sklearn/blob/master/pmml-sklearn-extension/src/test/resources/extensions/sktree.py)

* [`sktree.ensemble.ObliqueRandomForestClassifier`](https://docs.neurodata.io/scikit-tree/dev/generated/sktree.ObliqueRandomForestClassifier.html)
* [`sktree.ensemble.ObliqueRandomForestRegressor`](https://docs.neurodata.io/scikit-tree/dev/generated/sktree.ObliqueRandomForestRegressor.html)
* [`sktree.tree.ObliqueDecisionTreeClassifier`](https://docs.neurodata.io/scikit-tree/dev/generated/sktree.tree.ObliqueDecisionTreeClassifier.html)
* [`sktree.tree.ObliqueDecisionTreeRegressor`](https://docs.neurodata.io/scikit-tree/dev/generated/sktree.tree.ObliqueDecisionTreeRegressor.html)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) 2024 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package sktree.ensemble.forest;

import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import sklearn.Estimator;
import sklearn.HasEstimatorEnsemble;
import sklearn.tree.HasTree;

public class ObliqueForestUtil {

private ObliqueForestUtil(){
}

static
public <E extends Estimator & HasEstimatorEnsemble<T>, T extends Estimator & HasTree> MiningModel encodeBaseObliqueForest(E estimator, MiningFunction miningFunction, Segmentation.MultipleModelMethod multipleModelMethod, Schema schema){
List<? extends T> estimators = estimator.getEstimators();

Schema segmentSchema = schema.toAnonymousSchema();

Function<T, TreeModel> function = new Function<T, TreeModel>(){

@Override
public TreeModel apply(T estimator){
int i = estimators.indexOf(estimator);

if(i < 0){
throw new IllegalArgumentException();
}

return (TreeModel)estimator.encode((i + 1), segmentSchema);
}
};

List<TreeModel> treeModels = estimators.stream()
.map(function)
.collect(Collectors.toList());

MiningModel miningModel = new MiningModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()))
.setSegmentation(MiningModelUtil.createSegmentation(multipleModelMethod, Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels));

return miningModel;

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (c) 2024 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package sktree.ensemble.forest;

import java.util.List;

import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Schema;
import sklearn.Classifier;
import sklearn.HasEstimatorEnsemble;
import sktree.tree.ObliqueDecisionTreeClassifier;

public class ObliqueRandomForestClassifier extends Classifier implements HasEstimatorEnsemble<ObliqueDecisionTreeClassifier> {

public ObliqueRandomForestClassifier(String module, String name){
super(module, name);
}

@Override
public DataType getDataType(){
return DataType.FLOAT;
}

@Override
public MiningModel encodeModel(Schema schema){
CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();

MiningModel miningModel = ObliqueForestUtil.encodeBaseObliqueForest(this, MiningFunction.CLASSIFICATION, Segmentation.MultipleModelMethod.AVERAGE, schema);

encodePredictProbaOutput(miningModel, DataType.DOUBLE, categoricalLabel);

return miningModel;
}

@Override
public List<ObliqueDecisionTreeClassifier> getEstimators(){
return getList("estimators_", ObliqueDecisionTreeClassifier.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,13 @@
*/
package sktree.ensemble.forest;

import java.util.ArrayList;
import java.util.List;

import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import sklearn.HasEstimatorEnsemble;
import sklearn.Regressor;
import sktree.tree.ObliqueDecisionTreeRegressor;
Expand All @@ -40,25 +36,13 @@ public ObliqueRandomForestRegressor(String module, String name){
}

@Override
public Model encodeModel(Schema schema){
List<ObliqueDecisionTreeRegressor> estimators = getEstimators();

List<TreeModel> treeModels = new ArrayList<>();

Schema segmentSchema = schema.toAnonymousSchema();

for(int i = 0; i < estimators.size(); i++){
ObliqueDecisionTreeRegressor estimator = estimators.get(i);

TreeModel treeModel = (TreeModel)estimator.encode((i + 1), segmentSchema);

treeModels.add(treeModel);
}

MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels));
public DataType getDataType(){
return DataType.FLOAT;
}

return miningModel;
@Override
public MiningModel encodeModel(Schema schema){
return ObliqueForestUtil.encodeBaseObliqueForest(this, MiningFunction.REGRESSION, Segmentation.MultipleModelMethod.AVERAGE, schema);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ sklego.preprocessing.identitytransformer.IdentityTransformer = sklearn.IdentityT
sktree._lib.sklearn.tree._criterion.Entropy = sklearn.tree.ClassificationCriterion
sktree._lib.sklearn.tree._criterion.Gini = sklearn.tree.ClassificationCriterion
sktree._lib.sklearn.tree._tree.DepthFirstTreeBuilder = sklearn.tree.DepthFirstTreeBuilder
sktree.ensemble._supervised_forest.ObliqueRandomForestClassifier = sktree.ensemble.forest.ObliqueRandomForestClassifier
sktree.ensemble._supervised_forest.ObliqueRandomForestRegressor = sktree.ensemble.forest.ObliqueRandomForestRegressor
sktree.tree._classes.ObliqueDecisionTreeClassifier = sktree.tree.ObliqueDecisionTreeClassifier
sktree.tree._classes.ObliqueDecisionTreeRegressor = sktree.tree.ObliqueDecisionTreeRegressor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,18 @@ public void evaluateObliqueDecisionTreeIris() throws Exception {
evaluate("ObliqueDecisionTree", IRIS);
}

@Test
public void evaluateObliqueRandomForestAudit() throws Exception {
evaluate("ObliqueRandomForest", AUDIT);
}

@Test
public void evaluateObliqueRandomForestAuto() throws Exception {
evaluate("ObliqueRandomForest", AUTO);
}

@Test
public void evaluateObliqueRandomForestIris() throws Exception {
evaluate("ObliqueRandomForest", IRIS);
}
}
Loading

0 comments on commit 36580db

Please sign in to comment.