diff --git a/pmml-sklearn/src/main/java/sklearn2pmml/expression/ExpressionClassifier.java b/pmml-sklearn/src/main/java/sklearn2pmml/expression/ExpressionClassifier.java index 404ed7d63..a45847685 100644 --- a/pmml-sklearn/src/main/java/sklearn2pmml/expression/ExpressionClassifier.java +++ b/pmml-sklearn/src/main/java/sklearn2pmml/expression/ExpressionClassifier.java @@ -18,6 +18,7 @@ */ package sklearn2pmml.expression; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -83,50 +84,42 @@ public RegressionModel encodeModel(Schema schema){ List regressionTables; switch(normalizationMethod){ - case LOGIT: - { - if((categoryRegressionTables.size() != 1) || (categories.size() != 2)){ - throw new IllegalArgumentException(); - } - - Object activeCategory = Iterables.getOnlyElement(categoryRegressionTables.keySet()); - Object passiveCategory; + case NONE: + if(categoricalLabel.size() == 2){ + regressionTables = encodeBinaryClassifier(categories, categoryRegressionTables); + } else - int index = categories.indexOf(activeCategory); - if(index == 0){ - passiveCategory = categories.get(1); - } else + if(categoricalLabel.size() >= 3){ + List activeRegressionTables = encodeMultinomialClassifier(categories.subList(0, categories.size() - 1), categoryRegressionTables); - if(index == 1){ - passiveCategory = categories.get(0); - } else - - { - throw new IllegalArgumentException(); - } + RegressionTable passiveRegressionTable = RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), null) + .setTargetCategory(categories.get(categories.size() - 1)); - RegressionTable activeRegressionTable = categoryRegressionTables.get(activeCategory) - .setTargetCategory(activeCategory); + regressionTables = new ArrayList<>(activeRegressionTables); + regressionTables.add(passiveRegressionTable); + } else - RegressionTable passiveRegressionTable = RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), null) - .setTargetCategory(passiveCategory); + { + throw new IllegalArgumentException(); + } + break; + case LOGIT: + if(categoricalLabel.size() == 2){ + regressionTables = encodeBinaryClassifier(categories, categoryRegressionTables); + } else - regressionTables = Arrays.asList(activeRegressionTable, passiveRegressionTable); + { + throw new IllegalArgumentException(); } break; - case SIMPLEMAX: case SOFTMAX: + case SIMPLEMAX: + if(categoricalLabel.size() >= 2){ + regressionTables = encodeMultinomialClassifier(categories, categoryRegressionTables); + } else + { - if((categoryRegressionTables.size() != categories.size()) || !(categoryRegressionTables.keySet()).containsAll(categories)){ - throw new IllegalArgumentException(); - } - - regressionTables = categories.stream() - .map(category -> { - return categoryRegressionTables.get(category) - .setTargetCategory(category); - }) - .collect(Collectors.toList()); + throw new IllegalArgumentException(); } break; default: @@ -149,10 +142,61 @@ public String getNormalizationMethod(){ return getString("normalization_method"); } + static + private List encodeBinaryClassifier(List categories, Map categoryRegressionTables){ + + if(categoryRegressionTables.size() != 1){ + throw new IllegalArgumentException(); + } + + Map.Entry entry = Iterables.getOnlyElement(categoryRegressionTables.entrySet()); + + Object activeCategory = entry.getKey(); + Object passiveCategory; + + int index = categories.indexOf(activeCategory); + if(index == 0){ + passiveCategory = categories.get(1); + } else + + if(index == 1){ + passiveCategory = categories.get(0); + } else + + { + throw new IllegalArgumentException(); + } + + RegressionTable activeRegressionTable = entry.getValue() + .setTargetCategory(activeCategory); + + RegressionTable passiveRegressionTable = RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), null) + .setTargetCategory(passiveCategory); + + return Arrays.asList(activeRegressionTable, passiveRegressionTable); + } + + static + private List encodeMultinomialClassifier(List categories, Map categoryRegressionTables){ + + if(categoryRegressionTables.size() != categories.size() || !(categoryRegressionTables.keySet()).containsAll(categories)){ + throw new IllegalArgumentException(); + } + + return categories.stream() + .map(category -> { + return categoryRegressionTables.get(category) + .setTargetCategory(category); + }) + .collect(Collectors.toList()); + } + static private RegressionModel.NormalizationMethod parseNormalizationMethod(String normalizationMethod){ switch(normalizationMethod){ + case "none": + return RegressionModel.NormalizationMethod.NONE; case "logit": return RegressionModel.NormalizationMethod.LOGIT; case "simplemax": diff --git a/pmml-sklearn/src/test/resources/pkl/ExpressionIris.pkl b/pmml-sklearn/src/test/resources/pkl/ExpressionIris.pkl index f486e5faa..7fe57023b 100644 Binary files a/pmml-sklearn/src/test/resources/pkl/ExpressionIris.pkl and b/pmml-sklearn/src/test/resources/pkl/ExpressionIris.pkl differ diff --git a/pmml-sklearn/src/test/resources/pkl/ExpressionVersicolor.pkl b/pmml-sklearn/src/test/resources/pkl/ExpressionVersicolor.pkl index cbf6db0bc..dbf8aaf75 100644 Binary files a/pmml-sklearn/src/test/resources/pkl/ExpressionVersicolor.pkl and b/pmml-sklearn/src/test/resources/pkl/ExpressionVersicolor.pkl differ