diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 75c4bd7a..65c6812c 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -23,12 +23,11 @@ from mlflow.exceptions import RestException from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST from mlflow.tracking import MlflowClient -from sortedcontainers import SortedSet from numalogic.registry import ArtifactManager, ArtifactData from numalogic.registry.artifact import ArtifactCache from numalogic.tools.exceptions import ModelVersionError -from numalogic.tools.types import KeyedArtifact, artifact_t, KEYS, META_VT +from numalogic.tools.types import artifact_t, KEYS, META_VT _LOGGER = logging.getLogger(__name__) @@ -192,34 +191,39 @@ def load( def load_multiple( self, skeys: KEYS, - dkeys_list: list[list[str]], + dkeys: KEYS, ) -> Optional[ArtifactData]: """ Load multiple artifacts from the registry for pyfunc models. Args: skeys (KEYS): The source keys of the artifacts to load. - dkeys_list (list[list[str]]): - A list of lists containing the dkeys of the artifacts to load. + dkeys: dynamic key fields as list/tuple of strings. Returns ------- Optional[ArtifactData]: The loaded ArtifactData object if available otherwise None. ArtifactData should contain a dictionary of artifacts. """ - dkeys = self.__get_sorted_unique_dkeys(dkeys_list) loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc") - if loaded_model is not None: - try: - unwrapped_composite_model = loaded_model.artifact.unwrap_python_model() - except Exception: - _LOGGER.exception("Error occurred while unwrapping python model") - return None - - dict_artifacts = unwrapped_composite_model.dict_artifacts - metadata = loaded_model.metadata - version_info = loaded_model.extras - return ArtifactData(artifact=dict_artifacts, metadata=metadata, extras=version_info) - return None + if loaded_model is None: + return None + if loaded_model.artifact.loader_module != "mlflow.pyfunc.model": + raise TypeError("The loaded model is not a valid pyfunc Python model.") + + try: + unwrapped_composite_model = loaded_model.artifact.unwrap_python_model() + except AttributeError: + _LOGGER.exception("The loaded model does not have an unwrap_python_model method") + return None + except Exception: + _LOGGER.exception("Unexpected error occurred while unwrapping python model.") + return None + else: + return ArtifactData( + artifact=unwrapped_composite_model.dict_artifacts, + metadata=loaded_model.metadata, + extras=loaded_model.extras, + ) @staticmethod def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None: @@ -278,7 +282,8 @@ def save( def save_multiple( self, skeys: KEYS, - dict_artifacts: dict[str, KeyedArtifact], + dkeys: KEYS, + dict_artifacts: dict[str, artifact_t], **metadata: META_VT, ) -> Optional[ModelVersion]: """ @@ -287,20 +292,22 @@ def save_multiple( Args: ---- - skeys: static key fields as list/tuple of strings - dict_artifacts: dict of artifacts to save - metadata: additional metadata surrounding the artifact that needs to be saved. + skeys (KEYS): Static key fields as a list or tuple of strings. + dkeys (KEYS): Dynamic key fields as a list or tuple of strings. + dict_artifacts (dict[str, artifact_t]): Dictionary of artifacts to save. + **metadata (META_VT): Additional metadata to be saved with the artifacts. Returns ------- - mlflow ModelVersion instance + Optional[ModelVersion]: An instance of the MLflow ModelVersion. + """ - multiple_artifacts = CompositeModels(skeys=skeys, dict_artifacts=dict_artifacts, **metadata) - dkeys_list = multiple_artifacts._get_dkeys_list() - sorted_dkeys = self.__get_sorted_unique_dkeys(dkeys_list) + if len(dict_artifacts) == 1: + _LOGGER.warning("Only one element in dict_artifacts. Please use save directly.") + multiple_artifacts = CompositeModel(skeys=skeys, dict_artifacts=dict_artifacts, **metadata) return self.save( skeys=skeys, - dkeys=sorted_dkeys, + dkeys=dkeys, artifact=multiple_artifacts, artifact_type="pyfunc", **metadata, @@ -407,23 +414,8 @@ def __load_artifacts( ) return model, metadata - def __get_sorted_unique_dkeys(self, dkeys_list: list[list]) -> list[str]: - """ - Returns a unique sorted list of all dkeys in the stored artifacts. - - Args: - ---- - dkeys_list: A list of lists containing the destination keys of the artifacts. - - Returns - ------- - List[str] - A list of all unique dkeys in the stored artifacts, sorted in ascending order. - """ - return list(SortedSet([dkey for dkeys in dkeys_list for dkey in dkeys])) - -class CompositeModels(mlflow.pyfunc.PythonModel): +class CompositeModel(mlflow.pyfunc.PythonModel): """A composite model that represents multiple artifacts. This class extends the `mlflow.pyfunc.PythonModel` class and is used to store and load @@ -438,7 +430,7 @@ class CompositeModels(mlflow.pyfunc.PythonModel): Methods ------- - get_dkeys_list(): Returns a list of all dynamic keys in the stored artifacts. + predict: Not implemented for our use case. Attributes ---------- @@ -450,21 +442,10 @@ class CompositeModels(mlflow.pyfunc.PythonModel): __slots__ = ("skeys", "dict_artifacts", "metadata") - def __init__(self, skeys: KEYS, dict_artifacts: dict[str, KeyedArtifact], **metadata: META_VT): + def __init__(self, skeys: KEYS, dict_artifacts: dict[str, artifact_t], **metadata: META_VT): self.skeys = skeys self.dict_artifacts = dict_artifacts self.metadata = metadata - def _get_dkeys_list(self): - """ - Returns a list of all dynamic keys in the stored artifacts. - - Returns - ------- - list[list[str]]: A list of all dynamic keys in the stored artifacts. - """ - dkeys_list = [] - artifacts = self.dict_artifacts.values() - for artifact in artifacts: - dkeys_list.append(artifact.dkeys) - return dkeys_list + def predict(self): + raise NotImplementedError() diff --git a/tests/registry/_mlflow_utils.py b/tests/registry/_mlflow_utils.py index 1ffc0830..ca82f246 100644 --- a/tests/registry/_mlflow_utils.py +++ b/tests/registry/_mlflow_utils.py @@ -13,8 +13,7 @@ from numalogic.models.autoencoder.variants.vanilla import VanillaAE from numalogic.models.threshold import StdDevThreshold -from numalogic.registry.mlflow_registry import CompositeModels -from numalogic.tools.types import KeyedArtifact +from numalogic.registry.mlflow_registry import CompositeModel def create_model(): @@ -192,11 +191,12 @@ def mock_load_model_pyfunc(*_, **__): return mlflow.pyfunc.PyFuncModel( model_meta=model, model_impl=TestObject( - python_model=CompositeModels( - skeys=["model"], + python_model=CompositeModel( + skeys=["error"], dict_artifacts={ - "AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)), - "scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()), + "inference": VanillaAE(10), + "precrocessing": StandardScaler(), + "threshold": StdDevThreshold(), }, **{"learning_rate": 0.01}, ) diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index 611d803a..e0fa686d 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -11,11 +11,11 @@ from sklearn.preprocessing import StandardScaler from numalogic.models.autoencoder.variants import VanillaAE +from numalogic.models.threshold._std import StdDevThreshold from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache from numalogic.registry.mlflow_registry import ModelStage -from numalogic.tools.types import KeyedArtifact from tests.registry._mlflow_utils import ( mock_load_model_pyfunc, mock_log_model_pyfunc, @@ -93,9 +93,11 @@ def test_save_multiple_models_pyfunc(self): status = ml.save_multiple( skeys=self.skeys, dict_artifacts={ - "AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)), - "scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()), + "inference": VanillaAE(10), + "precrocessing": StandardScaler(), + "threshold": StdDevThreshold(), }, + dkeys=["unique", "sorted"], **{"learning_rate": 0.01}, ) self.assertIsNotNone(status) @@ -114,13 +116,23 @@ def test_save_multiple_models_pyfunc(self): def test_load_multiple_models_when_pyfunc_model_exist(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys - dkeys_list = [["AE", "infer"], ["scaler", "infer"]] - data = ml.load_multiple(skeys=skeys, dkeys_list=dkeys_list) + dkeys = ["unique", "sorted"] + ml.save_multiple( + skeys=self.skeys, + dict_artifacts={ + "inference": VanillaAE(10), + "precrocessing": StandardScaler(), + "threshold": StdDevThreshold(), + }, + dkeys=["unique", "sorted"], + **{"learning_rate": 0.01}, + ) + data = ml.load_multiple(skeys=skeys, dkeys=dkeys) self.assertIsNotNone(data.metadata) self.assertIsInstance(data, ArtifactData) self.assertIsInstance(data.artifact, dict) - self.assertIsInstance(data.artifact["AE"].artifact, VanillaAE) - self.assertIsInstance(data.artifact["scaler"].artifact, StandardScaler) + self.assertIsInstance(data.artifact["inference"], VanillaAE) + self.assertIsInstance(data.artifact["precrocessing"], StandardScaler) @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) @patch("mlflow.log_param", mock_log_state_dict) @@ -466,10 +478,9 @@ def test_cache_loading(self): def test_cache_loading_pyfunc(self): cache_registry = LocalLRUCache(ttl=50000) ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry) - dkeys_list = [["AE", "infer"], ["scaler", "infer"]] - ml.load_multiple(skeys=self.skeys, dkeys_list=dkeys_list) - unique_sorted_dkeys = ["AE", "infer", "scaler"] - key = MLflowRegistry.construct_key(self.skeys, unique_sorted_dkeys) + dkeys = ["unique", "sorted"] + ml.load_multiple(skeys=self.skeys, dkeys=dkeys) + key = MLflowRegistry.construct_key(self.skeys, dkeys) self.assertIsNotNone(ml._load_from_cache(key))