Skip to content

Commit

Permalink
fix: dkeys format
Browse files Browse the repository at this point in the history
Signed-off-by: Leila Wang <[email protected]>
  • Loading branch information
yleilawang committed Sep 25, 2024
1 parent 82ea731 commit 69caf21
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 75 deletions.
97 changes: 39 additions & 58 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
from mlflow.exceptions import RestException
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
from mlflow.tracking import MlflowClient
from sortedcontainers import SortedSet

from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.registry.artifact import ArtifactCache
from numalogic.tools.exceptions import ModelVersionError
from numalogic.tools.types import KeyedArtifact, artifact_t, KEYS, META_VT
from numalogic.tools.types import artifact_t, KEYS, META_VT

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -192,34 +191,39 @@ def load(
def load_multiple(
self,
skeys: KEYS,
dkeys_list: list[list[str]],
dkeys: KEYS,
) -> Optional[ArtifactData]:
"""
Load multiple artifacts from the registry for pyfunc models.
Args:
skeys (KEYS): The source keys of the artifacts to load.
dkeys_list (list[list[str]]):
A list of lists containing the dkeys of the artifacts to load.
dkeys: dynamic key fields as list/tuple of strings.
Returns
-------
Optional[ArtifactData]: The loaded ArtifactData object if available otherwise None.
ArtifactData should contain a dictionary of artifacts.
"""
dkeys = self.__get_sorted_unique_dkeys(dkeys_list)
loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc")
if loaded_model is not None:
try:
unwrapped_composite_model = loaded_model.artifact.unwrap_python_model()
except Exception:
_LOGGER.exception("Error occurred while unwrapping python model")
return None

dict_artifacts = unwrapped_composite_model.dict_artifacts
metadata = loaded_model.metadata
version_info = loaded_model.extras
return ArtifactData(artifact=dict_artifacts, metadata=metadata, extras=version_info)
return None
if loaded_model is None:
return None

Check warning on line 209 in numalogic/registry/mlflow_registry.py

View check run for this annotation

Codecov / codecov/patch

numalogic/registry/mlflow_registry.py#L209

Added line #L209 was not covered by tests
if loaded_model.artifact.loader_module != "mlflow.pyfunc.model":
raise TypeError("The loaded model is not a valid pyfunc Python model.")

Check warning on line 211 in numalogic/registry/mlflow_registry.py

View check run for this annotation

Codecov / codecov/patch

numalogic/registry/mlflow_registry.py#L211

Added line #L211 was not covered by tests

try:
unwrapped_composite_model = loaded_model.artifact.unwrap_python_model()
except AttributeError:
_LOGGER.exception("The loaded model does not have an unwrap_python_model method")
return None
except Exception:
_LOGGER.exception("Unexpected error occurred while unwrapping python model.")
return None

Check warning on line 220 in numalogic/registry/mlflow_registry.py

View check run for this annotation

Codecov / codecov/patch

numalogic/registry/mlflow_registry.py#L216-L220

Added lines #L216 - L220 were not covered by tests
else:
return ArtifactData(
artifact=unwrapped_composite_model.dict_artifacts,
metadata=loaded_model.metadata,
extras=loaded_model.extras,
)

@staticmethod
def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None:
Expand Down Expand Up @@ -278,7 +282,8 @@ def save(
def save_multiple(
self,
skeys: KEYS,
dict_artifacts: dict[str, KeyedArtifact],
dkeys: KEYS,
dict_artifacts: dict[str, artifact_t],
**metadata: META_VT,
) -> Optional[ModelVersion]:
"""
Expand All @@ -287,20 +292,22 @@ def save_multiple(
Args:
----
skeys: static key fields as list/tuple of strings
dict_artifacts: dict of artifacts to save
metadata: additional metadata surrounding the artifact that needs to be saved.
skeys (KEYS): Static key fields as a list or tuple of strings.
dkeys (KEYS): Dynamic key fields as a list or tuple of strings.
dict_artifacts (dict[str, artifact_t]): Dictionary of artifacts to save.
**metadata (META_VT): Additional metadata to be saved with the artifacts.
Returns
-------
mlflow ModelVersion instance
Optional[ModelVersion]: An instance of the MLflow ModelVersion.
"""
multiple_artifacts = CompositeModels(skeys=skeys, dict_artifacts=dict_artifacts, **metadata)
dkeys_list = multiple_artifacts._get_dkeys_list()
sorted_dkeys = self.__get_sorted_unique_dkeys(dkeys_list)
if len(dict_artifacts) == 1:
_LOGGER.warning("Only one element in dict_artifacts. Please use save directly.")

Check warning on line 306 in numalogic/registry/mlflow_registry.py

View check run for this annotation

Codecov / codecov/patch

numalogic/registry/mlflow_registry.py#L306

Added line #L306 was not covered by tests
multiple_artifacts = CompositeModel(skeys=skeys, dict_artifacts=dict_artifacts, **metadata)
return self.save(
skeys=skeys,
dkeys=sorted_dkeys,
dkeys=dkeys,
artifact=multiple_artifacts,
artifact_type="pyfunc",
**metadata,
Expand Down Expand Up @@ -407,23 +414,8 @@ def __load_artifacts(
)
return model, metadata

def __get_sorted_unique_dkeys(self, dkeys_list: list[list]) -> list[str]:
"""
Returns a unique sorted list of all dkeys in the stored artifacts.
Args:
----
dkeys_list: A list of lists containing the destination keys of the artifacts.
Returns
-------
List[str]
A list of all unique dkeys in the stored artifacts, sorted in ascending order.
"""
return list(SortedSet([dkey for dkeys in dkeys_list for dkey in dkeys]))


class CompositeModels(mlflow.pyfunc.PythonModel):
class CompositeModel(mlflow.pyfunc.PythonModel):
"""A composite model that represents multiple artifacts.
This class extends the `mlflow.pyfunc.PythonModel` class and is used to store and load
Expand All @@ -438,7 +430,7 @@ class CompositeModels(mlflow.pyfunc.PythonModel):
Methods
-------
get_dkeys_list(): Returns a list of all dynamic keys in the stored artifacts.
predict: Not implemented for our use case.
Attributes
----------
Expand All @@ -450,21 +442,10 @@ class CompositeModels(mlflow.pyfunc.PythonModel):

__slots__ = ("skeys", "dict_artifacts", "metadata")

def __init__(self, skeys: KEYS, dict_artifacts: dict[str, KeyedArtifact], **metadata: META_VT):
def __init__(self, skeys: KEYS, dict_artifacts: dict[str, artifact_t], **metadata: META_VT):
self.skeys = skeys
self.dict_artifacts = dict_artifacts
self.metadata = metadata

def _get_dkeys_list(self):
"""
Returns a list of all dynamic keys in the stored artifacts.
Returns
-------
list[list[str]]: A list of all dynamic keys in the stored artifacts.
"""
dkeys_list = []
artifacts = self.dict_artifacts.values()
for artifact in artifacts:
dkeys_list.append(artifact.dkeys)
return dkeys_list
def predict(self):
raise NotImplementedError()

Check warning on line 451 in numalogic/registry/mlflow_registry.py

View check run for this annotation

Codecov / codecov/patch

numalogic/registry/mlflow_registry.py#L451

Added line #L451 was not covered by tests
12 changes: 6 additions & 6 deletions tests/registry/_mlflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

from numalogic.models.autoencoder.variants.vanilla import VanillaAE
from numalogic.models.threshold import StdDevThreshold
from numalogic.registry.mlflow_registry import CompositeModels
from numalogic.tools.types import KeyedArtifact
from numalogic.registry.mlflow_registry import CompositeModel


def create_model():
Expand Down Expand Up @@ -192,11 +191,12 @@ def mock_load_model_pyfunc(*_, **__):
return mlflow.pyfunc.PyFuncModel(
model_meta=model,
model_impl=TestObject(
python_model=CompositeModels(
skeys=["model"],
python_model=CompositeModel(
skeys=["error"],
dict_artifacts={
"AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)),
"scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()),
"inference": VanillaAE(10),
"precrocessing": StandardScaler(),
"threshold": StdDevThreshold(),
},
**{"learning_rate": 0.01},
)
Expand Down
33 changes: 22 additions & 11 deletions tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from sklearn.preprocessing import StandardScaler

from numalogic.models.autoencoder.variants import VanillaAE
from numalogic.models.threshold._std import StdDevThreshold
from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache


from numalogic.registry.mlflow_registry import ModelStage
from numalogic.tools.types import KeyedArtifact
from tests.registry._mlflow_utils import (
mock_load_model_pyfunc,
mock_log_model_pyfunc,
Expand Down Expand Up @@ -93,9 +93,11 @@ def test_save_multiple_models_pyfunc(self):
status = ml.save_multiple(
skeys=self.skeys,
dict_artifacts={
"AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)),
"scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()),
"inference": VanillaAE(10),
"precrocessing": StandardScaler(),
"threshold": StdDevThreshold(),
},
dkeys=["unique", "sorted"],
**{"learning_rate": 0.01},
)
self.assertIsNotNone(status)
Expand All @@ -114,13 +116,23 @@ def test_save_multiple_models_pyfunc(self):
def test_load_multiple_models_when_pyfunc_model_exist(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys_list = [["AE", "infer"], ["scaler", "infer"]]
data = ml.load_multiple(skeys=skeys, dkeys_list=dkeys_list)
dkeys = ["unique", "sorted"]
ml.save_multiple(
skeys=self.skeys,
dict_artifacts={
"inference": VanillaAE(10),
"precrocessing": StandardScaler(),
"threshold": StdDevThreshold(),
},
dkeys=["unique", "sorted"],
**{"learning_rate": 0.01},
)
data = ml.load_multiple(skeys=skeys, dkeys=dkeys)
self.assertIsNotNone(data.metadata)
self.assertIsInstance(data, ArtifactData)
self.assertIsInstance(data.artifact, dict)
self.assertIsInstance(data.artifact["AE"].artifact, VanillaAE)
self.assertIsInstance(data.artifact["scaler"].artifact, StandardScaler)
self.assertIsInstance(data.artifact["inference"], VanillaAE)
self.assertIsInstance(data.artifact["precrocessing"], StandardScaler)

@patch("mlflow.sklearn.log_model", mock_log_model_sklearn)
@patch("mlflow.log_param", mock_log_state_dict)
Expand Down Expand Up @@ -466,10 +478,9 @@ def test_cache_loading(self):
def test_cache_loading_pyfunc(self):
cache_registry = LocalLRUCache(ttl=50000)
ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry)
dkeys_list = [["AE", "infer"], ["scaler", "infer"]]
ml.load_multiple(skeys=self.skeys, dkeys_list=dkeys_list)
unique_sorted_dkeys = ["AE", "infer", "scaler"]
key = MLflowRegistry.construct_key(self.skeys, unique_sorted_dkeys)
dkeys = ["unique", "sorted"]
ml.load_multiple(skeys=self.skeys, dkeys=dkeys)
key = MLflowRegistry.construct_key(self.skeys, dkeys)
self.assertIsNotNone(ml._load_from_cache(key))


Expand Down

0 comments on commit 69caf21

Please sign in to comment.