From 37259d305c4bb1cf0dc2822fc0713b1fe7245c9f Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Wed, 20 Mar 2024 08:57:52 +0200 Subject: [PATCH] Improved support for the 'ExpressionClassifier.normalization_method' attribute --- .../expression/ExpressionClassifier.java | 114 ++++++++++++------ .../src/test/resources/pkl/ExpressionIris.pkl | Bin 907 -> 905 bytes .../resources/pkl/ExpressionVersicolor.pkl | Bin 693 -> 694 bytes 3 files changed, 79 insertions(+), 35 deletions(-) diff --git a/pmml-sklearn/src/main/java/sklearn2pmml/expression/ExpressionClassifier.java b/pmml-sklearn/src/main/java/sklearn2pmml/expression/ExpressionClassifier.java index 404ed7d63..a45847685 100644 --- a/pmml-sklearn/src/main/java/sklearn2pmml/expression/ExpressionClassifier.java +++ b/pmml-sklearn/src/main/java/sklearn2pmml/expression/ExpressionClassifier.java @@ -18,6 +18,7 @@ */ package sklearn2pmml.expression; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -83,50 +84,42 @@ public RegressionModel encodeModel(Schema schema){ List regressionTables; switch(normalizationMethod){ - case LOGIT: - { - if((categoryRegressionTables.size() != 1) || (categories.size() != 2)){ - throw new IllegalArgumentException(); - } - - Object activeCategory = Iterables.getOnlyElement(categoryRegressionTables.keySet()); - Object passiveCategory; + case NONE: + if(categoricalLabel.size() == 2){ + regressionTables = encodeBinaryClassifier(categories, categoryRegressionTables); + } else - int index = categories.indexOf(activeCategory); - if(index == 0){ - passiveCategory = categories.get(1); - } else + if(categoricalLabel.size() >= 3){ + List activeRegressionTables = encodeMultinomialClassifier(categories.subList(0, categories.size() - 1), categoryRegressionTables); - if(index == 1){ - passiveCategory = categories.get(0); - } else - - { - throw new IllegalArgumentException(); - } + RegressionTable passiveRegressionTable = RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), null) + .setTargetCategory(categories.get(categories.size() - 1)); - RegressionTable activeRegressionTable = categoryRegressionTables.get(activeCategory) - .setTargetCategory(activeCategory); + regressionTables = new ArrayList<>(activeRegressionTables); + regressionTables.add(passiveRegressionTable); + } else - RegressionTable passiveRegressionTable = RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), null) - .setTargetCategory(passiveCategory); + { + throw new IllegalArgumentException(); + } + break; + case LOGIT: + if(categoricalLabel.size() == 2){ + regressionTables = encodeBinaryClassifier(categories, categoryRegressionTables); + } else - regressionTables = Arrays.asList(activeRegressionTable, passiveRegressionTable); + { + throw new IllegalArgumentException(); } break; - case SIMPLEMAX: case SOFTMAX: + case SIMPLEMAX: + if(categoricalLabel.size() >= 2){ + regressionTables = encodeMultinomialClassifier(categories, categoryRegressionTables); + } else + { - if((categoryRegressionTables.size() != categories.size()) || !(categoryRegressionTables.keySet()).containsAll(categories)){ - throw new IllegalArgumentException(); - } - - regressionTables = categories.stream() - .map(category -> { - return categoryRegressionTables.get(category) - .setTargetCategory(category); - }) - .collect(Collectors.toList()); + throw new IllegalArgumentException(); } break; default: @@ -149,10 +142,61 @@ public String getNormalizationMethod(){ return getString("normalization_method"); } + static + private List encodeBinaryClassifier(List categories, Map categoryRegressionTables){ + + if(categoryRegressionTables.size() != 1){ + throw new IllegalArgumentException(); + } + + Map.Entry entry = Iterables.getOnlyElement(categoryRegressionTables.entrySet()); + + Object activeCategory = entry.getKey(); + Object passiveCategory; + + int index = categories.indexOf(activeCategory); + if(index == 0){ + passiveCategory = categories.get(1); + } else + + if(index == 1){ + passiveCategory = categories.get(0); + } else + + { + throw new IllegalArgumentException(); + } + + RegressionTable activeRegressionTable = entry.getValue() + .setTargetCategory(activeCategory); + + RegressionTable passiveRegressionTable = RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), null) + .setTargetCategory(passiveCategory); + + return Arrays.asList(activeRegressionTable, passiveRegressionTable); + } + + static + private List encodeMultinomialClassifier(List categories, Map categoryRegressionTables){ + + if(categoryRegressionTables.size() != categories.size() || !(categoryRegressionTables.keySet()).containsAll(categories)){ + throw new IllegalArgumentException(); + } + + return categories.stream() + .map(category -> { + return categoryRegressionTables.get(category) + .setTargetCategory(category); + }) + .collect(Collectors.toList()); + } + static private RegressionModel.NormalizationMethod parseNormalizationMethod(String normalizationMethod){ switch(normalizationMethod){ + case "none": + return RegressionModel.NormalizationMethod.NONE; case "logit": return RegressionModel.NormalizationMethod.LOGIT; case "simplemax": diff --git a/pmml-sklearn/src/test/resources/pkl/ExpressionIris.pkl b/pmml-sklearn/src/test/resources/pkl/ExpressionIris.pkl index f486e5faae1dafc36ec5e9ed804030da35ad37a1..7fe57023bb9f77783b184bd0c7353a49470fdb5c 100644 GIT binary patch delta 895 zcmV-_1AzRC2Z;xe7k?9}T5OZlaTGuh@xrEVl&Xp)h*q<^WADK1WiFf6s7U1@Z6nPa zlfMA}g1;8dj@Or%Sb_&eI`+)D&H29BbKX?0{;DjPrS6NLDMmbQrcp%gl%$xF7;AO+ z`03N5;fqFX53pCTbhkAPl%a-%T(3i5(2iM87NZhSbL1dL0(W3sq;gc#N=uuKHfYW+sa zOcDxrPLfShRPtc0K642t!X1XYy)5>(x<5QR#u8Ec2*-UHK04dhYAv@&NZ&6pCVz5!->Q;cZCX`3__EftcXq*2Xt^Cr>t|#1@RM_ za2Dk?p@y@W&JdH*@@t4Ak4r z-S({vv|D@4CvzCc)MgBAgedt1nGZ0+GE4%ks)9r*#Swa?)vD1BueW@M} z>om8y0J0xPIF_zAkXUHvUgo_i`^Lrg6OQdDqmmeB%jMWlVj+3vON%WGy{w{eii!F9bFcLjh=`O=R1xn)vTYq}XmQUEqJGOels#{J8urjMiXs}uA zPPYvH&-U0|z{l23*av{n-SCN1xzOQW z*mw^2!vnSWqS2gwk%Rn2oVrtSR^cteZ{LT9ImxB#=WrtiNkLG~42TVKROn-nK8OSf zLFW0Nf)J2k$~Rxfg}?ll`RDo1%YSuLY}A@WypQGW|scBUm{#?Al$ delta 897 zcmV-{1AhF82a5-g7k?e8T5OZlaVVft@xmraO_N$B2iWI0$?*rZM(uh_$+X zcyO?Pcy;_;Z~dl!)lIcbF!C_heYFvD>`_mwoPkb+))a;-4@3i5(23boq&t~@-Q1e7J-%CgvLgve_aSVjVAwSFyS zBB>8|PLfSh)Y4$BKC>yO!p_6(VG_E{!4D4(u|&SPkHe894^FnUT1zbw64)LdW8a6#5i~6@TSl+5Rg9e@szuJ;(pZK&#i; z>D|acue;lMG>3siZHCZB$fv)c>;nw2B$208UCdgoTKqEhYp=?rd5$RS-e(DS(H{Qp@&jmaDv_?D8|}aJcE72)5)MJ zPhs!LF@NaE7i15-s?T=y`TKYJNH?uTRe7!Z{r>ltu(BWP8>!)gM!p}tw1WV}`dr=5 zE6nb(0p2JKa42nOEV0nmoy2)lbu$;!jW{-g#FsR4Hf)aFC=`+>t~A&}?qw7u8?0h2 zoUy8FyOD=oVc;;`7;JG+fnn&d+F%*}eIPSzvwx+xZ25?-yko0JtlqRLfRY(S!VI6m zZVjs7|7@4t27GMohMwO)KkUPhsQRt zWOp6&!q4aRLgKvDNG%6Ai1-+$=L*C+k-$LO8iFZwiftH&z9%&Kyl9CI*&&at+LVO8EgXLG}xM2r>eZ zAOvKd|1k&!@uqy!by)bzf0=)t|GfMc`^84BNyJB3-b~R}$xrn~cR~+2NXhwhmK@`< XT+a)vX-#fXL%J8)dzpU&D0F%7YKhB$ diff --git a/pmml-sklearn/src/test/resources/pkl/ExpressionVersicolor.pkl b/pmml-sklearn/src/test/resources/pkl/ExpressionVersicolor.pkl index cbf6db0bc58803b5c93402a1ff246dd12042d4fa..dbf8aaf7551ba31bfd9601cfa6f056d8c890e8f1 100644 GIT binary patch literal 694 zcmV;n0!jUN+Kp0ePZL29rcj_%NF-p4(GM4m(ndJ0rRBw-i6MURO4DcvCWc)0dfV>d z-b;4(z?zuo2NaV1=JN0SFF3bcU#OBe$#XY5^USleyYnr7xjQwXP2S_5C`K%-#z8>M z*pD&wL#*U{Yk$Ad>Mwt(@^|%BmE;V;$iqxE48bQ`l=LGAzxfMM~D46U?~Zj%dW3P7=0DA3ml~`3R*mb`X{u6wQJnKT#%SJnWmL^t9TLNg<{G;%jf%OY zGjj}!goVJYZdYr@6N5%wUzBP)rm=3+Yi1cZm~uSIUE46Xt<9|gxsq~gn<=+X?K5}G zYQ?G!n3ic-R;652C~>)54&fmY^}ir}@&F4Gc}h-eu95})8kjpuE*(cM^<6Veg1G0z zej8pz$;D>6e9ai@9Wt1Y(Wfl%#MSknr>m5l4Ly`{fdrF7P>hv5c?N#O(z%qmJ@7rz zgE)EC_kcU;Z>x**pX!||*;6vNtD4Q`+bb~ru_uIbWf4&tojO5)Vs$PbUiYbYbKrp8 zFu@ASaFy9aX1VNtd7L?k`$Hi)6W}0XJs9;_m;^WCSY5~k1e0=t9hhC} zaYY_qADKKQB~rHYXNo*0t7MI=%c+A(b=W%DfW~dh&f5iOnbgn8X4=dA)rh@^tuY`5 zAOHoJfJPP%0bPIq5WsExLx2zPM!2c9GQRjv{oDAr>0dT7M=l7&x>)@Cmdb-GnmpJV czoAFtH&ncNKXKMmJL@6!xqdS94@||cRbOjcI{*Lx literal 693 zcmV;m0!saO+Kp0eYt%px?zPvqvlUA#BKjdH+B1pzgg z&oK2vtkptuZ?E3$ZGY*~cl}it)eOPN!(2DiGUM3uThfsng`yqtAa&pGxfzdKG!|S! zjM46xb6QBugg(-$(4t5PzwLux-5V(ke+FHMMWI$JZ~L3wK|;U8Td9hZWC(EBIkhw* z6-%GOSp$(lYCcKQs^*+vF8o$RBkr{0uvI+#xZ1=LQL~Q2jwGu`YewI7;CqRuq}5`= zm7wT&xSihUHyFlWYAnffXeza4T<=a3mD0lGOINR3?eDl*m-4C%*`}1 z&#{bo2uy3ssv3_C8g+bGtZtjebE8%@OTfWY;9>6i3v;WyvDqhA5^i-P;r6IK=C)~7 z%2uCQHp}I5rBu);c7<9Ap^=FCUr>7U080{iT1}^1t>)5hVBu)Bd>pydcg-*kSl3~G z3%a7!QX|>E;T&}jILyazE9XS)rum?kW@$AWdMM!n2_}aihP6F;26n`gxm3AbuszuY zKY2>lfi>xG>5KE9`n@jNQ!2Nk8;!=hD_H5ro{-vA1w?6d>I4B|`dmG@t}s2vfva`G z0Eg0XyAlg+--+GZS0ig8Te<9id7K%s`$Ho&6W}1?T^RLQm;^TxSYN0`1e0=t9hhC} z2~8edUzt21MN+c!XPP`CYh<06YHGh?4K6?|DBmvId3z2DC)IOOODb9zuIB?BjR6dR z02E*X>X|