Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add multiple save for redis registry #281

Merged
merged 18 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry==1.5.1
run: pipx install poetry==1.6.1

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry==1.5.1
run: pipx install poetry==1.6.1

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry==1.5.1
run: pipx install poetry==1.6.1

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
2 changes: 1 addition & 1 deletion numalogic/connectors/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def fetch(

end_dt = datetime.now(pytz.utc)
start_dt = end_dt - timedelta(hours=hours)
intervals = f"{start_dt.isoformat()}/{end_dt.isoformat()}"
intervals = [f"{start_dt.isoformat()}/{end_dt.isoformat()}"]
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved

dimension_specs = map(lambda d: DimensionSpec(dimension=d, output_name=d), dimensions)

Expand Down
124 changes: 100 additions & 24 deletions numalogic/registry/redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from datetime import datetime, timedelta
from typing import Optional
import orjson
import redis.client

from redis.exceptions import RedisError

from numalogic.registry.artifact import ArtifactManager, ArtifactData, ArtifactCache
from numalogic.registry._serialize import loads, dumps
from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError
from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT
from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT, KeyedArtifact

_LOGGER = logging.getLogger(__name__)

Expand All @@ -33,6 +34,8 @@
client: Take in the redis client already established/created
ttl: Total Time to Live (in seconds) for the key when saving in redis (dafault = 604800)
cache_registry: Cache registry to use (default = None).
transactional: Flag to indicate if the registry should be transactional or
not (default = False).

Examples
--------
Expand All @@ -48,18 +51,20 @@
>>> loaded_artifact = registry.load(skeys, dkeys)
"""

__slots__ = ("client", "ttl", "cache_registry")
__slots__ = ("client", "ttl", "cache_registry", "transactional")

def __init__(
self,
client: redis_client_t,
ttl: int = 604800,
cache_registry: Optional[ArtifactCache] = None,
transactional: bool = True,
):
super().__init__("")
self.client = client
self.ttl = ttl
self.cache_registry = cache_registry
self.transactional = transactional

@staticmethod
def construct_key(skeys: KEYS, dkeys: KEYS) -> str:
Expand Down Expand Up @@ -155,30 +160,37 @@
------
ModelKeyNotFound: If the model key is not found in the registry.
"""
cached_artifact = self._load_from_cache(key)
latest_key = self.__construct_latest_key(key)
cached_artifact = self._load_from_cache(latest_key)
if cached_artifact:
_LOGGER.debug("Found cached artifact for key: %s", key)
_LOGGER.debug("Found cached artifact for key: %s", latest_key)
return cached_artifact, True
latest_key = self.__construct_latest_key(key)
if not self.client.exists(latest_key):
raise ModelKeyNotFound(f"latest key: {latest_key}, Not Found !!!")
model_key = self.client.get(latest_key)
_LOGGER.debug("latest key, %s, is pointing to the key : %s", latest_key, model_key)
return (
self.__load_version_artifact(version=self.get_version(model_key.decode()), key=key),
False,
artifact, _ = self.__load_version_artifact(
version=self.get_version(model_key.decode()), key=key
)
return artifact, False

def __load_version_artifact(self, version: str, key: str) -> ArtifactData:
model_key = self.__construct_version_key(key, version)
if not self.client.exists(model_key):
raise ModelKeyNotFound("Could not find model key with key: %s" % model_key)
return self.__get_artifact_data(
model_key=model_key,
def __load_version_artifact(self, version: str, key: str) -> tuple[ArtifactData, bool]:
version_key = self.__construct_version_key(key, version)
cached_artifact = self._load_from_cache(version_key)
if cached_artifact:
_LOGGER.debug("Found cached version artifact for key: %s", version_key)
return cached_artifact, True
if not self.client.exists(version_key):
raise ModelKeyNotFound(f"Could not find model key with key: {version_key}")
return (
self.__get_artifact_data(
model_key=version_key,
),
False,
)

def __save_artifact(
self, pipe, artifact: artifact_t, metadata: META_T, key: KEYS, version: str
self, pipe, artifact: artifact_t, key: KEYS, version: str, **metadata: META_T
) -> str:
new_version_key = self.__construct_version_key(key, version)
latest_key = self.__construct_latest_key(key)
Expand Down Expand Up @@ -210,6 +222,8 @@
is needed to load the respective artifact.

If cache registry is provided, it will first check the cache registry for the artifact.
If latest is passed, latest key is saved otherwise version call saves the respective
version artifact in cache.

Args:
----
Expand All @@ -230,24 +244,35 @@
if (latest and version) or (not latest and not version):
raise ValueError("Either One of 'latest' or 'version' needed in load method call")
key = self.construct_key(skeys, dkeys)
is_cached = False
try:
if latest:
artifact_data, is_cached = self.__load_latest_artifact(key)
else:
artifact_data = self.__load_version_artifact(version, key)
artifact_data, is_cached = self.__load_version_artifact(version, key)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
if (not is_cached) and latest:
self._save_in_cache(key, artifact_data)
if not is_cached:
if latest:
_LOGGER.debug(
"Saving %s, in cache as %s", self.__construct_latest_key(key), key
)
self._save_in_cache(self.__construct_latest_key(key), artifact_data)
else:
_LOGGER.info(
"Saving %s, in cache as %s",
self.__construct_version_key(key, version),
key,
)
self._save_in_cache(self.__construct_version_key(key, version), artifact_data)
return artifact_data

def save(
self,
skeys: KEYS,
dkeys: KEYS,
artifact: artifact_t,
_pipe: Optional[redis.client.Pipeline] = None,
**metadata: META_VT,
) -> Optional[str]:
"""Saves the artifact into redis registry and updates version.
Expand All @@ -257,6 +282,7 @@
skeys: static key fields as list/tuple of strings
dkeys: dynamic key fields as list/tuple of strings
artifact: primary artifact to be saved
_pipe: RedisPipeline object
metadata: additional metadata surrounding the artifact that needs to be saved.

Returns
Expand All @@ -275,10 +301,15 @@
_LOGGER.debug("Latest key: %s exists for the model", latest_key)
version_key = self.client.get(name=latest_key)
version = int(self.get_version(version_key.decode())) + 1
with self.client.pipeline() as pipe:
new_version_key = self.__save_artifact(pipe, artifact, metadata, key, str(version))
pipe.expire(name=new_version_key, time=self.ttl)
pipe.execute()
_redis_pipe = (
self.client.pipeline(transaction=self.transactional) if _pipe is None else _pipe
)
new_version_key = self.__save_artifact(
pipe=_redis_pipe, artifact=artifact, key=key, version=str(version), **metadata
)
_redis_pipe.expire(name=new_version_key, time=self.ttl)
if _pipe is None:
_redis_pipe.execute()
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
Expand Down Expand Up @@ -306,7 +337,7 @@
self.client.delete(del_key)
else:
raise ModelKeyNotFound(
"Key to delete: %s, Not Found!" % del_key,
f"Key to delete: {del_key}, Not Found!",
)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
Expand Down Expand Up @@ -338,3 +369,48 @@
raise RedisRegistryError("Error fetching timestamp information") from err
stale_ts = (datetime.now() - timedelta(hours=freq_hr)).timestamp()
return stale_ts > artifact_ts

def save_multiple(
self,
skeys: KEYS,
dict_artifacts: dict[str, KeyedArtifact],
**metadata: META_VT,
):
"""
Saves multiple artifacts into redis registry. The last save stores all the
artifact versions in the metadata.

Args:
----
skeys: static key fields as list/tuple of strings
dict_artifacts: dict of artifacts to save
metadata: additional metadata surrounding the artifact that needs to be saved.
"""
dict_model_ver = {}
try:
with self.client.pipeline(transaction=self.transactional) as pipe:
pipe.multi()
for key, value in dict_artifacts.items():
dict_model_ver[":".join(value.dkeys)] = self.save(
skeys=skeys,
dkeys=value.dkeys,
artifact=value.artifact,
_pipe=pipe,
**metadata,
)

if len(dict_artifacts) == len(dict_model_ver):
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
self.save(
skeys=skeys,
dkeys=value.dkeys,
artifact=value.artifact,
_pipe=pipe,
artifact_versions=dict_model_ver,
**metadata,
)
pipe.execute()
_LOGGER.info("Successfully saved all the artifacts with: %s", dict_model_ver)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err

Check warning on line 414 in numalogic/registry/redis_registry.py

View check run for this annotation

Codecov / codecov/patch

numalogic/registry/redis_registry.py#L414

Added line #L414 was not covered by tests
else:
return dict_model_ver
11 changes: 8 additions & 3 deletions numalogic/tools/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from collections.abc import Sequence
from typing import Union, TypeVar, ClassVar
from typing import Union, TypeVar, ClassVar, NamedTuple

from sklearn.base import BaseEstimator
from torch import Tensor
Expand Down Expand Up @@ -40,6 +38,13 @@
KEYS = TypeVar("KEYS", bound=Sequence[str], covariant=True)


class KeyedArtifact(NamedTuple):
r"""namedtuple for artifacts."""

dkeys: KEYS
artifact: artifact_t


class Singleton(type):
r"""Helper metaclass to use as a Singleton class."""

Expand Down
54 changes: 10 additions & 44 deletions numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@

from numalogic.config import RegistryFactory
from numalogic.registry import LocalLRUCache, ArtifactData
from numalogic.tools.exceptions import RedisRegistryError, ModelKeyNotFound, ConfigNotFoundError
from numalogic.tools.exceptions import ConfigNotFoundError
from numalogic.tools.types import artifact_t, redis_client_t
from numalogic.udfs._base import NumalogicUDF
from numalogic.udfs._config import StreamConf, PipelineConf
from numalogic.udfs.entities import StreamPayload, Header, Status
from numalogic.udfs.tools import _load_artifact

_LOGGER = logging.getLogger(__name__)

# TODO: move to config
LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600"))
LOCAL_CACHE_SIZE = int(os.getenv("LOCAL_CACHE_SIZE", "10000"))
LOAD_LATEST = os.getenv("LOAD_LATEST", "false").lower() == "true"


class InferenceUDF(NumalogicUDF):
Expand Down Expand Up @@ -114,8 +116,13 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
# Forward payload if a training request is tagged
if payload.header == Header.TRAIN_REQUEST:
return Messages(Message(keys=keys, value=payload.to_json()))

artifact_data = self.load_artifact(keys, payload)
artifact_data, payload = _load_artifact(
skeys=keys,
dkeys=[self.get_conf(payload.config_id).numalogic_conf.model.name],
payload=payload,
model_registry=self.model_registry,
load_latest=LOAD_LATEST,
)

# TODO: revisit retraining logic
# Send training request if artifact loading is not successful
Expand Down Expand Up @@ -165,47 +172,6 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
)
return Messages(Message(keys=keys, value=payload.to_json()))

def load_artifact(self, keys: list[str], payload: StreamPayload) -> Optional[ArtifactData]:
"""
Load inference artifact from the registry.

Args:
keys: List of keys
payload: StreamPayload object

Returns
-------
ArtifactData instance
"""
_conf = self.get_conf(payload.config_id).numalogic_conf
try:
artifact_data = self.model_registry.load(
skeys=keys,
dkeys=[_conf.model.name],
)
except ModelKeyNotFound:
_LOGGER.warning(
"%s - Model key not found for Keys: %s, Metric: %s",
payload.uuid,
payload.composite_keys,
payload.metrics,
)
return None
except RedisRegistryError:
_LOGGER.exception(
"%s - Error while fetching inference artifact, Keys: %s, Metric: %s",
payload.uuid,
payload.composite_keys,
payload.metrics,
)
return None
_LOGGER.info(
"%s - Loaded artifact data from %s",
payload.uuid,
artifact_data.extras.get("source"),
)
return artifact_data

def is_model_stale(self, artifact_data: ArtifactData, payload: StreamPayload) -> bool:
"""
Check if the inference artifact is stale.
Expand Down
Loading