Skip to content

Commit

Permalink
Added support for the 'ExpressionRegressor.normalization_method' attr…
Browse files Browse the repository at this point in the history
…ibute
  • Loading branch information
vruusmann committed Mar 20, 2024
1 parent def0c18 commit f80ce0f
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public ExpressionRegressor(String module, String name){
@Override
public RegressionModel encodeModel(Schema schema){
Expression expr = getExpr();
RegressionModel.NormalizationMethod normalizationMethod = parseNormalizationMethod(getNormalizationMethod());

PMMLEncoder encoder = schema.getEncoder();

Expand All @@ -61,7 +62,7 @@ public RegressionModel encodeModel(Schema schema){
RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(exprFeature), Collections.singletonList(1d), 0d);

RegressionModel regressionModel = new RegressionModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel), null)
.setNormalizationMethod(RegressionModel.NormalizationMethod.NONE)
.setNormalizationMethod(normalizationMethod)
.addRegressionTables(regressionTable);

return regressionModel;
Expand All @@ -70,4 +71,27 @@ public RegressionModel encodeModel(Schema schema){
public Expression getExpr(){
return get("expr", Expression.class);
}

public String getNormalizationMethod(){

if(!containsKey("normalization_method")){
return "none";
}

// SkLearn2PMML 0.105.0+
return getString("normalization_method");
}

static
private RegressionModel.NormalizationMethod parseNormalizationMethod(String normalizationMethod){

switch(normalizationMethod){
case "none":
return RegressionModel.NormalizationMethod.NONE;
case "exp":
return RegressionModel.NormalizationMethod.EXP;
default:
throw new IllegalArgumentException(normalizationMethod);
}
}
}
2 changes: 1 addition & 1 deletion pmml-sklearn/src/test/resources/extensions/sklearn2pmml.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def build_expr_auto(auto_df, name):
expr = Expression("-1.724 * _scale_displacement(X['displacement']) + 4.879 * _scale_weight(X['weight']) + 23.45", function_defs = [_scale_displacement, _scale_weight])

pipeline = PMMLPipeline([
("regressor", ExpressionRegressor(expr))
("regressor", ExpressionRegressor(expr, normalization_method = "none"))
])
pipeline.fit(auto_X, auto_y)
store_pkl(pipeline, name)
Expand Down
Binary file modified pmml-sklearn/src/test/resources/pkl/ExpressionAuto.pkl
Binary file not shown.

0 comments on commit f80ce0f

Please sign in to comment.