From bc1c627cf89e9a41fa2194df3a2a1ed33e5d989e Mon Sep 17 00:00:00 2001 From: Kushal Batra <34571348+s0nicboOm@users.noreply.github.com> Date: Mon, 18 Sep 2023 09:46:33 -0700 Subject: [PATCH] feat: add multiple save for redis registry (#281) 1) Option to multisave in redis registry 2) support caching for version calls. 3) In udfs/ load models using version call instead of latest calls. 4) call multiple_save in artifacts to save artifacts in a transaction. --------- Signed-off-by: s0nicboOm --- .github/workflows/ci.yml | 2 +- .github/workflows/coverage.yml | 2 +- .github/workflows/pypi.yml | 2 +- numalogic/connectors/druid.py | 2 +- numalogic/registry/redis_registry.py | 124 ++++++++++++++++++++----- numalogic/tools/types.py | 11 ++- numalogic/udfs/inference.py | 54 ++--------- numalogic/udfs/postprocess.py | 11 ++- numalogic/udfs/preprocess.py | 6 +- numalogic/udfs/tools.py | 64 ++++++++++--- numalogic/udfs/trainer.py | 90 +++++++++--------- poetry.lock | 43 +++++++-- pyproject.toml | 2 +- tests/registry/test_mlflow_registry.py | 21 +++-- tests/registry/test_redis_registry.py | 27 +++++- tests/udfs/test_inference.py | 16 +++- tests/udfs/test_postprocess.py | 3 +- tests/udfs/test_preprocess.py | 6 +- tests/udfs/utility.py | 19 ++-- 19 files changed, 333 insertions(+), 172 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 95a3e3e4..be49353b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: - uses: actions/checkout@v3 - name: Install poetry - run: pipx install poetry==1.5.1 + run: pipx install poetry==1.6.1 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index e6096865..2f1b6b00 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -19,7 +19,7 @@ jobs: - uses: actions/checkout@v3 - name: Install poetry - run: pipx install poetry==1.5.1 + run: pipx install poetry==1.6.1 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 0189ea15..3e435c39 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -19,7 +19,7 @@ jobs: - uses: actions/checkout@v3 - name: Install poetry - run: pipx install poetry==1.5.1 + run: pipx install poetry==1.6.1 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 diff --git a/numalogic/connectors/druid.py b/numalogic/connectors/druid.py index 28449d6a..5a528138 100644 --- a/numalogic/connectors/druid.py +++ b/numalogic/connectors/druid.py @@ -52,7 +52,7 @@ def fetch( end_dt = datetime.now(pytz.utc) start_dt = end_dt - timedelta(hours=hours) - intervals = f"{start_dt.isoformat()}/{end_dt.isoformat()}" + intervals = [f"{start_dt.isoformat()}/{end_dt.isoformat()}"] dimension_specs = map(lambda d: DimensionSpec(dimension=d, output_name=d), dimensions) diff --git a/numalogic/registry/redis_registry.py b/numalogic/registry/redis_registry.py index 9ef98e34..f19c759e 100644 --- a/numalogic/registry/redis_registry.py +++ b/numalogic/registry/redis_registry.py @@ -14,13 +14,14 @@ from datetime import datetime, timedelta from typing import Optional import orjson +import redis.client from redis.exceptions import RedisError from numalogic.registry.artifact import ArtifactManager, ArtifactData, ArtifactCache from numalogic.registry._serialize import loads, dumps from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError -from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT +from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT, KeyedArtifact _LOGGER = logging.getLogger(__name__) @@ -33,6 +34,8 @@ class RedisRegistry(ArtifactManager): client: Take in the redis client already established/created ttl: Total Time to Live (in seconds) for the key when saving in redis (dafault = 604800) cache_registry: Cache registry to use (default = None). + transactional: Flag to indicate if the registry should be transactional or + not (default = False). Examples -------- @@ -48,18 +51,20 @@ class RedisRegistry(ArtifactManager): >>> loaded_artifact = registry.load(skeys, dkeys) """ - __slots__ = ("client", "ttl", "cache_registry") + __slots__ = ("client", "ttl", "cache_registry", "transactional") def __init__( self, client: redis_client_t, ttl: int = 604800, cache_registry: Optional[ArtifactCache] = None, + transactional: bool = True, ): super().__init__("") self.client = client self.ttl = ttl self.cache_registry = cache_registry + self.transactional = transactional @staticmethod def construct_key(skeys: KEYS, dkeys: KEYS) -> str: @@ -155,30 +160,37 @@ def __load_latest_artifact(self, key: str) -> tuple[ArtifactData, bool]: ------ ModelKeyNotFound: If the model key is not found in the registry. """ - cached_artifact = self._load_from_cache(key) + latest_key = self.__construct_latest_key(key) + cached_artifact = self._load_from_cache(latest_key) if cached_artifact: - _LOGGER.debug("Found cached artifact for key: %s", key) + _LOGGER.debug("Found cached artifact for key: %s", latest_key) return cached_artifact, True - latest_key = self.__construct_latest_key(key) if not self.client.exists(latest_key): raise ModelKeyNotFound(f"latest key: {latest_key}, Not Found !!!") model_key = self.client.get(latest_key) _LOGGER.debug("latest key, %s, is pointing to the key : %s", latest_key, model_key) - return ( - self.__load_version_artifact(version=self.get_version(model_key.decode()), key=key), - False, + artifact, _ = self.__load_version_artifact( + version=self.get_version(model_key.decode()), key=key ) + return artifact, False - def __load_version_artifact(self, version: str, key: str) -> ArtifactData: - model_key = self.__construct_version_key(key, version) - if not self.client.exists(model_key): - raise ModelKeyNotFound("Could not find model key with key: %s" % model_key) - return self.__get_artifact_data( - model_key=model_key, + def __load_version_artifact(self, version: str, key: str) -> tuple[ArtifactData, bool]: + version_key = self.__construct_version_key(key, version) + cached_artifact = self._load_from_cache(version_key) + if cached_artifact: + _LOGGER.debug("Found cached version artifact for key: %s", version_key) + return cached_artifact, True + if not self.client.exists(version_key): + raise ModelKeyNotFound(f"Could not find model key with key: {version_key}") + return ( + self.__get_artifact_data( + model_key=version_key, + ), + False, ) def __save_artifact( - self, pipe, artifact: artifact_t, metadata: META_T, key: KEYS, version: str + self, pipe, artifact: artifact_t, key: KEYS, version: str, **metadata: META_T ) -> str: new_version_key = self.__construct_version_key(key, version) latest_key = self.__construct_latest_key(key) @@ -210,6 +222,8 @@ def load( is needed to load the respective artifact. If cache registry is provided, it will first check the cache registry for the artifact. + If latest is passed, latest key is saved otherwise version call saves the respective + version artifact in cache. Args: ---- @@ -230,17 +244,27 @@ def load( if (latest and version) or (not latest and not version): raise ValueError("Either One of 'latest' or 'version' needed in load method call") key = self.construct_key(skeys, dkeys) - is_cached = False try: if latest: artifact_data, is_cached = self.__load_latest_artifact(key) else: - artifact_data = self.__load_version_artifact(version, key) + artifact_data, is_cached = self.__load_version_artifact(version, key) except RedisError as err: raise RedisRegistryError(f"{err.__class__.__name__} raised") from err else: - if (not is_cached) and latest: - self._save_in_cache(key, artifact_data) + if not is_cached: + if latest: + _LOGGER.debug( + "Saving %s, in cache as %s", self.__construct_latest_key(key), key + ) + self._save_in_cache(self.__construct_latest_key(key), artifact_data) + else: + _LOGGER.info( + "Saving %s, in cache as %s", + self.__construct_version_key(key, version), + key, + ) + self._save_in_cache(self.__construct_version_key(key, version), artifact_data) return artifact_data def save( @@ -248,6 +272,7 @@ def save( skeys: KEYS, dkeys: KEYS, artifact: artifact_t, + _pipe: Optional[redis.client.Pipeline] = None, **metadata: META_VT, ) -> Optional[str]: """Saves the artifact into redis registry and updates version. @@ -257,6 +282,7 @@ def save( skeys: static key fields as list/tuple of strings dkeys: dynamic key fields as list/tuple of strings artifact: primary artifact to be saved + _pipe: RedisPipeline object metadata: additional metadata surrounding the artifact that needs to be saved. Returns @@ -275,10 +301,15 @@ def save( _LOGGER.debug("Latest key: %s exists for the model", latest_key) version_key = self.client.get(name=latest_key) version = int(self.get_version(version_key.decode())) + 1 - with self.client.pipeline() as pipe: - new_version_key = self.__save_artifact(pipe, artifact, metadata, key, str(version)) - pipe.expire(name=new_version_key, time=self.ttl) - pipe.execute() + _redis_pipe = ( + self.client.pipeline(transaction=self.transactional) if _pipe is None else _pipe + ) + new_version_key = self.__save_artifact( + pipe=_redis_pipe, artifact=artifact, key=key, version=str(version), **metadata + ) + _redis_pipe.expire(name=new_version_key, time=self.ttl) + if _pipe is None: + _redis_pipe.execute() except RedisError as err: raise RedisRegistryError(f"{err.__class__.__name__} raised") from err else: @@ -306,7 +337,7 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None: self.client.delete(del_key) else: raise ModelKeyNotFound( - "Key to delete: %s, Not Found!" % del_key, + f"Key to delete: {del_key}, Not Found!", ) except RedisError as err: raise RedisRegistryError(f"{err.__class__.__name__} raised") from err @@ -338,3 +369,48 @@ def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool: raise RedisRegistryError("Error fetching timestamp information") from err stale_ts = (datetime.now() - timedelta(hours=freq_hr)).timestamp() return stale_ts > artifact_ts + + def save_multiple( + self, + skeys: KEYS, + dict_artifacts: dict[str, KeyedArtifact], + **metadata: META_VT, + ): + """ + Saves multiple artifacts into redis registry. The last save stores all the + artifact versions in the metadata. + + Args: + ---- + skeys: static key fields as list/tuple of strings + dict_artifacts: dict of artifacts to save + metadata: additional metadata surrounding the artifact that needs to be saved. + """ + dict_model_ver = {} + try: + with self.client.pipeline(transaction=self.transactional) as pipe: + pipe.multi() + for key, value in dict_artifacts.items(): + dict_model_ver[":".join(value.dkeys)] = self.save( + skeys=skeys, + dkeys=value.dkeys, + artifact=value.artifact, + _pipe=pipe, + **metadata, + ) + + if len(dict_artifacts) == len(dict_model_ver): + self.save( + skeys=skeys, + dkeys=value.dkeys, + artifact=value.artifact, + _pipe=pipe, + artifact_versions=dict_model_ver, + **metadata, + ) + pipe.execute() + _LOGGER.info("Successfully saved all the artifacts with: %s", dict_model_ver) + except RedisError as err: + raise RedisRegistryError(f"{err.__class__.__name__} raised") from err + else: + return dict_model_ver diff --git a/numalogic/tools/types.py b/numalogic/tools/types.py index b22e4bc4..e4c9f461 100644 --- a/numalogic/tools/types.py +++ b/numalogic/tools/types.py @@ -8,10 +8,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from collections.abc import Sequence -from typing import Union, TypeVar, ClassVar +from typing import Union, TypeVar, ClassVar, NamedTuple from sklearn.base import BaseEstimator from torch import Tensor @@ -40,6 +38,13 @@ KEYS = TypeVar("KEYS", bound=Sequence[str], covariant=True) +class KeyedArtifact(NamedTuple): + r"""namedtuple for artifacts.""" + + dkeys: KEYS + artifact: artifact_t + + class Singleton(type): r"""Helper metaclass to use as a Singleton class.""" diff --git a/numalogic/udfs/inference.py b/numalogic/udfs/inference.py index 77eed00a..285ca539 100644 --- a/numalogic/udfs/inference.py +++ b/numalogic/udfs/inference.py @@ -12,17 +12,19 @@ from numalogic.config import RegistryFactory from numalogic.registry import LocalLRUCache, ArtifactData -from numalogic.tools.exceptions import RedisRegistryError, ModelKeyNotFound, ConfigNotFoundError +from numalogic.tools.exceptions import ConfigNotFoundError from numalogic.tools.types import artifact_t, redis_client_t from numalogic.udfs._base import NumalogicUDF from numalogic.udfs._config import StreamConf, PipelineConf from numalogic.udfs.entities import StreamPayload, Header, Status +from numalogic.udfs.tools import _load_artifact _LOGGER = logging.getLogger(__name__) # TODO: move to config LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600")) LOCAL_CACHE_SIZE = int(os.getenv("LOCAL_CACHE_SIZE", "10000")) +LOAD_LATEST = os.getenv("LOAD_LATEST", "false").lower() == "true" class InferenceUDF(NumalogicUDF): @@ -114,8 +116,13 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: # Forward payload if a training request is tagged if payload.header == Header.TRAIN_REQUEST: return Messages(Message(keys=keys, value=payload.to_json())) - - artifact_data = self.load_artifact(keys, payload) + artifact_data, payload = _load_artifact( + skeys=keys, + dkeys=[self.get_conf(payload.config_id).numalogic_conf.model.name], + payload=payload, + model_registry=self.model_registry, + load_latest=LOAD_LATEST, + ) # TODO: revisit retraining logic # Send training request if artifact loading is not successful @@ -165,47 +172,6 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: ) return Messages(Message(keys=keys, value=payload.to_json())) - def load_artifact(self, keys: list[str], payload: StreamPayload) -> Optional[ArtifactData]: - """ - Load inference artifact from the registry. - - Args: - keys: List of keys - payload: StreamPayload object - - Returns - ------- - ArtifactData instance - """ - _conf = self.get_conf(payload.config_id).numalogic_conf - try: - artifact_data = self.model_registry.load( - skeys=keys, - dkeys=[_conf.model.name], - ) - except ModelKeyNotFound: - _LOGGER.warning( - "%s - Model key not found for Keys: %s, Metric: %s", - payload.uuid, - payload.composite_keys, - payload.metrics, - ) - return None - except RedisRegistryError: - _LOGGER.exception( - "%s - Error while fetching inference artifact, Keys: %s, Metric: %s", - payload.uuid, - payload.composite_keys, - payload.metrics, - ) - return None - _LOGGER.info( - "%s - Loaded artifact data from %s", - payload.uuid, - artifact_data.extras.get("source"), - ) - return artifact_data - def is_model_stale(self, artifact_data: ArtifactData, payload: StreamPayload) -> bool: """ Check if the inference artifact is stale. diff --git a/numalogic/udfs/postprocess.py b/numalogic/udfs/postprocess.py index 1eae5fee..f365039b 100644 --- a/numalogic/udfs/postprocess.py +++ b/numalogic/udfs/postprocess.py @@ -16,11 +16,12 @@ from numalogic.udfs import NumalogicUDF from numalogic.udfs._config import StreamConf, PipelineConf from numalogic.udfs.entities import StreamPayload, Header, Status, TrainerPayload, OutputPayload -from numalogic.udfs.tools import _load_model +from numalogic.udfs.tools import _load_artifact # TODO: move to config LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600")) LOCAL_CACHE_SIZE = int(os.getenv("LOCAL_CACHE_SIZE", "10000")) +LOAD_LATEST = os.getenv("LOAD_LATEST", "false").lower() == "true" _LOGGER = logging.getLogger(__name__) @@ -88,8 +89,12 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: postprocess_cfg = self.get_conf(payload.config_id).numalogic_conf.postprocess # load artifact - thresh_artifact = _load_model( - skeys=keys, dkeys=[thresh_cfg.name], payload=payload, model_registry=self.model_registry + thresh_artifact, payload = _load_artifact( + skeys=keys, + dkeys=[thresh_cfg.name], + payload=payload, + model_registry=self.model_registry, + load_latest=LOAD_LATEST, ) postproc_clf = self.postproc_factory.get_instance(postprocess_cfg) diff --git a/numalogic/udfs/preprocess.py b/numalogic/udfs/preprocess.py index 8c1ab887..1b0e3dd5 100644 --- a/numalogic/udfs/preprocess.py +++ b/numalogic/udfs/preprocess.py @@ -16,11 +16,12 @@ from numalogic.udfs import NumalogicUDF from numalogic.udfs._config import StreamConf, PipelineConf from numalogic.udfs.entities import Status, Header -from numalogic.udfs.tools import make_stream_payload, get_df, _load_model +from numalogic.udfs.tools import make_stream_payload, get_df, _load_artifact # TODO: move to config LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600")) LOCAL_CACHE_SIZE = int(os.getenv("LOCAL_CACHE_SIZE", "10000")) +LOAD_LATEST = os.getenv("LOAD_LATEST", "false").lower() == "true" _LOGGER = logging.getLogger(__name__) @@ -108,7 +109,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: if any( [_conf.stateful for _conf in self.get_conf(payload.config_id).numalogic_conf.preprocess] ): - preproc_artifact = _load_model( + preproc_artifact, payload = _load_artifact( skeys=keys, dkeys=[ _conf.name @@ -116,6 +117,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: ], payload=payload, model_registry=self.model_registry, + load_latest=LOAD_LATEST, ) if preproc_artifact: preproc_clf = preproc_artifact.artifact diff --git a/numalogic/udfs/tools.py b/numalogic/udfs/tools.py index 04515a6c..d9cd39fb 100644 --- a/numalogic/udfs/tools.py +++ b/numalogic/udfs/tools.py @@ -1,4 +1,5 @@ import logging +from dataclasses import replace from typing import Optional import numpy as np @@ -64,10 +65,14 @@ def make_stream_payload( ) -# TODO: move to base NumalogicUDF class -def _load_model( - skeys: KEYS, dkeys: KEYS, payload: StreamPayload, model_registry: ArtifactManager -) -> Optional[ArtifactData]: +# TODO: move to base NumalogicUDF class and look into payload mutation +def _load_artifact( + skeys: KEYS, + dkeys: KEYS, + payload: StreamPayload, + model_registry: ArtifactManager, + load_latest: bool, +) -> tuple[Optional[ArtifactData], StreamPayload]: """ Load artifact from redis Args: @@ -79,26 +84,35 @@ def _load_model( Returns ------- artifact_t object + StreamPayload object """ - try: - artifact = model_registry.load(skeys, dkeys) + version_to_load = "-1" + if payload.metadata and "artifact_versions" in payload.metadata: + version_to_load = payload.metadata["artifact_versions"][":".join(dkeys)] + _LOGGER.info("%s - Found version info for keys: %s, %s", payload.uuid, skeys, dkeys) + else: _LOGGER.info( - "%s - Loaded Model. Source: %s , version: %s, Keys: %s, %s", + "%s - No version info passed on! Loading latest artifact version for Keys: %s", payload.uuid, - artifact.extras.get("source"), - artifact.extras.get("version"), skeys, - dkeys, ) + load_latest = True + try: + if load_latest: + artifact = model_registry.load(skeys=skeys, dkeys=dkeys) + else: + artifact = model_registry.load( + skeys=skeys, dkeys=dkeys, latest=False, version=version_to_load + ) except RedisRegistryError: - _LOGGER.exception( - "%s - Error while fetching preproc artifact, Keys: %s, Metrics: %s", + _LOGGER.warning( + "%s - Error while fetching artifact, Keys: %s, Metrics: %s", payload.uuid, skeys, payload.metrics, ) - return None + return None, payload except Exception: _LOGGER.exception( @@ -107,6 +121,26 @@ def _load_model( payload.composite_keys, payload.metrics, ) - return None + return None, payload else: - return artifact + _LOGGER.info( + "%s - Loaded Model. Source: %s , version: %s, Keys: %s, %s", + payload.uuid, + artifact.extras.get("source"), + artifact.extras.get("version"), + skeys, + dkeys, + ) + if ( + artifact.metadata + and "artifact_versions" in artifact.metadata + and "artifact_versions" not in payload.metadata + ): + payload = replace( + payload, + metadata={ + "artifact_versions": artifact.metadata["artifact_versions"], + **payload.metadata, + }, + ) + return artifact, payload diff --git a/numalogic/udfs/trainer.py b/numalogic/udfs/trainer.py index 63bc77b2..29ab27b4 100644 --- a/numalogic/udfs/trainer.py +++ b/numalogic/udfs/trainer.py @@ -18,10 +18,10 @@ from numalogic.models.autoencoder import AutoencoderTrainer from numalogic.tools.data import StreamingDataset from numalogic.tools.exceptions import ConfigNotFoundError, RedisRegistryError -from numalogic.tools.types import redis_client_t, artifact_t +from numalogic.tools.types import redis_client_t, artifact_t, KEYS, KeyedArtifact from numalogic.udfs import NumalogicUDF from numalogic.udfs._config import StreamConf, PipelineConf -from numalogic.udfs.entities import TrainerPayload +from numalogic.udfs.entities import TrainerPayload, StreamPayload _LOGGER = logging.getLogger(__name__) @@ -194,24 +194,24 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: # Save artifacts # TODO perform multi-save here - self.save_artifact( - artifacts["preproc_clf"], - skeys=payload.composite_keys, - dkeys=[_conf.name for _conf in _conf.numalogic_conf.preprocess], - uuid=payload.uuid, - ) - self.save_artifact( - artifacts["model"], - skeys=payload.composite_keys, - dkeys=[_conf.numalogic_conf.model.name], - uuid=payload.uuid, - train_size=x_train.shape[0], - ) - self.save_artifact( - artifacts["threshold_clf"], - skeys=payload.composite_keys, - dkeys=[_conf.numalogic_conf.threshold.name], - uuid=payload.uuid, + skeys = payload.composite_keys + dict_artifacts = { + "postproc": KeyedArtifact( + dkeys=[_conf.numalogic_conf.threshold.name], artifact=artifacts["threshold_clf"] + ), + "inference": KeyedArtifact( + dkeys=[_conf.numalogic_conf.model.name], artifact=artifacts["model"] + ), + "preproc": KeyedArtifact( + dkeys=[_conf.name for _conf in _conf.numalogic_conf.preprocess], + artifact=artifacts["preproc_clf"], + ), + } + self.artifacts_to_save( + skeys=skeys, + dict_artifacts=dict_artifacts, + model_registry=self.model_registry, + payload=payload, ) _LOGGER.debug( @@ -230,39 +230,41 @@ def _construct_preproc_clf(self, _conf: StreamConf) -> Optional[artifact_t]: return preproc_clfs[0] return make_pipeline(*preproc_clfs) - def save_artifact( - self, artifact: artifact_t, skeys: list[str], dkeys: list[str], uuid: str, **metadata + @staticmethod + def artifacts_to_save( + skeys: KEYS, + dict_artifacts: dict[str, KeyedArtifact], + model_registry, + payload: StreamPayload, ) -> None: """ - Save artifact to the registry. - + Save artifacts. Args: - artifact: Artifact to save - skeys: List of keys - dkeys: List of dkeys - uuid: UUID - **metadata: Additional metadata + _______ + skeys: list keys + dict_artifacts: artifact_tuple which has dkeys and artifact as fields + model_registry: registry that supports multiple_save + payload: payload. + + Returns + ------- + Tuple of keys and artifacts + """ - if not artifact: - return - # TODO check for statelessness from config - if isinstance(artifact, StatelessTransformer): - _LOGGER.info("%s - Skipping save for stateless artifact with dkeys: %s", uuid, dkeys) - return + for key, value in dict_artifacts.items(): + if value.artifact: + if isinstance(value.artifact, StatelessTransformer): + del dict_artifacts[key] try: - version = self.model_registry.save( + ver_dict = model_registry.save_multiple( skeys=skeys, - dkeys=dkeys, - artifact=artifact, - uuid=uuid, - **metadata, + dict_artifacts=dict_artifacts, + uuid=payload.uuid, ) except RedisRegistryError: - _LOGGER.exception("%s - Error while saving Model with skeys: %s", uuid, skeys) + _LOGGER.exception("%s - Error while saving Model with skeys: %s", payload.uuid, skeys) else: - _LOGGER.info( - "%s - Artifact saved with dkeys: %s with version: %s", uuid, dkeys, version - ) + _LOGGER.info("%s - Artifact saved with with versions: %s", payload.uuid, ver_dict) def _is_data_sufficient(self, payload: TrainerPayload, df: pd.DataFrame) -> bool: _conf = self.get_conf(payload.config_id) diff --git a/poetry.lock b/poetry.lock index 45e0604c..2ec60308 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1925,6 +1925,7 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, + {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -2154,13 +2155,13 @@ test = ["coverage", "jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-cov", [[package]] name = "jupyterlab" -version = "4.0.5" +version = "4.0.6" description = "JupyterLab computational environment" optional = false python-versions = ">=3.8" files = [ - {file = "jupyterlab-4.0.5-py3-none-any.whl", hash = "sha256:13b3a326e7b95d72746fe20dbe80ee1e71165d6905e01ceaf1320eb809cb1b47"}, - {file = "jupyterlab-4.0.5.tar.gz", hash = "sha256:de49deb75f9b9aec478ed04754cbefe9c5d22fd796a5783cdc65e212983d3611"}, + {file = "jupyterlab-4.0.6-py3-none-any.whl", hash = "sha256:7d9dacad1e3f30fe4d6d4efc97fda25fbb5012012b8f27cc03a2283abcdee708"}, + {file = "jupyterlab-4.0.6.tar.gz", hash = "sha256:6c43ae5a6a1fd2fdfafcb3454004958bde6da76331abb44cffc6f9e436b19ba1"}, ] [package.dependencies] @@ -2179,8 +2180,8 @@ tornado = ">=6.2.0" traitlets = "*" [package.extras] -dev = ["black[jupyter] (==23.3.0)", "build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.0.271)"] -docs = ["jsx-lexer", "myst-parser", "pydata-sphinx-theme (>=0.13.0)", "pytest", "pytest-check-links", "pytest-tornasync", "sphinx (>=1.8)", "sphinx-copybutton"] +dev = ["black[jupyter] (==23.7.0)", "build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.0.286)"] +docs = ["jsx-lexer", "myst-parser", "pydata-sphinx-theme (>=0.13.0)", "pytest", "pytest-check-links", "pytest-tornasync", "sphinx (>=1.8,<7.2.0)", "sphinx-copybutton"] docs-screenshots = ["altair (==5.0.1)", "ipython (==8.14.0)", "ipywidgets (==8.0.6)", "jupyterlab-geojson (==3.4.0)", "jupyterlab-language-pack-zh-cn (==4.0.post0)", "matplotlib (==3.7.1)", "nbconvert (>=7.0.0)", "pandas (==2.0.2)", "scipy (==1.10.1)", "vega-datasets (==0.9.0)"] test = ["coverage", "pytest (>=7.0)", "pytest-check-links (>=0.7)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter (>=0.5.3)", "pytest-timeout", "pytest-tornasync", "requests", "requests-cache", "virtualenv"] @@ -2425,6 +2426,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -2553,13 +2564,13 @@ files = [ [[package]] name = "mlflow-skinny" -version = "2.6.0" +version = "2.7.0" description = "MLflow: A Platform for ML Development and Productionization" optional = true python-versions = ">=3.8" files = [ - {file = "mlflow-skinny-2.6.0.tar.gz", hash = "sha256:c1a71bc4abb83169d4cff012c1c050c203775d4a7462d059c4523a481f0a9688"}, - {file = "mlflow_skinny-2.6.0-py3-none-any.whl", hash = "sha256:b63c6555645af0196b3c92da849968189dfd72d2c1397f7b149593db6b404d66"}, + {file = "mlflow-skinny-2.7.0.tar.gz", hash = "sha256:c4ce24817227219ac03f674bd3d0f371b274807e7e3845b3d46ba79ca86dad08"}, + {file = "mlflow_skinny-2.7.0-py3-none-any.whl", hash = "sha256:0dd3855b07bfce5f1bc403b6505e84f0e165780aba5d7edcca2024cff7ca7309"}, ] [package.dependencies] @@ -2580,7 +2591,7 @@ sqlparse = ">=0.4.0,<1" aliyun-oss = ["aliyunstoreplugin"] databricks = ["azure-storage-file-datalake (>12)", "boto3 (>1)", "google-cloud-storage (>=1.30.0)"] extras = ["azureml-core (>=1.2.0)", "boto3", "google-cloud-storage (>=1.30.0)", "kubernetes", "mlserver (>=1.2.0,!=1.3.1)", "mlserver-mlflow (>=1.2.0,!=1.3.1)", "prometheus-flask-exporter", "pyarrow", "pysftp", "requests-auth-aws-sigv4", "virtualenv"] -gateway = ["aiohttp (<4)", "fastapi (<1)", "psutil (<6)", "pydantic (>=1.0,<2)", "uvicorn[standard] (<1)", "watchfiles (<1)"] +gateway = ["aiohttp (<4)", "fastapi (<1)", "psutil (<6)", "pydantic (>=1.0,<3)", "uvicorn[standard] (<1)", "watchfiles (<1)"] sqlserver = ["mlflow-dbstore"] [[package]] @@ -3302,6 +3313,7 @@ files = [ {file = "Pillow-10.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3b08d4cc24f471b2c8ca24ec060abf4bebc6b144cb89cba638c720546b1cf538"}, {file = "Pillow-10.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d737a602fbd82afd892ca746392401b634e278cb65d55c4b7a8f48e9ef8d008d"}, {file = "Pillow-10.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:3a82c40d706d9aa9734289740ce26460a11aeec2d9c79b7af87bb35f0073c12f"}, + {file = "Pillow-10.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:bc2ec7c7b5d66b8ec9ce9f720dbb5fa4bace0f545acd34870eff4a369b44bf37"}, {file = "Pillow-10.0.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:d80cf684b541685fccdd84c485b31ce73fc5c9b5d7523bf1394ce134a60c6883"}, {file = "Pillow-10.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:76de421f9c326da8f43d690110f0e79fe3ad1e54be811545d7d91898b4c8493e"}, {file = "Pillow-10.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81ff539a12457809666fef6624684c008e00ff6bf455b4b89fd00a140eecd640"}, @@ -3311,6 +3323,7 @@ files = [ {file = "Pillow-10.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d50b6aec14bc737742ca96e85d6d0a5f9bfbded018264b3b70ff9d8c33485551"}, {file = "Pillow-10.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:00e65f5e822decd501e374b0650146063fbb30a7264b4d2744bdd7b913e0cab5"}, {file = "Pillow-10.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:f31f9fdbfecb042d046f9d91270a0ba28368a723302786c0009ee9b9f1f60199"}, + {file = "Pillow-10.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:1ce91b6ec08d866b14413d3f0bbdea7e24dfdc8e59f562bb77bc3fe60b6144ca"}, {file = "Pillow-10.0.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:349930d6e9c685c089284b013478d6f76e3a534e36ddfa912cde493f235372f3"}, {file = "Pillow-10.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3a684105f7c32488f7153905a4e3015a3b6c7182e106fe3c37fbb5ef3e6994c3"}, {file = "Pillow-10.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4f69b3700201b80bb82c3a97d5e9254084f6dd5fb5b16fc1a7b974260f89f43"}, @@ -3779,6 +3792,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -3786,8 +3800,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -3804,6 +3825,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -3811,6 +3833,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -4982,4 +5005,4 @@ redis = ["redis"] [metadata] lock-version = "2.0" python-versions = ">=3.9, <3.12" -content-hash = "292d29c73f389c027915552f126fb3dc3ebeaca536ca7244ffa83d94841483a6" +content-hash = "abcfdf1c496915cc32f58d90aa82892c956789bc55c045a814481445ab3f4f65" diff --git a/pyproject.toml b/pyproject.toml index c709b900..c907e947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ pynumaflow = "~0.4" # extras -mlflow-skinny = { version = "~2.6", optional = true } +mlflow-skinny = { version = "^2.0", optional = true } redis = {extras = ["hiredis"], version = "^5.0", optional = true} boto3 = { version = "^1.24.64", optional = true } pydruid = {version = "^0.6", optional = true} diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index 6c96a513..dcf41c37 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -7,7 +7,7 @@ from mlflow.exceptions import RestException from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST, ErrorCode, RESOURCE_LIMIT_EXCEEDED from mlflow.store.entities import PagedList -from sklearn.ensemble import RandomForestRegressor +from sklearn.preprocessing import StandardScaler from numalogic.models.autoencoder.variants import VanillaAE from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache @@ -124,16 +124,25 @@ def test_load_model_when_pytorch_model_exist2(self): @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) - @patch("mlflow.sklearn.load_model", Mock(return_value=RandomForestRegressor())) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata())) + @patch.object( + MLflowRegistry, + "load", + Mock( + return_value=ArtifactData( + artifact=StandardScaler(), extras={"metric": ["error"]}, metadata={} + ) + ), + ) def test_load_model_when_sklearn_model_exist(self): - model = self.model_sklearn ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys - ml.save(skeys=skeys, dkeys=dkeys, artifact=model) - data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="sklearn") - self.assertIsInstance(data.artifact, RandomForestRegressor) + scaler = StandardScaler() + ml.save(skeys=skeys, dkeys=dkeys, artifact=scaler) + data = ml.load(skeys=skeys, dkeys=dkeys) + print(data) + self.assertIsInstance(data.artifact, StandardScaler) self.assertEqual(data.metadata, {}) @patch("mlflow.pytorch.log_model", mock_log_model_pytorch()) diff --git a/tests/registry/test_redis_registry.py b/tests/registry/test_redis_registry.py index 59c3fdb6..de252cf1 100644 --- a/tests/registry/test_redis_registry.py +++ b/tests/registry/test_redis_registry.py @@ -13,6 +13,7 @@ from numalogic.models.autoencoder.variants import VanillaAE from numalogic.registry import RedisRegistry, LocalLRUCache, ArtifactData from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError +from numalogic.tools.types import KeyedArtifact logging.basicConfig(level=logging.DEBUG) @@ -62,9 +63,10 @@ def test_save_model_without_metadata_cache_hit(self): data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) self.assertEqual(data.extras["version"], save_version) resave_version1 = self.registry.save( - skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model + skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model, **{"lr": 0.01} ) resave_data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + print(resave_data.extras) self.assertEqual(save_version, "0") self.assertEqual(resave_version1, "1") self.assertEqual(resave_data.extras["version"], "0") @@ -185,6 +187,19 @@ def test_load_latest_cache_ttl_expire(self): self.assertEqual("registry", artifact_data_1.extras["source"]) self.assertEqual("registry", artifact_data_2.extras["source"]) + def test_multiple_save(self): + self.registry.save_multiple( + skeys=self.skeys, + dict_artifacts={ + "AE": KeyedArtifact(dkeys=["AE"], artifact=VanillaAE(10)), + "scaler": KeyedArtifact(dkeys=["scaler"], artifact=StandardScaler()), + }, + **{"a": "b"} + ) + artifact_data = self.registry.load(skeys=self.skeys, dkeys=["AE"]) + self.assertEqual("registry", artifact_data.extras["source"]) + self.assertIsNotNone(artifact_data.artifact) + def test_load_non_latest_model_twice(self): old_version = self.registry.save( skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model @@ -198,7 +213,7 @@ def test_load_non_latest_model_twice(self): skeys=self.skeys, dkeys=self.dkeys, latest=False, version=old_version ) self.assertEqual("registry", artifact_data_1.extras["source"]) - self.assertEqual("registry", artifact_data_2.extras["source"]) + self.assertEqual("cache", artifact_data_2.extras["source"]) def test_delete_version(self): version = self.registry.save( @@ -227,3 +242,11 @@ def test_exception_call2(self): def test_exception_call3(self): with self.assertRaises(RedisRegistryError): self.registry.delete(skeys=self.skeys, dkeys=self.dkeys, version="0") + + @patch("redis.Redis.set", Mock(side_effect=ConnectionError)) + def test_exception_call4(self): + with self.assertRaises(RedisRegistryError): + self.registry.save_multiple( + skeys=self.skeys, + dict_artifacts={"AE": KeyedArtifact(dkeys=self.dkeys, artifact=VanillaAE(10))}, + ) diff --git a/tests/udfs/test_inference.py b/tests/udfs/test_inference.py index 9ebc3cfa..feab6b7a 100644 --- a/tests/udfs/test_inference.py +++ b/tests/udfs/test_inference.py @@ -8,6 +8,7 @@ from orjson import orjson from pynumaflow.function import Datum, DatumMetadata +from numalogic.config import NumalogicConf, ModelInfo, TrainerConf, LightningTrainerConf from numalogic.models.autoencoder.variants import VanillaAE from numalogic.registry import RedisRegistry, ArtifactData from numalogic.tools.exceptions import RedisRegistryError @@ -70,11 +71,12 @@ ], "header": "model_inference", "metadata": { + "artifact_versions": {"VanillaAE": "0"}, "tags": { "asset_alias": "some-alias", "asset_id": "362557362191815079", "env": "prd", - } + }, }, } @@ -82,7 +84,15 @@ class TestInferenceUDF(unittest.TestCase): def setUp(self) -> None: self.udf = InferenceUDF(REDIS_CLIENT) - self.udf.register_conf("conf1", StreamConf(config_id="conf1")) + self.udf.register_conf( + "conf1", + StreamConf( + numalogic_conf=NumalogicConf( + model=ModelInfo(name="VanillaAE", conf={"seq_len": 12, "n_features": 2}), + trainer=TrainerConf(pltrainer_conf=LightningTrainerConf(max_epochs=1)), + ) + ), + ) @patch.object( RedisRegistry, @@ -90,7 +100,7 @@ def setUp(self) -> None: Mock( return_value=ArtifactData( artifact=VanillaAE(seq_len=12, n_features=2), - extras=dict(version="1", timestamp=time.time(), source="registry"), + extras=dict(version="0", timestamp=time.time(), source="registry"), metadata={}, ) ), diff --git a/tests/udfs/test_postprocess.py b/tests/udfs/test_postprocess.py index a13924ca..4df25ca7 100644 --- a/tests/udfs/test_postprocess.py +++ b/tests/udfs/test_postprocess.py @@ -77,7 +77,7 @@ "status": "artifact_found", "header": "model_inference", "metadata": { - "model_version": 0, + "artifact_versions": {"StdDevThreshold": "0"}, "tags": {"asset_alias": "data", "asset_id": "123456789", "env": "prd"}, }, } @@ -116,7 +116,6 @@ def test_postprocess_infer_model_stale(self): self.registry.save( KEYS, ["StdDevThreshold"], StdDevThreshold().fit(np.asarray([[0, 1], [1, 2]])) ) - msg = self.udf(KEYS, Datum(keys=KEYS, value=orjson.dumps(data), **DATUM_KW)) self.assertEqual(2, len(msg)) diff --git a/tests/udfs/test_preprocess.py b/tests/udfs/test_preprocess.py index a9802a70..487bf524 100644 --- a/tests/udfs/test_preprocess.py +++ b/tests/udfs/test_preprocess.py @@ -45,13 +45,13 @@ def setUp(self) -> None: self.udf1 = PreprocessUDF(REDIS_CLIENT, pl_conf=pl_conf) self.udf2 = PreprocessUDF(REDIS_CLIENT, pl_conf=pl_conf_2) self.udf1.register_conf("druid-config", pl_conf.stream_confs["druid-config"]) - self.udf1.register_conf("druid-config", pl_conf_2.stream_confs["druid-config"]) + self.udf2.register_conf("druid-config", pl_conf_2.stream_confs["druid-config"]) def tearDown(self) -> None: REDIS_CLIENT.flushall() def test_preprocess_load_from_registry(self): - msgs = self.udf1( + msgs = self.udf2( KEYS, DATUM, ) @@ -61,7 +61,7 @@ def test_preprocess_load_from_registry(self): self.assertEqual(payload.header, Header.MODEL_INFERENCE) def test_preprocess_load_from_config(self): - msgs = self.udf2( + msgs = self.udf1( KEYS, DATUM, ) diff --git a/tests/udfs/utility.py b/tests/udfs/utility.py index a0299fa3..bb7da059 100644 --- a/tests/udfs/utility.py +++ b/tests/udfs/utility.py @@ -6,6 +6,8 @@ from sklearn.pipeline import make_pipeline from numalogic.config import PreprocessFactory +from numalogic.models.autoencoder.variants import VanillaAE +from numalogic.tools.types import KeyedArtifact def input_json_from_file(data_path: str) -> Datum: @@ -45,11 +47,16 @@ def store_in_redis(pl_conf, registry): ): preproc_clf = make_pipeline(*preproc_clfs) preproc_clf.fit(np.asarray([[1, 3], [4, 6]])) - registry.save( + registry.save_multiple( skeys=pl_conf.stream_confs["druid-config"].composite_keys, - dkeys=[ - _conf.name - for _conf in pl_conf.stream_confs["druid-config"].numalogic_conf.preprocess - ], - artifact=preproc_clf, + dict_artifacts={ + "inference": KeyedArtifact(dkeys=["AE"], artifact=VanillaAE(10)), + "preproc": KeyedArtifact( + dkeys=[ + _conf.name + for _conf in pl_conf.stream_confs["druid-config"].numalogic_conf.preprocess + ], + artifact=preproc_clf, + ), + }, )