Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for logging of H2O MOJO Models #486

Merged
merged 10 commits into from
Sep 30, 2024
23 changes: 20 additions & 3 deletions rubicon_ml/client/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _get_data(self):
@failsafe
def get_data(
self,
deserialize: Optional[Literal["h2o", "pickle", "xgboost"]] = None,
deserialize: Optional[Literal["h2o", "h2o_binary", "h2o_mojo", "pickle", "xgboost"]] = None,
unpickle: bool = False, # TODO: deprecate & move to `deserialize`
):
"""Loads the data associated with this artifact and
Expand All @@ -82,7 +82,8 @@ def get_data(
deseralize : str, optional
Method to use to deseralize this artifact's data.
* None to disable deseralization and return the raw data.
* "h2o" to use `h2o.load_model` to load the data.
* "h2o" or "h2o_binary" to use `h2o.load_model` to load the data.
* "h2o_mojo" to use `h2o.import_mojo` to load the data.
* "pickle" to use pickles to load the data.
* "xgboost" to use xgboost's JSON loader to load the data as a fitted model.
Defaults to None.
Expand All @@ -101,6 +102,13 @@ def get_data(
)
deserialize = "pickle"

if deserialize == "h2o":
warnings.warn(
"'deserialize' method 'h2o' will be deprecated in a future release,"
" please use 'h2o_binary' instead.",
DeprecationWarning,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

)

for repo in self.repositories or []:
try:
if deserialize == "xgboost":
Expand All @@ -119,12 +127,21 @@ def get_data(
except Exception as err:
return_err = err
else:
if deserialize == "h2o":
if deserialize in [
"h2o",
"h2o_binary",
]: # "h2o" will be deprecated in a future release
import h2o

data = h2o.load_model(
repo._get_artifact_data_path(project_name, experiment_id, self.id)
)
elif deserialize == "h2o_mojo":
import h2o

data = h2o.import_mojo(
repo._get_artifact_data_path(project_name, experiment_id, self.id)
)
elif deserialize == "pickle":
data = pickle.loads(data)

Expand Down
20 changes: 14 additions & 6 deletions rubicon_ml/client/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def log_h2o_model(
h2o_model,
artifact_name: Optional[str] = None,
export_cross_validation_predictions: bool = False,
use_mojo: bool = False,
**log_artifact_kwargs,
) -> Artifact:
"""Log an `h2o` model as an artifact using `h2o.save_model`.
Expand All @@ -256,6 +257,9 @@ def log_h2o_model(
The name of the artifact. Defaults to None, using `h2o_model`'s class name.
export_cross_validation_predictions: bool, optional (default False)
Passed directly to `h2o.save_model`.
use_mojo: bool, optional (default False)
Whether to log the model in MOJO format. If False, the model will be
logged in binary format.
log_artifact_kwargs : dict
Additional kwargs to be passed directly to `self.log_artifact`.
"""
Expand All @@ -268,12 +272,16 @@ def log_h2o_model(
artifact_name = h2o_model.__class__.__name__

with tempfile.TemporaryDirectory() as temp_dir_name:
model_data_path = h2o.save_model(
h2o_model,
export_cross_validation_predictions=export_cross_validation_predictions,
filename=artifact_name,
path=temp_dir_name,
)
if use_mojo:
model_data_path = f"{temp_dir_name}/{artifact_name}.zip"
h2o_model.download_mojo(path=model_data_path)
else:
model_data_path = h2o.save_model(
h2o_model,
export_cross_validation_predictions=export_cross_validation_predictions,
filename=artifact_name,
path=temp_dir_name,
)

artifact = self.log_artifact(
name=artifact_name,
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,6 @@ def test_estimator_h2o_schema_train(
model_artifact = experiment.artifact(name=schema_cls.__name__)

assert len(project.schema_["parameters"]) == len(experiment.parameters())
assert model_artifact.get_data(deserialize="h2o").__class__.__name__ == schema_cls.__name__
assert (
model_artifact.get_data(deserialize="h2o_binary").__class__.__name__ == schema_cls.__name__
)
23 changes: 19 additions & 4 deletions tests/unit/client/test_artifact_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import xgboost
from h2o import H2OFrame
from h2o.estimators.generic import H2OGenericEstimator
from h2o.estimators.random_forest import H2ORandomForestEstimator

from rubicon_ml import domain
Expand Down Expand Up @@ -159,8 +160,19 @@ def test_download_location(mock_open, project_client):
mock_file().write.assert_called_once_with(data)


@pytest.mark.parametrize(
["use_mojo", "deserialization_method"],
[
(False, "h2o"),
(False, "h2o_binary"),
(True, "h2o_mojo"),
thebrianbn marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_get_data_deserialize_h2o(
make_classification_df, rubicon_local_filesystem_client_with_project
make_classification_df,
rubicon_local_filesystem_client_with_project,
use_mojo,
deserialization_method,
):
"""Test logging `h2o` model data."""
_, project = rubicon_local_filesystem_client_with_project
Expand All @@ -181,10 +193,13 @@ def test_get_data_deserialize_h2o(
y=target_name,
)

artifact = project.log_h2o_model(h2o_model)
artifact_data = artifact.get_data(deserialize="h2o")
artifact = project.log_h2o_model(h2o_model, use_mojo=use_mojo)
artifact_data = artifact.get_data(deserialize=deserialization_method)

assert artifact_data.__class__ == h2o_model.__class__
if use_mojo:
assert isinstance(artifact_data, H2OGenericEstimator)
else:
assert artifact_data.__class__ == h2o_model.__class__


def test_get_data_deserialize_xgboost(
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/client/test_mixin_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ def test_log_json(project_client):
assert artifact_b.id in [a.id for a in artifacts]


def test_log_h2o_model(make_classification_df, rubicon_local_filesystem_client_with_project):
@pytest.mark.parametrize("use_mojo", [False, True])
def test_log_h2o_model(
make_classification_df, rubicon_local_filesystem_client_with_project, use_mojo
):
"""Test logging `h2o` model data."""
_, project = rubicon_local_filesystem_client_with_project
X, y = make_classification_df
Expand All @@ -222,7 +225,7 @@ def test_log_h2o_model(make_classification_df, rubicon_local_filesystem_client_w
y=target_name,
)

artifact = project.log_h2o_model(h2o_model, tags=["h2o"])
artifact = project.log_h2o_model(h2o_model, use_mojo=use_mojo, tags=["h2o"])
read_artifact = project.artifact(name=artifact.name)

assert artifact.id == read_artifact.id
Expand Down
Loading