Skip to content

Commit

Permalink
add H2ORandomForestEstimator schema
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanSoley committed Feb 27, 2024
1 parent 727c176 commit d4b3d79
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 4 deletions.
3 changes: 3 additions & 0 deletions rubicon_ml/schema/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
"h2o__H2OGradientBoostingEstimator": lambda: _load_schema(
os.path.join("schema", "h2o__H2OGradientBoostingEstimator.yaml")
),
"h2o__H2ORandomForestEstimator": lambda: _load_schema(
os.path.join("schema", "h2o__H2ORandomForestEstimator.yaml")
),
"lightgbm__LGBMModel": lambda: _load_schema(os.path.join("schema", "lightgbm__LGBMModel.yaml")),
"lightgbm__LGBMClassifier": lambda: _load_schema(
os.path.join("schema", "lightgbm__LGBMClassifier.yaml")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,20 @@ parameters:
value_attr: pred_noise_bandwidth
- name: quantile_alpha
value_attr: quantile_alpha
- name: response_column
value_attr: response_column
- name: r2_stopping
value_attr: r2_stopping
- name: response_column
value_attr: response_column
- name: sample_rate
value_attr: sample_rate
- name: sample_rate_per_class
value_attr: sample_rate_per_class
- name: seed
value_attr: seed
- name: score_each_iteration
value_attr: score_each_iteration
- name: score_tree_interval
value_attr: score_tree_interval
- name: seed
value_attr: seed
- name: stopping_metric
value_attr: stopping_metric
- name: stopping_rounds
Expand Down
104 changes: 104 additions & 0 deletions rubicon_ml/schema/schema/h2o__H2ORandomForestEstimator.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
name: h2o__H2ORandomForestEstimator
version: 1.0.0

compatibility:
lightgbm:
max_version:
min_version: 3.44.0.1
docs_url: https://docs.h2o.ai/h2o/latest-stable/h2o-py/docs/modeling.html#h2orandomforestestimator

parameters:
- name: auc_type
value_attr: auc_type
- name: balance_classes
value_attr: balance_classes
- name: binomial_double_trees
value_attr: binomial_double_trees
- name: build_tree_one_node
value_attr: build_tree_one_node
- name: calibrate_model
value_attr: calibrate_model
- name: calibration_method
value_attr: calibration_method
- name: categorical_encoding
value_attr: categorical_encoding
- name: check_constant_response
value_attr: check_constant_response
- name: class_sampling_factors
value_attr: class_sampling_factors
- name: col_sample_rate_change_per_level
value_attr: col_sample_rate_change_per_level
- name: col_sample_rate_per_tree
value_attr: col_sample_rate_per_tree
- name: custom_metric_func
value_attr: custom_metric_func
- name: distribution
value_attr: distribution
- name: export_checkpoints_dir
value_attr: export_checkpoints_dir
- name: fold_assignment
value_attr: fold_assignment
- name: fold_column
value_attr: fold_column
- name: gainslift_bins
value_attr: gainslift_bins
- name: histogram_type
value_attr: histogram_type
- name: ignore_const_cols
value_attr: ignore_const_cols
- name: ignored_columns
value_attr: ignored_columns
- name: keep_cross_validation_fold_assignment
value_attr: keep_cross_validation_fold_assignment
- name: keep_cross_validation_models
value_attr: keep_cross_validation_models
- name: keep_cross_validation_predictions
value_attr: keep_cross_validation_predictions
- name: max_after_balance_size
value_attr: max_after_balance_size
- name: max_confusion_matrix_size
value_attr: max_confusion_matrix_size
- name: max_depth
value_attr: max_depth
- name: max_runtime_secs
value_attr: max_runtime_secs
- name: min_rows
value_attr: min_rows
- name: min_split_improvement
value_attr: min_split_improvement
- name: mtries
value_attr: mtries
- name: nbins
value_attr: nbins
- name: nbins_cats
value_attr: nbins_cats
- name: nbins_top_level
value_attr: nbins_top_level
- name: nfolds
value_attr: nfolds
- name: ntrees
value_attr: ntrees
- name: offset_column
value_attr: offset_column
- name: r2_stopping
value_attr: r2_stopping
- name: response_column
value_attr: response_column
- name: sample_rate
value_attr: sample_rate
- name: sample_rate_per_class
value_attr: sample_rate_per_class
- name: score_each_iteration
value_attr: score_each_iteration
- name: score_tree_interval
value_attr: score_tree_interval
- name: seed
value_attr: seed
- name: stopping_metric
value_attr: stopping_metric
- name: stopping_rounds
value_attr: stopping_rounds
- name: stopping_tolerance
value_attr: stopping_tolerance
- name: weights_column
value_attr: weights_column
2 changes: 2 additions & 0 deletions tests/integration/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from h2o import H2OFrame
from h2o.estimators.gbm import H2OGradientBoostingEstimator
from h2o.estimators.glm import H2OGeneralizedLinearEstimator
from h2o.estimators.random_forest import H2ORandomForestEstimator
from lightgbm import LGBMClassifier, LGBMRegressor
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier, XGBRegressor
Expand All @@ -12,6 +13,7 @@
H2O_SCHEMA_CLS = [
H2OGeneralizedLinearEstimator,
H2OGradientBoostingEstimator,
H2ORandomForestEstimator,
]
PANDAS_SCHEMA_CLS = [
LGBMClassifier,
Expand Down

0 comments on commit d4b3d79

Please sign in to comment.