Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add multiple save for redis registry #281

Merged
merged 18 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry==1.5.1
run: pipx install poetry==1.6.1

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry==1.5.1
run: pipx install poetry==1.6.1

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry==1.5.1
run: pipx install poetry==1.6.1

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
4 changes: 3 additions & 1 deletion numalogic/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ class TrainerConf:
class NumalogicConf:
"""Top level config schema for numalogic."""

model: ModelInfo = field(default_factory=ModelInfo)
model: ModelInfo = field(
default_factory=lambda: ModelInfo(name="VanillaAE", conf={"n_features": 2, "seq_len": 10})
)
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
trainer: TrainerConf = field(default_factory=TrainerConf)
preprocess: list[ModelInfo] = field(default_factory=list)
threshold: ModelInfo = field(default_factory=lambda: ModelInfo(name="StdDevThreshold"))
Expand Down
2 changes: 1 addition & 1 deletion numalogic/connectors/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def fetch_data(

end_dt = datetime.now(pytz.utc)
start_dt = end_dt - timedelta(hours=hours)
intervals = f"{start_dt.isoformat()}/{end_dt.isoformat()}"
intervals = [f"{start_dt.isoformat()}/{end_dt.isoformat()}"]
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved

params = {
"datasource": datasource,
Expand Down
115 changes: 93 additions & 22 deletions numalogic/registry/redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
client: Take in the redis client already established/created
ttl: Total Time to Live (in seconds) for the key when saving in redis (dafault = 604800)
cache_registry: Cache registry to use (default = None).
transactional: Flag to indicate if the registry should be transactional or
not (default = False).

Examples
--------
Expand All @@ -48,18 +50,20 @@
>>> loaded_artifact = registry.load(skeys, dkeys)
"""

__slots__ = ("client", "ttl", "cache_registry")
__slots__ = ("client", "ttl", "cache_registry", "transactional")

def __init__(
self,
client: redis_client_t,
ttl: int = 604800,
cache_registry: Optional[ArtifactCache] = None,
transactional: bool = True,
):
super().__init__("")
self.client = client
self.ttl = ttl
self.cache_registry = cache_registry
self.transactional = transactional

@staticmethod
def construct_key(skeys: KEYS, dkeys: KEYS) -> str:
Expand Down Expand Up @@ -155,30 +159,37 @@
------
ModelKeyNotFound: If the model key is not found in the registry.
"""
cached_artifact = self._load_from_cache(key)
latest_key = self.__construct_latest_key(key)
cached_artifact = self._load_from_cache(latest_key)
if cached_artifact:
_LOGGER.debug("Found cached artifact for key: %s", key)
_LOGGER.debug("Found cached artifact for key: %s", latest_key)
return cached_artifact, True
latest_key = self.__construct_latest_key(key)
if not self.client.exists(latest_key):
raise ModelKeyNotFound(f"latest key: {latest_key}, Not Found !!!")
model_key = self.client.get(latest_key)
_LOGGER.debug("latest key, %s, is pointing to the key : %s", latest_key, model_key)
return (
self.__load_version_artifact(version=self.get_version(model_key.decode()), key=key),
False,
artifact, _ = self.__load_version_artifact(
version=self.get_version(model_key.decode()), key=key
)
return artifact, False

def __load_version_artifact(self, version: str, key: str) -> ArtifactData:
model_key = self.__construct_version_key(key, version)
if not self.client.exists(model_key):
raise ModelKeyNotFound("Could not find model key with key: %s" % model_key)
return self.__get_artifact_data(
model_key=model_key,
def __load_version_artifact(self, version: str, key: str) -> tuple[ArtifactData, bool]:
version_key = self.__construct_version_key(key, version)
cached_artifact = self._load_from_cache(version_key)
if cached_artifact:
_LOGGER.debug("Found cached version artifact for key: %s", version_key)
return cached_artifact, True
if not self.client.exists(version_key):
raise ModelKeyNotFound("Could not find model key with key: %s" % version_key)
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
return (
self.__get_artifact_data(
model_key=version_key,
),
False,
)

def __save_artifact(
self, pipe, artifact: artifact_t, metadata: META_T, key: KEYS, version: str
self, pipe, artifact: artifact_t, key: KEYS, version: str, **metadata: META_T
) -> str:
new_version_key = self.__construct_version_key(key, version)
latest_key = self.__construct_latest_key(key)
Expand Down Expand Up @@ -210,6 +221,8 @@
is needed to load the respective artifact.

If cache registry is provided, it will first check the cache registry for the artifact.
If latest is passed, latest key is saved otherwise version call saves the respective
version artifact in cache.

Args:
----
Expand All @@ -230,24 +243,33 @@
if (latest and version) or (not latest and not version):
raise ValueError("Either One of 'latest' or 'version' needed in load method call")
key = self.construct_key(skeys, dkeys)
is_cached = False
try:
if latest:
artifact_data, is_cached = self.__load_latest_artifact(key)
else:
artifact_data = self.__load_version_artifact(version, key)
artifact_data, is_cached = self.__load_version_artifact(version, key)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
if (not is_cached) and latest:
self._save_in_cache(key, artifact_data)
if not is_cached:
if latest:
_LOGGER.info("Saving %s, in cache as %s", self.__construct_latest_key(key), key)
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
self._save_in_cache(self.__construct_latest_key(key), artifact_data)
else:
_LOGGER.info(
"Saving %s, in cache as %s",
self.__construct_version_key(key, version),
key,
)
self._save_in_cache(self.__construct_version_key(key, version), artifact_data)
return artifact_data

def save(
self,
skeys: KEYS,
dkeys: KEYS,
artifact: artifact_t,
pipe=None,
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
**metadata: META_VT,
) -> Optional[str]:
"""Saves the artifact into redis registry and updates version.
Expand Down Expand Up @@ -275,10 +297,15 @@
_LOGGER.debug("Latest key: %s exists for the model", latest_key)
version_key = self.client.get(name=latest_key)
version = int(self.get_version(version_key.decode())) + 1
with self.client.pipeline() as pipe:
new_version_key = self.__save_artifact(pipe, artifact, metadata, key, str(version))
pipe.expire(name=new_version_key, time=self.ttl)
pipe.execute()
redis_pipe = (
self.client.pipeline(transaction=self.transactional) if pipe is None else pipe
)
new_version_key = self.__save_artifact(
pipe=redis_pipe, artifact=artifact, key=key, version=str(version), **metadata
)
redis_pipe.expire(name=new_version_key, time=self.ttl)
if pipe is None:
redis_pipe.execute()
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
Expand Down Expand Up @@ -338,3 +365,47 @@
raise RedisRegistryError("Error fetching timestamp information") from err
stale_ts = (datetime.now() - timedelta(hours=freq_hr)).timestamp()
return stale_ts > artifact_ts

def save_multiple(
self,
skeys: KEYS,
list_dkeys: list[KEYS],
list_artifacts: list[artifact_t],
**metadata: META_VT,
):
"""
Saves multiple artifacts into redis registry. The last save stores all the
artifact versions in the metadata.

Args:
----
skeys: static key fields as list/tuple of strings
list_dkeys: list of dynamic key fields as list/tuple of strings
list_artifacts: list of primary artifacts to be saved
metadata: additional metadata surrounding the artifact that needs to be saved.
"""
dict_model_ver = {}
if len(list_artifacts) != len(list_dkeys):
raise IndexError("artifact list and dkeys list should have same length!")
try:
with self.client.pipeline(transaction=self.transactional) as pipe:
pipe.multi()
for count, (key, artifact) in enumerate(zip(list_dkeys, list_artifacts)):
dict_model_ver[":".join(key)] = self.save(
skeys=skeys, dkeys=key, artifact=artifact, pipe=pipe, **metadata
)
if count == len(list_artifacts) - 1:
self.save(
skeys=skeys,
dkeys=key,
artifact=artifact,
pipe=pipe,
artifact_versions=dict_model_ver,
**metadata,
)
pipe.execute()
_LOGGER.info("Successfully saved all the artifacts with: %s", dict_model_ver)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err

Check warning on line 409 in numalogic/registry/redis_registry.py

View check run for this annotation

Codecov / codecov/patch

numalogic/registry/redis_registry.py#L409

Added line #L409 was not covered by tests
else:
return dict_model_ver
5 changes: 5 additions & 0 deletions numalogic/udfs/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
CONF_FILE_PATH = os.getenv(
"CONF_PATH", default=os.path.join(BASE_CONF_DIR, "default-configs", "config.yaml")
)
pipeline_conf = load_pipeline_conf(CONF_FILE_PATH)
logging.info("Pipeline config: %s", pipeline_conf)

Check warning on line 16 in numalogic/udfs/__main__.py

View check run for this annotation

Codecov / codecov/patch

numalogic/udfs/__main__.py#L15-L16

Added lines #L15 - L16 were not covered by tests

redis_client = get_redis_client_from_conf(pipeline_conf.redis_conf)

Check warning on line 18 in numalogic/udfs/__main__.py

View check run for this annotation

Codecov / codecov/patch

numalogic/udfs/__main__.py#L18

Added line #L18 was not covered by tests

for key in redis_client.scan_iter("*"):
redis_client.delete(key)

Check warning on line 21 in numalogic/udfs/__main__.py

View check run for this annotation

Codecov / codecov/patch

numalogic/udfs/__main__.py#L21

Added line #L21 was not covered by tests
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
if __name__ == "__main__":
set_logger()
step = sys.argv[1]
Expand Down
3 changes: 3 additions & 0 deletions numalogic/udfs/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@
def get_metadata(self, key: str) -> dict[str, Any]:
return copy(self.metadata[key])

def add_metadata(self, data: dict[str, Any]):
self.metadata.update(data)

Check warning on line 85 in numalogic/udfs/entities.py

View check run for this annotation

Codecov / codecov/patch

numalogic/udfs/entities.py#L85

Added line #L85 was not covered by tests
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self) -> str:
return (
f'"StreamPayload(header={self.header}, status={self.status}, '
Expand Down
54 changes: 10 additions & 44 deletions numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@
from pynumaflow.function import Messages, Datum, Message

from numalogic.registry import RedisRegistry, LocalLRUCache, ArtifactData
from numalogic.tools.exceptions import RedisRegistryError, ModelKeyNotFound, ConfigNotFoundError
from numalogic.tools.exceptions import ConfigNotFoundError
from numalogic.tools.types import artifact_t, redis_client_t
from numalogic.udfs._base import NumalogicUDF
from numalogic.udfs._config import StreamConf, PipelineConf
from numalogic.udfs.entities import StreamPayload, Header, Status
from numalogic.udfs.tools import _load_artifact

_LOGGER = logging.getLogger(__name__)

# TODO: move to config
LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600"))
LOCAL_CACHE_SIZE = int(os.getenv("LOCAL_CACHE_SIZE", "10000"))
LOAD_LATEST = os.getenv("LOAD_LATEST", "false").lower() == "true"


class InferenceUDF(NumalogicUDF):
Expand Down Expand Up @@ -111,8 +113,13 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
# Forward payload if a training request is tagged
if payload.header == Header.TRAIN_REQUEST:
return Messages(Message(keys=keys, value=payload.to_json()))

artifact_data = self.load_artifact(keys, payload)
artifact_data, payload = _load_artifact(
skeys=keys,
dkeys=[self.get_conf(payload.config_id).numalogic_conf.model.name],
payload=payload,
model_registry=self.model_registry,
load_latest=LOAD_LATEST,
)

# TODO: revisit retraining logic
# Send training request if artifact loading is not successful
Expand Down Expand Up @@ -162,47 +169,6 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
)
return Messages(Message(keys=keys, value=payload.to_json()))

def load_artifact(self, keys: list[str], payload: StreamPayload) -> Optional[ArtifactData]:
"""
Load inference artifact from the registry.

Args:
keys: List of keys
payload: StreamPayload object

Returns
-------
ArtifactData instance
"""
_conf = self.get_conf(payload.config_id).numalogic_conf
try:
artifact_data = self.model_registry.load(
skeys=keys,
dkeys=[_conf.model.name],
)
except ModelKeyNotFound:
_LOGGER.warning(
"%s - Model key not found for Keys: %s, Metric: %s",
payload.uuid,
payload.composite_keys,
payload.metrics,
)
return None
except RedisRegistryError:
_LOGGER.exception(
"%s - Error while fetching inference artifact, Keys: %s, Metric: %s",
payload.uuid,
payload.composite_keys,
payload.metrics,
)
return None
_LOGGER.info(
"%s - Loaded artifact data from %s",
payload.uuid,
artifact_data.extras.get("source"),
)
return artifact_data

def is_model_stale(self, artifact_data: ArtifactData, payload: StreamPayload) -> bool:
"""
Check if the inference artifact is stale.
Expand Down
11 changes: 8 additions & 3 deletions numalogic/udfs/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from numalogic.udfs import NumalogicUDF
from numalogic.udfs._config import StreamConf, PipelineConf
from numalogic.udfs.entities import StreamPayload, Header, Status, TrainerPayload, OutputPayload
from numalogic.udfs.tools import _load_model
from numalogic.udfs.tools import _load_artifact

# TODO: move to config
LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600"))
LOCAL_CACHE_SIZE = int(os.getenv("LOCAL_CACHE_SIZE", "10000"))
LOAD_LATEST = os.getenv("LOAD_LATEST", "false").lower() == "true"

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,8 +89,12 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
postprocess_cfg = self.get_conf(payload.config_id).numalogic_conf.postprocess

# load artifact
thresh_artifact = _load_model(
skeys=keys, dkeys=[thresh_cfg.name], payload=payload, model_registry=self.model_registry
thresh_artifact, payload = _load_artifact(
skeys=keys,
dkeys=[thresh_cfg.name],
payload=payload,
model_registry=self.model_registry,
load_latest=LOAD_LATEST,
)
postproc_clf = self.postproc_factory.get_instance(postprocess_cfg)

Expand Down
Loading