Skip to content

Commit

Permalink
feat: add multiple save for redis registry (#281)
Browse files Browse the repository at this point in the history
1) Option to multisave in redis registry
2) support caching for version calls.
3) In udfs/ load models using version call instead of latest calls.
4) call multiple_save in artifacts to save artifacts in a transaction.

---------

Signed-off-by: s0nicboOm <[email protected]>
  • Loading branch information
s0nicboOm authored Sep 18, 2023
1 parent a364721 commit bc1c627
Show file tree
Hide file tree
Showing 19 changed files with 333 additions and 172 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry==1.5.1
run: pipx install poetry==1.6.1

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry==1.5.1
run: pipx install poetry==1.6.1

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry==1.5.1
run: pipx install poetry==1.6.1

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
2 changes: 1 addition & 1 deletion numalogic/connectors/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def fetch(

end_dt = datetime.now(pytz.utc)
start_dt = end_dt - timedelta(hours=hours)
intervals = f"{start_dt.isoformat()}/{end_dt.isoformat()}"
intervals = [f"{start_dt.isoformat()}/{end_dt.isoformat()}"]

dimension_specs = map(lambda d: DimensionSpec(dimension=d, output_name=d), dimensions)

Expand Down
124 changes: 100 additions & 24 deletions numalogic/registry/redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from datetime import datetime, timedelta
from typing import Optional
import orjson
import redis.client

from redis.exceptions import RedisError

from numalogic.registry.artifact import ArtifactManager, ArtifactData, ArtifactCache
from numalogic.registry._serialize import loads, dumps
from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError
from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT
from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT, KeyedArtifact

_LOGGER = logging.getLogger(__name__)

Expand All @@ -33,6 +34,8 @@ class RedisRegistry(ArtifactManager):
client: Take in the redis client already established/created
ttl: Total Time to Live (in seconds) for the key when saving in redis (dafault = 604800)
cache_registry: Cache registry to use (default = None).
transactional: Flag to indicate if the registry should be transactional or
not (default = False).
Examples
--------
Expand All @@ -48,18 +51,20 @@ class RedisRegistry(ArtifactManager):
>>> loaded_artifact = registry.load(skeys, dkeys)
"""

__slots__ = ("client", "ttl", "cache_registry")
__slots__ = ("client", "ttl", "cache_registry", "transactional")

def __init__(
self,
client: redis_client_t,
ttl: int = 604800,
cache_registry: Optional[ArtifactCache] = None,
transactional: bool = True,
):
super().__init__("")
self.client = client
self.ttl = ttl
self.cache_registry = cache_registry
self.transactional = transactional

@staticmethod
def construct_key(skeys: KEYS, dkeys: KEYS) -> str:
Expand Down Expand Up @@ -155,30 +160,37 @@ def __load_latest_artifact(self, key: str) -> tuple[ArtifactData, bool]:
------
ModelKeyNotFound: If the model key is not found in the registry.
"""
cached_artifact = self._load_from_cache(key)
latest_key = self.__construct_latest_key(key)
cached_artifact = self._load_from_cache(latest_key)
if cached_artifact:
_LOGGER.debug("Found cached artifact for key: %s", key)
_LOGGER.debug("Found cached artifact for key: %s", latest_key)
return cached_artifact, True
latest_key = self.__construct_latest_key(key)
if not self.client.exists(latest_key):
raise ModelKeyNotFound(f"latest key: {latest_key}, Not Found !!!")
model_key = self.client.get(latest_key)
_LOGGER.debug("latest key, %s, is pointing to the key : %s", latest_key, model_key)
return (
self.__load_version_artifact(version=self.get_version(model_key.decode()), key=key),
False,
artifact, _ = self.__load_version_artifact(
version=self.get_version(model_key.decode()), key=key
)
return artifact, False

def __load_version_artifact(self, version: str, key: str) -> ArtifactData:
model_key = self.__construct_version_key(key, version)
if not self.client.exists(model_key):
raise ModelKeyNotFound("Could not find model key with key: %s" % model_key)
return self.__get_artifact_data(
model_key=model_key,
def __load_version_artifact(self, version: str, key: str) -> tuple[ArtifactData, bool]:
version_key = self.__construct_version_key(key, version)
cached_artifact = self._load_from_cache(version_key)
if cached_artifact:
_LOGGER.debug("Found cached version artifact for key: %s", version_key)
return cached_artifact, True
if not self.client.exists(version_key):
raise ModelKeyNotFound(f"Could not find model key with key: {version_key}")
return (
self.__get_artifact_data(
model_key=version_key,
),
False,
)

def __save_artifact(
self, pipe, artifact: artifact_t, metadata: META_T, key: KEYS, version: str
self, pipe, artifact: artifact_t, key: KEYS, version: str, **metadata: META_T
) -> str:
new_version_key = self.__construct_version_key(key, version)
latest_key = self.__construct_latest_key(key)
Expand Down Expand Up @@ -210,6 +222,8 @@ def load(
is needed to load the respective artifact.
If cache registry is provided, it will first check the cache registry for the artifact.
If latest is passed, latest key is saved otherwise version call saves the respective
version artifact in cache.
Args:
----
Expand All @@ -230,24 +244,35 @@ def load(
if (latest and version) or (not latest and not version):
raise ValueError("Either One of 'latest' or 'version' needed in load method call")
key = self.construct_key(skeys, dkeys)
is_cached = False
try:
if latest:
artifact_data, is_cached = self.__load_latest_artifact(key)
else:
artifact_data = self.__load_version_artifact(version, key)
artifact_data, is_cached = self.__load_version_artifact(version, key)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
if (not is_cached) and latest:
self._save_in_cache(key, artifact_data)
if not is_cached:
if latest:
_LOGGER.debug(
"Saving %s, in cache as %s", self.__construct_latest_key(key), key
)
self._save_in_cache(self.__construct_latest_key(key), artifact_data)
else:
_LOGGER.info(
"Saving %s, in cache as %s",
self.__construct_version_key(key, version),
key,
)
self._save_in_cache(self.__construct_version_key(key, version), artifact_data)
return artifact_data

def save(
self,
skeys: KEYS,
dkeys: KEYS,
artifact: artifact_t,
_pipe: Optional[redis.client.Pipeline] = None,
**metadata: META_VT,
) -> Optional[str]:
"""Saves the artifact into redis registry and updates version.
Expand All @@ -257,6 +282,7 @@ def save(
skeys: static key fields as list/tuple of strings
dkeys: dynamic key fields as list/tuple of strings
artifact: primary artifact to be saved
_pipe: RedisPipeline object
metadata: additional metadata surrounding the artifact that needs to be saved.
Returns
Expand All @@ -275,10 +301,15 @@ def save(
_LOGGER.debug("Latest key: %s exists for the model", latest_key)
version_key = self.client.get(name=latest_key)
version = int(self.get_version(version_key.decode())) + 1
with self.client.pipeline() as pipe:
new_version_key = self.__save_artifact(pipe, artifact, metadata, key, str(version))
pipe.expire(name=new_version_key, time=self.ttl)
pipe.execute()
_redis_pipe = (
self.client.pipeline(transaction=self.transactional) if _pipe is None else _pipe
)
new_version_key = self.__save_artifact(
pipe=_redis_pipe, artifact=artifact, key=key, version=str(version), **metadata
)
_redis_pipe.expire(name=new_version_key, time=self.ttl)
if _pipe is None:
_redis_pipe.execute()
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
Expand Down Expand Up @@ -306,7 +337,7 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None:
self.client.delete(del_key)
else:
raise ModelKeyNotFound(
"Key to delete: %s, Not Found!" % del_key,
f"Key to delete: {del_key}, Not Found!",
)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
Expand Down Expand Up @@ -338,3 +369,48 @@ def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
raise RedisRegistryError("Error fetching timestamp information") from err
stale_ts = (datetime.now() - timedelta(hours=freq_hr)).timestamp()
return stale_ts > artifact_ts

def save_multiple(
self,
skeys: KEYS,
dict_artifacts: dict[str, KeyedArtifact],
**metadata: META_VT,
):
"""
Saves multiple artifacts into redis registry. The last save stores all the
artifact versions in the metadata.
Args:
----
skeys: static key fields as list/tuple of strings
dict_artifacts: dict of artifacts to save
metadata: additional metadata surrounding the artifact that needs to be saved.
"""
dict_model_ver = {}
try:
with self.client.pipeline(transaction=self.transactional) as pipe:
pipe.multi()
for key, value in dict_artifacts.items():
dict_model_ver[":".join(value.dkeys)] = self.save(
skeys=skeys,
dkeys=value.dkeys,
artifact=value.artifact,
_pipe=pipe,
**metadata,
)

if len(dict_artifacts) == len(dict_model_ver):
self.save(
skeys=skeys,
dkeys=value.dkeys,
artifact=value.artifact,
_pipe=pipe,
artifact_versions=dict_model_ver,
**metadata,
)
pipe.execute()
_LOGGER.info("Successfully saved all the artifacts with: %s", dict_model_ver)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
return dict_model_ver
11 changes: 8 additions & 3 deletions numalogic/tools/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from collections.abc import Sequence
from typing import Union, TypeVar, ClassVar
from typing import Union, TypeVar, ClassVar, NamedTuple

from sklearn.base import BaseEstimator
from torch import Tensor
Expand Down Expand Up @@ -40,6 +38,13 @@
KEYS = TypeVar("KEYS", bound=Sequence[str], covariant=True)


class KeyedArtifact(NamedTuple):
r"""namedtuple for artifacts."""

dkeys: KEYS
artifact: artifact_t


class Singleton(type):
r"""Helper metaclass to use as a Singleton class."""

Expand Down
54 changes: 10 additions & 44 deletions numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@

from numalogic.config import RegistryFactory
from numalogic.registry import LocalLRUCache, ArtifactData
from numalogic.tools.exceptions import RedisRegistryError, ModelKeyNotFound, ConfigNotFoundError
from numalogic.tools.exceptions import ConfigNotFoundError
from numalogic.tools.types import artifact_t, redis_client_t
from numalogic.udfs._base import NumalogicUDF
from numalogic.udfs._config import StreamConf, PipelineConf
from numalogic.udfs.entities import StreamPayload, Header, Status
from numalogic.udfs.tools import _load_artifact

_LOGGER = logging.getLogger(__name__)

# TODO: move to config
LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600"))
LOCAL_CACHE_SIZE = int(os.getenv("LOCAL_CACHE_SIZE", "10000"))
LOAD_LATEST = os.getenv("LOAD_LATEST", "false").lower() == "true"


class InferenceUDF(NumalogicUDF):
Expand Down Expand Up @@ -114,8 +116,13 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
# Forward payload if a training request is tagged
if payload.header == Header.TRAIN_REQUEST:
return Messages(Message(keys=keys, value=payload.to_json()))

artifact_data = self.load_artifact(keys, payload)
artifact_data, payload = _load_artifact(
skeys=keys,
dkeys=[self.get_conf(payload.config_id).numalogic_conf.model.name],
payload=payload,
model_registry=self.model_registry,
load_latest=LOAD_LATEST,
)

# TODO: revisit retraining logic
# Send training request if artifact loading is not successful
Expand Down Expand Up @@ -165,47 +172,6 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
)
return Messages(Message(keys=keys, value=payload.to_json()))

def load_artifact(self, keys: list[str], payload: StreamPayload) -> Optional[ArtifactData]:
"""
Load inference artifact from the registry.
Args:
keys: List of keys
payload: StreamPayload object
Returns
-------
ArtifactData instance
"""
_conf = self.get_conf(payload.config_id).numalogic_conf
try:
artifact_data = self.model_registry.load(
skeys=keys,
dkeys=[_conf.model.name],
)
except ModelKeyNotFound:
_LOGGER.warning(
"%s - Model key not found for Keys: %s, Metric: %s",
payload.uuid,
payload.composite_keys,
payload.metrics,
)
return None
except RedisRegistryError:
_LOGGER.exception(
"%s - Error while fetching inference artifact, Keys: %s, Metric: %s",
payload.uuid,
payload.composite_keys,
payload.metrics,
)
return None
_LOGGER.info(
"%s - Loaded artifact data from %s",
payload.uuid,
artifact_data.extras.get("source"),
)
return artifact_data

def is_model_stale(self, artifact_data: ArtifactData, payload: StreamPayload) -> bool:
"""
Check if the inference artifact is stale.
Expand Down
Loading

0 comments on commit bc1c627

Please sign in to comment.