diff --git a/numalogic/registry/redis_registry.py b/numalogic/registry/redis_registry.py index 0aec7530..f19c759e 100644 --- a/numalogic/registry/redis_registry.py +++ b/numalogic/registry/redis_registry.py @@ -14,13 +14,14 @@ from datetime import datetime, timedelta from typing import Optional import orjson +import redis.client from redis.exceptions import RedisError from numalogic.registry.artifact import ArtifactManager, ArtifactData, ArtifactCache from numalogic.registry._serialize import loads, dumps from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError -from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT, ArtifactTuple +from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT, KeyedArtifact _LOGGER = logging.getLogger(__name__) @@ -180,7 +181,7 @@ def __load_version_artifact(self, version: str, key: str) -> tuple[ArtifactData, _LOGGER.debug("Found cached version artifact for key: %s", version_key) return cached_artifact, True if not self.client.exists(version_key): - raise ModelKeyNotFound("Could not find model key with key: %s" % version_key) + raise ModelKeyNotFound(f"Could not find model key with key: {version_key}") return ( self.__get_artifact_data( model_key=version_key, @@ -253,7 +254,9 @@ def load( else: if not is_cached: if latest: - _LOGGER.info("Saving %s, in cache as %s", self.__construct_latest_key(key), key) + _LOGGER.debug( + "Saving %s, in cache as %s", self.__construct_latest_key(key), key + ) self._save_in_cache(self.__construct_latest_key(key), artifact_data) else: _LOGGER.info( @@ -269,7 +272,7 @@ def save( skeys: KEYS, dkeys: KEYS, artifact: artifact_t, - pipe=None, + _pipe: Optional[redis.client.Pipeline] = None, **metadata: META_VT, ) -> Optional[str]: """Saves the artifact into redis registry and updates version. @@ -279,6 +282,7 @@ def save( skeys: static key fields as list/tuple of strings dkeys: dynamic key fields as list/tuple of strings artifact: primary artifact to be saved + _pipe: RedisPipeline object metadata: additional metadata surrounding the artifact that needs to be saved. Returns @@ -297,15 +301,15 @@ def save( _LOGGER.debug("Latest key: %s exists for the model", latest_key) version_key = self.client.get(name=latest_key) version = int(self.get_version(version_key.decode())) + 1 - redis_pipe = ( - self.client.pipeline(transaction=self.transactional) if pipe is None else pipe + _redis_pipe = ( + self.client.pipeline(transaction=self.transactional) if _pipe is None else _pipe ) new_version_key = self.__save_artifact( - pipe=redis_pipe, artifact=artifact, key=key, version=str(version), **metadata + pipe=_redis_pipe, artifact=artifact, key=key, version=str(version), **metadata ) - redis_pipe.expire(name=new_version_key, time=self.ttl) - if pipe is None: - redis_pipe.execute() + _redis_pipe.expire(name=new_version_key, time=self.ttl) + if _pipe is None: + _redis_pipe.execute() except RedisError as err: raise RedisRegistryError(f"{err.__class__.__name__} raised") from err else: @@ -333,7 +337,7 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None: self.client.delete(del_key) else: raise ModelKeyNotFound( - "Key to delete: %s, Not Found!" % del_key, + f"Key to delete: {del_key}, Not Found!", ) except RedisError as err: raise RedisRegistryError(f"{err.__class__.__name__} raised") from err @@ -369,7 +373,7 @@ def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool: def save_multiple( self, skeys: KEYS, - dict_artifacts: dict[str, ArtifactTuple], + dict_artifacts: dict[str, KeyedArtifact], **metadata: META_VT, ): """ @@ -391,7 +395,7 @@ def save_multiple( skeys=skeys, dkeys=value.dkeys, artifact=value.artifact, - pipe=pipe, + _pipe=pipe, **metadata, ) @@ -400,7 +404,7 @@ def save_multiple( skeys=skeys, dkeys=value.dkeys, artifact=value.artifact, - pipe=pipe, + _pipe=pipe, artifact_versions=dict_model_ver, **metadata, ) diff --git a/numalogic/tools/types.py b/numalogic/tools/types.py index f50cc42d..22214b3c 100644 --- a/numalogic/tools/types.py +++ b/numalogic/tools/types.py @@ -38,7 +38,7 @@ KEYS = TypeVar("KEYS", bound=Sequence[str], covariant=True) -class ArtifactTuple(NamedTuple): +class KeyedArtifact(NamedTuple): r"""namedtuple for artifacts.""" dkeys: KEYS diff --git a/numalogic/udfs/tools.py b/numalogic/udfs/tools.py index b2dee9de..d9cd39fb 100644 --- a/numalogic/udfs/tools.py +++ b/numalogic/udfs/tools.py @@ -65,7 +65,7 @@ def make_stream_payload( ) -# TODO: move to base NumalogicUDF class +# TODO: move to base NumalogicUDF class and look into payload mutation def _load_artifact( skeys: KEYS, dkeys: KEYS, diff --git a/numalogic/udfs/trainer.py b/numalogic/udfs/trainer.py index 5d53a195..f35b6ca9 100644 --- a/numalogic/udfs/trainer.py +++ b/numalogic/udfs/trainer.py @@ -19,7 +19,7 @@ from numalogic.registry import RedisRegistry from numalogic.tools.data import StreamingDataset from numalogic.tools.exceptions import ConfigNotFoundError, RedisRegistryError -from numalogic.tools.types import redis_client_t, artifact_t, KEYS, ArtifactTuple +from numalogic.tools.types import redis_client_t, artifact_t, KEYS, KeyedArtifact from numalogic.udfs import NumalogicUDF from numalogic.udfs._config import StreamConf, PipelineConf from numalogic.udfs.entities import TrainerPayload, StreamPayload @@ -194,13 +194,13 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: # TODO perform multi-save here skeys = payload.composite_keys dict_artifacts = { - "postproc": ArtifactTuple( + "postproc": KeyedArtifact( dkeys=[_conf.numalogic_conf.threshold.name], artifact=artifacts["threshold_clf"] ), - "inference": ArtifactTuple( + "inference": KeyedArtifact( dkeys=[_conf.numalogic_conf.model.name], artifact=artifacts["model"] ), - "preproc": ArtifactTuple( + "preproc": KeyedArtifact( dkeys=[_conf.name for _conf in _conf.numalogic_conf.preprocess], artifact=artifacts["preproc_clf"], ), @@ -233,7 +233,7 @@ def _construct_preproc_clf(_conf: StreamConf) -> Optional[artifact_t]: @staticmethod def artifacts_to_save( skeys: KEYS, - dict_artifacts: dict[str, ArtifactTuple], + dict_artifacts: dict[str, KeyedArtifact], model_registry: RedisRegistry, payload: StreamPayload, ) -> None: diff --git a/tests/registry/test_redis_registry.py b/tests/registry/test_redis_registry.py index 09e138d4..de252cf1 100644 --- a/tests/registry/test_redis_registry.py +++ b/tests/registry/test_redis_registry.py @@ -13,7 +13,7 @@ from numalogic.models.autoencoder.variants import VanillaAE from numalogic.registry import RedisRegistry, LocalLRUCache, ArtifactData from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError -from numalogic.tools.types import ArtifactTuple +from numalogic.tools.types import KeyedArtifact logging.basicConfig(level=logging.DEBUG) @@ -191,8 +191,8 @@ def test_multiple_save(self): self.registry.save_multiple( skeys=self.skeys, dict_artifacts={ - "AE": ArtifactTuple(dkeys=["AE"], artifact=VanillaAE(10)), - "scaler": ArtifactTuple(dkeys=["scaler"], artifact=StandardScaler()), + "AE": KeyedArtifact(dkeys=["AE"], artifact=VanillaAE(10)), + "scaler": KeyedArtifact(dkeys=["scaler"], artifact=StandardScaler()), }, **{"a": "b"} ) @@ -248,5 +248,5 @@ def test_exception_call4(self): with self.assertRaises(RedisRegistryError): self.registry.save_multiple( skeys=self.skeys, - dict_artifacts={"AE": ArtifactTuple(dkeys=self.dkeys, artifact=VanillaAE(10))}, + dict_artifacts={"AE": KeyedArtifact(dkeys=self.dkeys, artifact=VanillaAE(10))}, ) diff --git a/tests/udfs/utility.py b/tests/udfs/utility.py index 7587e8e6..bb7da059 100644 --- a/tests/udfs/utility.py +++ b/tests/udfs/utility.py @@ -7,7 +7,7 @@ from numalogic.config import PreprocessFactory from numalogic.models.autoencoder.variants import VanillaAE -from numalogic.tools.types import ArtifactTuple +from numalogic.tools.types import KeyedArtifact def input_json_from_file(data_path: str) -> Datum: @@ -50,8 +50,8 @@ def store_in_redis(pl_conf, registry): registry.save_multiple( skeys=pl_conf.stream_confs["druid-config"].composite_keys, dict_artifacts={ - "inference": ArtifactTuple(dkeys=["AE"], artifact=VanillaAE(10)), - "preproc": ArtifactTuple( + "inference": KeyedArtifact(dkeys=["AE"], artifact=VanillaAE(10)), + "preproc": KeyedArtifact( dkeys=[ _conf.name for _conf in pl_conf.stream_confs["druid-config"].numalogic_conf.preprocess