Skip to content

Commit

Permalink
add _pipe for redis registry
Browse files Browse the repository at this point in the history
Signed-off-by: s0nicboOm <[email protected]>
  • Loading branch information
s0nicboOm committed Sep 14, 2023
1 parent d8e5282 commit 7ca521a
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 28 deletions.
32 changes: 18 additions & 14 deletions numalogic/registry/redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from datetime import datetime, timedelta
from typing import Optional
import orjson
import redis.client

from redis.exceptions import RedisError

from numalogic.registry.artifact import ArtifactManager, ArtifactData, ArtifactCache
from numalogic.registry._serialize import loads, dumps
from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError
from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT, ArtifactTuple
from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT, KeyedArtifact

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -180,7 +181,7 @@ def __load_version_artifact(self, version: str, key: str) -> tuple[ArtifactData,
_LOGGER.debug("Found cached version artifact for key: %s", version_key)
return cached_artifact, True
if not self.client.exists(version_key):
raise ModelKeyNotFound("Could not find model key with key: %s" % version_key)
raise ModelKeyNotFound(f"Could not find model key with key: {version_key}")
return (
self.__get_artifact_data(
model_key=version_key,
Expand Down Expand Up @@ -253,7 +254,9 @@ def load(
else:
if not is_cached:
if latest:
_LOGGER.info("Saving %s, in cache as %s", self.__construct_latest_key(key), key)
_LOGGER.debug(
"Saving %s, in cache as %s", self.__construct_latest_key(key), key
)
self._save_in_cache(self.__construct_latest_key(key), artifact_data)
else:
_LOGGER.info(
Expand All @@ -269,7 +272,7 @@ def save(
skeys: KEYS,
dkeys: KEYS,
artifact: artifact_t,
pipe=None,
_pipe: Optional[redis.client.Pipeline] = None,
**metadata: META_VT,
) -> Optional[str]:
"""Saves the artifact into redis registry and updates version.
Expand All @@ -279,6 +282,7 @@ def save(
skeys: static key fields as list/tuple of strings
dkeys: dynamic key fields as list/tuple of strings
artifact: primary artifact to be saved
_pipe: RedisPipeline object
metadata: additional metadata surrounding the artifact that needs to be saved.
Returns
Expand All @@ -297,15 +301,15 @@ def save(
_LOGGER.debug("Latest key: %s exists for the model", latest_key)
version_key = self.client.get(name=latest_key)
version = int(self.get_version(version_key.decode())) + 1
redis_pipe = (
self.client.pipeline(transaction=self.transactional) if pipe is None else pipe
_redis_pipe = (
self.client.pipeline(transaction=self.transactional) if _pipe is None else _pipe
)
new_version_key = self.__save_artifact(
pipe=redis_pipe, artifact=artifact, key=key, version=str(version), **metadata
pipe=_redis_pipe, artifact=artifact, key=key, version=str(version), **metadata
)
redis_pipe.expire(name=new_version_key, time=self.ttl)
if pipe is None:
redis_pipe.execute()
_redis_pipe.expire(name=new_version_key, time=self.ttl)
if _pipe is None:
_redis_pipe.execute()
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
Expand Down Expand Up @@ -333,7 +337,7 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None:
self.client.delete(del_key)
else:
raise ModelKeyNotFound(
"Key to delete: %s, Not Found!" % del_key,
f"Key to delete: {del_key}, Not Found!",
)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
Expand Down Expand Up @@ -369,7 +373,7 @@ def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
def save_multiple(
self,
skeys: KEYS,
dict_artifacts: dict[str, ArtifactTuple],
dict_artifacts: dict[str, KeyedArtifact],
**metadata: META_VT,
):
"""
Expand All @@ -391,7 +395,7 @@ def save_multiple(
skeys=skeys,
dkeys=value.dkeys,
artifact=value.artifact,
pipe=pipe,
_pipe=pipe,
**metadata,
)

Expand All @@ -400,7 +404,7 @@ def save_multiple(
skeys=skeys,
dkeys=value.dkeys,
artifact=value.artifact,
pipe=pipe,
_pipe=pipe,
artifact_versions=dict_model_ver,
**metadata,
)
Expand Down
2 changes: 1 addition & 1 deletion numalogic/tools/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
KEYS = TypeVar("KEYS", bound=Sequence[str], covariant=True)


class ArtifactTuple(NamedTuple):
class KeyedArtifact(NamedTuple):
r"""namedtuple for artifacts."""

dkeys: KEYS
Expand Down
2 changes: 1 addition & 1 deletion numalogic/udfs/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def make_stream_payload(
)


# TODO: move to base NumalogicUDF class
# TODO: move to base NumalogicUDF class and look into payload mutation
def _load_artifact(
skeys: KEYS,
dkeys: KEYS,
Expand Down
10 changes: 5 additions & 5 deletions numalogic/udfs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from numalogic.registry import RedisRegistry
from numalogic.tools.data import StreamingDataset
from numalogic.tools.exceptions import ConfigNotFoundError, RedisRegistryError
from numalogic.tools.types import redis_client_t, artifact_t, KEYS, ArtifactTuple
from numalogic.tools.types import redis_client_t, artifact_t, KEYS, KeyedArtifact
from numalogic.udfs import NumalogicUDF
from numalogic.udfs._config import StreamConf, PipelineConf
from numalogic.udfs.entities import TrainerPayload, StreamPayload
Expand Down Expand Up @@ -194,13 +194,13 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
# TODO perform multi-save here
skeys = payload.composite_keys
dict_artifacts = {
"postproc": ArtifactTuple(
"postproc": KeyedArtifact(
dkeys=[_conf.numalogic_conf.threshold.name], artifact=artifacts["threshold_clf"]
),
"inference": ArtifactTuple(
"inference": KeyedArtifact(
dkeys=[_conf.numalogic_conf.model.name], artifact=artifacts["model"]
),
"preproc": ArtifactTuple(
"preproc": KeyedArtifact(
dkeys=[_conf.name for _conf in _conf.numalogic_conf.preprocess],
artifact=artifacts["preproc_clf"],
),
Expand Down Expand Up @@ -233,7 +233,7 @@ def _construct_preproc_clf(_conf: StreamConf) -> Optional[artifact_t]:
@staticmethod
def artifacts_to_save(
skeys: KEYS,
dict_artifacts: dict[str, ArtifactTuple],
dict_artifacts: dict[str, KeyedArtifact],
model_registry: RedisRegistry,
payload: StreamPayload,
) -> None:
Expand Down
8 changes: 4 additions & 4 deletions tests/registry/test_redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from numalogic.models.autoencoder.variants import VanillaAE
from numalogic.registry import RedisRegistry, LocalLRUCache, ArtifactData
from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError
from numalogic.tools.types import ArtifactTuple
from numalogic.tools.types import KeyedArtifact

logging.basicConfig(level=logging.DEBUG)

Expand Down Expand Up @@ -191,8 +191,8 @@ def test_multiple_save(self):
self.registry.save_multiple(
skeys=self.skeys,
dict_artifacts={
"AE": ArtifactTuple(dkeys=["AE"], artifact=VanillaAE(10)),
"scaler": ArtifactTuple(dkeys=["scaler"], artifact=StandardScaler()),
"AE": KeyedArtifact(dkeys=["AE"], artifact=VanillaAE(10)),
"scaler": KeyedArtifact(dkeys=["scaler"], artifact=StandardScaler()),
},
**{"a": "b"}
)
Expand Down Expand Up @@ -248,5 +248,5 @@ def test_exception_call4(self):
with self.assertRaises(RedisRegistryError):
self.registry.save_multiple(
skeys=self.skeys,
dict_artifacts={"AE": ArtifactTuple(dkeys=self.dkeys, artifact=VanillaAE(10))},
dict_artifacts={"AE": KeyedArtifact(dkeys=self.dkeys, artifact=VanillaAE(10))},
)
6 changes: 3 additions & 3 deletions tests/udfs/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from numalogic.config import PreprocessFactory
from numalogic.models.autoencoder.variants import VanillaAE
from numalogic.tools.types import ArtifactTuple
from numalogic.tools.types import KeyedArtifact


def input_json_from_file(data_path: str) -> Datum:
Expand Down Expand Up @@ -50,8 +50,8 @@ def store_in_redis(pl_conf, registry):
registry.save_multiple(
skeys=pl_conf.stream_confs["druid-config"].composite_keys,
dict_artifacts={
"inference": ArtifactTuple(dkeys=["AE"], artifact=VanillaAE(10)),
"preproc": ArtifactTuple(
"inference": KeyedArtifact(dkeys=["AE"], artifact=VanillaAE(10)),
"preproc": KeyedArtifact(
dkeys=[
_conf.name
for _conf in pl_conf.stream_confs["druid-config"].numalogic_conf.preprocess
Expand Down

0 comments on commit 7ca521a

Please sign in to comment.