From 0e83975d6dd9796c50684a0659dd496e02f9d5ae Mon Sep 17 00:00:00 2001 From: Ryan Soley Date: Tue, 27 Feb 2024 14:24:28 -0500 Subject: [PATCH] add `H2ORandomForestEstimator` schema --- rubicon_ml/schema/registry.py | 3 + .../h2o__H2OGradientBoostingEstimator.yaml | 8 +- .../schema/h2o__H2ORandomForestEstimator.yaml | 104 ++++++++++++++++++ tests/integration/test_schema.py | 2 + 4 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 rubicon_ml/schema/schema/h2o__H2ORandomForestEstimator.yaml diff --git a/rubicon_ml/schema/registry.py b/rubicon_ml/schema/registry.py index 5b436709..2dc69955 100644 --- a/rubicon_ml/schema/registry.py +++ b/rubicon_ml/schema/registry.py @@ -12,6 +12,9 @@ "h2o__H2OGradientBoostingEstimator": lambda: _load_schema( os.path.join("schema", "h2o__H2OGradientBoostingEstimator.yaml") ), + "h2o__H2ORandomForestEstimator": lambda: _load_schema( + os.path.join("schema", "h2o__H2ORandomForestEstimator.yaml") + ), "lightgbm__LGBMModel": lambda: _load_schema(os.path.join("schema", "lightgbm__LGBMModel.yaml")), "lightgbm__LGBMClassifier": lambda: _load_schema( os.path.join("schema", "lightgbm__LGBMClassifier.yaml") diff --git a/rubicon_ml/schema/schema/h2o__H2OGradientBoostingEstimator.yaml b/rubicon_ml/schema/schema/h2o__H2OGradientBoostingEstimator.yaml index 3652649a..6690c89e 100644 --- a/rubicon_ml/schema/schema/h2o__H2OGradientBoostingEstimator.yaml +++ b/rubicon_ml/schema/schema/h2o__H2OGradientBoostingEstimator.yaml @@ -102,20 +102,20 @@ parameters: value_attr: pred_noise_bandwidth - name: quantile_alpha value_attr: quantile_alpha - - name: response_column - value_attr: response_column - name: r2_stopping value_attr: r2_stopping + - name: response_column + value_attr: response_column - name: sample_rate value_attr: sample_rate - name: sample_rate_per_class value_attr: sample_rate_per_class - - name: seed - value_attr: seed - name: score_each_iteration value_attr: score_each_iteration - name: score_tree_interval value_attr: score_tree_interval + - name: seed + value_attr: seed - name: stopping_metric value_attr: stopping_metric - name: stopping_rounds diff --git a/rubicon_ml/schema/schema/h2o__H2ORandomForestEstimator.yaml b/rubicon_ml/schema/schema/h2o__H2ORandomForestEstimator.yaml new file mode 100644 index 00000000..aa859f67 --- /dev/null +++ b/rubicon_ml/schema/schema/h2o__H2ORandomForestEstimator.yaml @@ -0,0 +1,104 @@ +name: h2o__H2ORandomForestEstimator +version: 1.0.0 + +compatibility: + lightgbm: + max_version: + min_version: 3.44.0.1 +docs_url: https://docs.h2o.ai/h2o/latest-stable/h2o-py/docs/modeling.html#h2orandomforestestimator + +parameters: + - name: auc_type + value_attr: auc_type + - name: balance_classes + value_attr: balance_classes + - name: binomial_double_trees + value_attr: binomial_double_trees + - name: build_tree_one_node + value_attr: build_tree_one_node + - name: calibrate_model + value_attr: calibrate_model + - name: calibration_method + value_attr: calibration_method + - name: categorical_encoding + value_attr: categorical_encoding + - name: check_constant_response + value_attr: check_constant_response + - name: class_sampling_factors + value_attr: class_sampling_factors + - name: col_sample_rate_change_per_level + value_attr: col_sample_rate_change_per_level + - name: col_sample_rate_per_tree + value_attr: col_sample_rate_per_tree + - name: custom_metric_func + value_attr: custom_metric_func + - name: distribution + value_attr: distribution + - name: export_checkpoints_dir + value_attr: export_checkpoints_dir + - name: fold_assignment + value_attr: fold_assignment + - name: fold_column + value_attr: fold_column + - name: gainslift_bins + value_attr: gainslift_bins + - name: histogram_type + value_attr: histogram_type + - name: ignore_const_cols + value_attr: ignore_const_cols + - name: ignored_columns + value_attr: ignored_columns + - name: keep_cross_validation_fold_assignment + value_attr: keep_cross_validation_fold_assignment + - name: keep_cross_validation_models + value_attr: keep_cross_validation_models + - name: keep_cross_validation_predictions + value_attr: keep_cross_validation_predictions + - name: max_after_balance_size + value_attr: max_after_balance_size + - name: max_confusion_matrix_size + value_attr: max_confusion_matrix_size + - name: max_depth + value_attr: max_depth + - name: max_runtime_secs + value_attr: max_runtime_secs + - name: min_rows + value_attr: min_rows + - name: min_split_improvement + value_attr: min_split_improvement + - name: mtries + value_attr: mtries + - name: nbins + value_attr: nbins + - name: nbins_cats + value_attr: nbins_cats + - name: nbins_top_level + value_attr: nbins_top_level + - name: nfolds + value_attr: nfolds + - name: ntrees + value_attr: ntrees + - name: offset_column + value_attr: offset_column + - name: r2_stopping + value_attr: r2_stopping + - name: response_column + value_attr: response_column + - name: sample_rate + value_attr: sample_rate + - name: sample_rate_per_class + value_attr: sample_rate_per_class + - name: score_each_iteration + value_attr: score_each_iteration + - name: score_tree_interval + value_attr: score_tree_interval + - name: seed + value_attr: seed + - name: stopping_metric + value_attr: stopping_metric + - name: stopping_rounds + value_attr: stopping_rounds + - name: stopping_tolerance + value_attr: stopping_tolerance + - name: weights_column + value_attr: weights_column diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 9a0ec663..c6051b6f 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -4,6 +4,7 @@ from h2o import H2OFrame from h2o.estimators.gbm import H2OGradientBoostingEstimator from h2o.estimators.glm import H2OGeneralizedLinearEstimator +from h2o.estimators.random_forest import H2ORandomForestEstimator from lightgbm import LGBMClassifier, LGBMRegressor from sklearn.ensemble import RandomForestClassifier from xgboost import XGBClassifier, XGBRegressor @@ -12,6 +13,7 @@ H2O_SCHEMA_CLS = [ H2OGeneralizedLinearEstimator, H2OGradientBoostingEstimator, + H2ORandomForestEstimator, ] PANDAS_SCHEMA_CLS = [ LGBMClassifier,