Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: test cases mlflow registry #415

Merged
merged 3 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import mlflow.pyfunc
import mlflow.pytorch
import mlflow.sklearn
from mlflow.entities.model_registry import ModelVersion
from mlflow.exceptions import RestException
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.13.1"
version = "0.13.2"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
52 changes: 27 additions & 25 deletions tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import OrderedDict
import unittest
from contextlib import contextmanager
from unittest.mock import patch, Mock
Expand All @@ -11,8 +12,9 @@

from numalogic.models.autoencoder.variants import VanillaAE
from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache


from numalogic.registry.mlflow_registry import ModelStage
from numalogic.tools.exceptions import ModelVersionError
from tests.registry._mlflow_utils import (
model_sklearn,
create_model,
Expand Down Expand Up @@ -71,26 +73,29 @@ def test_save_model(self):
self.assertEqual(mock_status, status.status)

@patch("mlflow.sklearn.log_model", mock_log_model_sklearn)
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata())))
@patch("mlflow.active_run", Mock(return_value=return_sklearn_rundata()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version)
def test_save_model_sklearn(self):
model = self.model_sklearn
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
status = ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="sklearn")

mock_status = "READY"
self.assertEqual(mock_status, status.status)

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate", 0.01)])))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_load_model_when_pytorch_model_exist1(self):
Expand All @@ -103,11 +108,12 @@ def test_load_model_when_pytorch_model_exist1(self):
self.assertIsNotNone(data.metadata)
self.assertIsInstance(data.artifact, VanillaAE)

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata()))
def test_load_model_when_pytorch_model_exist2(self):
Expand Down Expand Up @@ -147,12 +153,13 @@ def test_load_model_when_sklearn_model_exist(self):
self.assertIsInstance(data.artifact, StandardScaler)
self.assertEqual(data.metadata, {})

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_empty_rundata())))
@patch("mlflow.active_run", Mock(return_value=return_empty_rundata()))
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_model_version", mock_get_model_version_obj)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata()))
def test_load_model_with_version(self):
Expand All @@ -177,12 +184,16 @@ def test_staging_model_load_error(self):
ml = MLflowRegistry(TRACKING_URI, model_stage=ModelStage.STAGE)
skeys = self.skeys
dkeys = self.dkeys
ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertRaises(ModelVersionError)
with self.assertLogs(level="ERROR") as log:
result = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertIsNone(result) # Ensure the result is None
self.assertTrue(
any("No Model found" in message for message in log.output)
) # Check that the expected log was made

@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version())
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata()))
def test_both_version_latest_model_with_version(self):
Expand Down Expand Up @@ -254,7 +265,7 @@ def test_no_implementation(self):
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.delete_model_version", None)
@patch("mlflow.tracking.MlflowClient.delete_model_version", Mock(return_value=None))
@patch("mlflow.pytorch.load_model", Mock(side_effect=RuntimeError))
def test_delete_model_when_model_exist(self):
model = self.model
Expand Down Expand Up @@ -321,12 +332,13 @@ def test_load_other_mlflow_err(self):
dkeys = self.dkeys
self.assertIsNone(ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch"))

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate", 0.01)])))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_is_model_stale_true(self):
Expand All @@ -342,12 +354,13 @@ def test_is_model_stale_true(self):
data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
self.assertTrue(ml.is_artifact_stale(data, 12))

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate", 0.01)])))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_is_model_stale_false(self):
Expand Down Expand Up @@ -381,29 +394,18 @@ def test_cache(self):
self.assertIsNotNone(registry._load_from_cache("key"))
self.assertIsNotNone(registry._clear_cache("key"))

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_cache_loading(self):
cache_registry = LocalLRUCache(ttl=50000)
ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry)
ml.save(
skeys=self.skeys,
dkeys=self.dkeys,
artifact=self.model,
**{"lr": 0.01},
artifact_type="pytorch",
)
ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
key = MLflowRegistry.construct_key(self.skeys, self.dkeys)
self.assertIsNotNone(ml._load_from_cache(key))
data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
self.assertIsNotNone(data)


if __name__ == "__main__":
Expand Down
Loading