Skip to content

Commit

Permalink
Improved support for the 'ExpressionClassifier.normalization_method' …
Browse files Browse the repository at this point in the history
…attribute
  • Loading branch information
vruusmann committed Mar 20, 2024
1 parent f80ce0f commit 37259d3
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package sklearn2pmml.expression;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -83,50 +84,42 @@ public RegressionModel encodeModel(Schema schema){
List<RegressionTable> regressionTables;

switch(normalizationMethod){
case LOGIT:
{
if((categoryRegressionTables.size() != 1) || (categories.size() != 2)){
throw new IllegalArgumentException();
}

Object activeCategory = Iterables.getOnlyElement(categoryRegressionTables.keySet());
Object passiveCategory;
case NONE:
if(categoricalLabel.size() == 2){
regressionTables = encodeBinaryClassifier(categories, categoryRegressionTables);
} else

int index = categories.indexOf(activeCategory);
if(index == 0){
passiveCategory = categories.get(1);
} else
if(categoricalLabel.size() >= 3){
List<RegressionTable> activeRegressionTables = encodeMultinomialClassifier(categories.subList(0, categories.size() - 1), categoryRegressionTables);

if(index == 1){
passiveCategory = categories.get(0);
} else

{
throw new IllegalArgumentException();
}
RegressionTable passiveRegressionTable = RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), null)
.setTargetCategory(categories.get(categories.size() - 1));

RegressionTable activeRegressionTable = categoryRegressionTables.get(activeCategory)
.setTargetCategory(activeCategory);
regressionTables = new ArrayList<>(activeRegressionTables);
regressionTables.add(passiveRegressionTable);
} else

RegressionTable passiveRegressionTable = RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), null)
.setTargetCategory(passiveCategory);
{
throw new IllegalArgumentException();
}
break;
case LOGIT:
if(categoricalLabel.size() == 2){
regressionTables = encodeBinaryClassifier(categories, categoryRegressionTables);
} else

regressionTables = Arrays.asList(activeRegressionTable, passiveRegressionTable);
{
throw new IllegalArgumentException();
}
break;
case SIMPLEMAX:
case SOFTMAX:
case SIMPLEMAX:
if(categoricalLabel.size() >= 2){
regressionTables = encodeMultinomialClassifier(categories, categoryRegressionTables);
} else

{
if((categoryRegressionTables.size() != categories.size()) || !(categoryRegressionTables.keySet()).containsAll(categories)){
throw new IllegalArgumentException();
}

regressionTables = categories.stream()
.map(category -> {
return categoryRegressionTables.get(category)
.setTargetCategory(category);
})
.collect(Collectors.toList());
throw new IllegalArgumentException();
}
break;
default:
Expand All @@ -149,10 +142,61 @@ public String getNormalizationMethod(){
return getString("normalization_method");
}

static
private List<RegressionTable> encodeBinaryClassifier(List<?> categories, Map<?, RegressionTable> categoryRegressionTables){

if(categoryRegressionTables.size() != 1){
throw new IllegalArgumentException();
}

Map.Entry<?, RegressionTable> entry = Iterables.getOnlyElement(categoryRegressionTables.entrySet());

Object activeCategory = entry.getKey();
Object passiveCategory;

int index = categories.indexOf(activeCategory);
if(index == 0){
passiveCategory = categories.get(1);
} else

if(index == 1){
passiveCategory = categories.get(0);
} else

{
throw new IllegalArgumentException();
}

RegressionTable activeRegressionTable = entry.getValue()
.setTargetCategory(activeCategory);

RegressionTable passiveRegressionTable = RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), null)
.setTargetCategory(passiveCategory);

return Arrays.asList(activeRegressionTable, passiveRegressionTable);
}

static
private List<RegressionTable> encodeMultinomialClassifier(List<?> categories, Map<?, RegressionTable> categoryRegressionTables){

if(categoryRegressionTables.size() != categories.size() || !(categoryRegressionTables.keySet()).containsAll(categories)){
throw new IllegalArgumentException();
}

return categories.stream()
.map(category -> {
return categoryRegressionTables.get(category)
.setTargetCategory(category);
})
.collect(Collectors.toList());
}

static
private RegressionModel.NormalizationMethod parseNormalizationMethod(String normalizationMethod){

switch(normalizationMethod){
case "none":
return RegressionModel.NormalizationMethod.NONE;
case "logit":
return RegressionModel.NormalizationMethod.LOGIT;
case "simplemax":
Expand Down
Binary file modified pmml-sklearn/src/test/resources/pkl/ExpressionIris.pkl
Binary file not shown.
Binary file modified pmml-sklearn/src/test/resources/pkl/ExpressionVersicolor.pkl
Binary file not shown.

0 comments on commit 37259d3

Please sign in to comment.