Skip to content

Commit

Permalink
add H2OXGBoostEstimator schema
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanSoley committed Feb 27, 2024
1 parent 55615d8 commit a38d34f
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 8 deletions.
3 changes: 3 additions & 0 deletions rubicon_ml/schema/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
"h2o__H2OTargetEncoderEstimator": lambda: _load_schema(
os.path.join("schema", "h2o__H2OTargetEncoderEstimator.yaml")
),
"h2o__H2OXGBoostEstimator": lambda: _load_schema(
os.path.join("schema", "h2o__H2OXGBoostEstimator.yaml")
),
"lightgbm__LGBMModel": lambda: _load_schema(os.path.join("schema", "lightgbm__LGBMModel.yaml")),
"lightgbm__LGBMClassifier": lambda: _load_schema(
os.path.join("schema", "lightgbm__LGBMClassifier.yaml")
Expand Down
143 changes: 143 additions & 0 deletions rubicon_ml/schema/schema/h2o__H2OXGBoostEstimator.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
name: h2o__H2OXGBoostEstimator
version: 1.0.0

compatibility:
lightgbm:
max_version:
min_version: 3.44.0.1
docs_url: https://docs.h2o.ai/h2o/latest-stable/h2o-py/docs/modeling.html#h2oxgboostestimator

parameters:
- name: auc_type
value_attr: auc_type
- name: backend
value_attr: backend
- name: booster
value_attr: booster
- name: build_tree_one_node
value_attr: build_tree_one_node
- name: calibrate_model
value_attr: calibrate_model
- name: calibration_method
value_attr: calibration_method
- name: categorical_encoding
value_attr: categorical_encoding
- name: col_sample_rate
value_attr: col_sample_rate
- name: col_sample_rate_per_tree
value_attr: col_sample_rate_per_tree
- name: colsample_bylevel
value_attr: colsample_bylevel
- name: colsample_bynode
value_attr: colsample_bynode
- name: colsample_bytree
value_attr: colsample_bytree
- name: distribution
value_attr: distribution
- name: dmatrix_type
value_attr: dmatrix_type
- name: eta
value_attr: eta
- name: eval_metric
value_attr: eval_metric
- name: export_checkpoints_dir
value_attr: export_checkpoints_dir
- name: fold_assignment
value_attr: fold_assignment
- name: fold_column
value_attr: fold_column
- name: gainslift_bins
value_attr: gainslift_bins
- name: gamma
value_attr: gamma
- name: gpu_id
value_attr: gpu_id
- name: grow_policy
value_attr: grow_policy
- name: ignore_const_cols
value_attr: ignore_const_cols
- name: ignored_columns
value_attr: ignored_columns
- name: interaction_constraints
value_attr: interaction_constraints
- name: keep_cross_validation_fold_assignment
value_attr: keep_cross_validation_fold_assignment
- name: keep_cross_validation_models
value_attr: keep_cross_validation_models
- name: keep_cross_validation_predictions
value_attr: keep_cross_validation_predictions
- name: learn_rate
value_attr: learn_rate
- name: max_abs_leafnode_pred
value_attr: max_abs_leafnode_pred
- name: max_bins
value_attr: max_bins
- name: max_delta_step
value_attr: max_delta_step
- name: max_depth
value_attr: max_depth
- name: max_leaves
value_attr: max_leaves
- name: max_runtime_secs
value_attr: max_runtime_secs
- name: min_child_weight
value_attr: min_child_weight
- name: min_rows
value_attr: min_rows
- name: monotone_constraints
value_attr: monotone_constraints
- name: nfolds
value_attr: nfolds
- name: normalize_type
value_attr: normalize_type
- name: nthread
value_attr: nthread
- name: ntrees
value_attr: ntrees
- name: offset_column
value_attr: offset_column
- name: one_drop
value_attr: one_drop
- name: parallelize_cross_validation
value_attr: parallelize_cross_validation
- name: quiet_mode
value_attr: quiet_mode
- name: rate_drop
value_attr: rate_drop
- name: reg_alpha
value_attr: reg_alpha
- name: reg_lambda
value_attr: reg_lambda
- name: response_column
value_attr: response_column
- name: sample_rate
value_attr: sample_rate
- name: sample_type
value_attr: sample_type
- name: save_matrix_directory
value_attr: save_matrix_directory
- name: scale_pos_weight
value_attr: scale_pos_weight
- name: score_each_iteration
value_attr: score_each_iteration
- name: score_eval_metric_only
value_attr: score_eval_metric_only
- name: score_tree_interval
value_attr: score_tree_interval
- name: seed
value_attr: seed
- name: skip_drop
value_attr: skip_drop
- name: stopping_metric
value_attr: stopping_metric
- name: stopping_rounds
value_attr: stopping_rounds
- name: stopping_tolerance
value_attr: stopping_tolerance
- name: subsample
value_attr: subsample
- name: tree_method
value_attr: tree_method
- name: tweedie_power
value_attr: tweedie_power
- name: weights_column
20 changes: 12 additions & 8 deletions tests/integration/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,12 @@
from h2o.estimators.glm import H2OGeneralizedLinearEstimator
from h2o.estimators.random_forest import H2ORandomForestEstimator
from h2o.estimators.targetencoder import H2OTargetEncoderEstimator
from h2o.estimators.xgboost import H2OXGBoostEstimator
from lightgbm import LGBMClassifier, LGBMRegressor
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier, XGBRegressor
from xgboost.dask import DaskXGBClassifier, DaskXGBRegressor

H2O_SCHEMA_CLS = [
H2OGeneralizedLinearEstimator,
H2OGradientBoostingEstimator,
H2ORandomForestEstimator,
H2OTargetEncoderEstimator,
]
PANDAS_SCHEMA_CLS = [
LGBMClassifier,
LGBMRegressor,
Expand All @@ -25,6 +20,17 @@
XGBRegressor,
]
DASK_SCHEMA_CLS = [DaskXGBClassifier, DaskXGBRegressor]
H2O_SCHEMA_CLS = [
H2OGeneralizedLinearEstimator,
H2OGradientBoostingEstimator,
H2ORandomForestEstimator,
H2OTargetEncoderEstimator,
]

h2o.init()

if H2OXGBoostEstimator.available():
H2O_SCHEMA_CLS.append(H2OXGBoostEstimator)


def _fit_and_log(X, y, schema_cls, rubicon_project):
Expand Down Expand Up @@ -95,8 +101,6 @@ def test_estimator_h2o_schema_train(schema_cls, make_classification_df, rubicon_
X, y = make_classification_df
y = y > y.mean()

h2o.init(nthreads=-1)

experiment = _train_and_log(X, y, schema_cls, rubicon_project)

assert len(rubicon_project.schema_["parameters"]) == len(experiment.parameters())

0 comments on commit a38d34f

Please sign in to comment.