Skip to content

Commit

Permalink
add retrain logic for insufficient data (#321)
Browse files Browse the repository at this point in the history
Tackling retraining for insufficient data.

<img width="1791" alt="image"
src="https://github.com/numaproj/numalogic/assets/34571348/5012e0c3-2210-4c7c-a721-e5dc10660e67">

<img width="1756" alt="image"
src="https://github.com/numaproj/numalogic/assets/34571348/9ae30e85-b0d7-428f-bf6c-ebf3c4fa2a7d">

---------

Signed-off-by: s0nicboOm <[email protected]>
  • Loading branch information
s0nicboOm authored Nov 14, 2023
1 parent 0bbb53d commit 46cdfcd
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 13 deletions.
3 changes: 3 additions & 0 deletions numalogic/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,14 @@ class LightningTrainerConf:

@dataclass
class TrainerConf:
"""Schema for defining the trainer config."""

train_hours: int = 24 * 8 # 8 days worth of data
min_train_size: int = 2000
retrain_freq_hr: int = 24
retry_sec: int = 600 # 10 min
batch_size: int = 64
data_freq_sec: int = 60
pltrainer_conf: LightningTrainerConf = field(default_factory=LightningTrainerConf)


Expand Down
1 change: 0 additions & 1 deletion numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
load_latest=LOAD_LATEST,
)

# TODO: revisit retraining logic
# Send training request if artifact loading is not successful
if not artifact_data:
payload = replace(
Expand Down
94 changes: 84 additions & 10 deletions numalogic/udfs/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from dataclasses import replace
import time
from typing import Optional
from typing import Optional, NamedTuple

import numpy as np
import pandas as pd
Expand All @@ -17,6 +17,14 @@
_LOGGER = logging.getLogger(__name__)


class _DedupMetadata(NamedTuple):
"""Data Structure for Dedup Metadata."""

msg_read_ts: Optional[str]
msg_train_ts: Optional[str]
msg_train_records: Optional[str]


def get_df(
data_payload: dict, stream_conf: StreamConf, fill_value: float = 0.0
) -> tuple[DataFrame, list[int]]:
Expand Down Expand Up @@ -164,43 +172,109 @@ def __init__(self, r_client: redis_client_t):
def __construct_key(keys: KEYS) -> str:
return f"TRAIN::{':'.join(keys)}"

def __fetch_ts(self, key: str) -> tuple[Optional[str], Optional[str]]:
def __fetch_ts(self, key: str) -> _DedupMetadata:
try:
data = self.client.hgetall(key)
except RedisError:
_LOGGER.exception("Problem fetching ts information for the key: %s", key)
return None, None
return _DedupMetadata(msg_read_ts=None, msg_train_ts=None, msg_train_records=None)
else:
# decode the key:value pair and update the values
data = {key.decode(): data.get(key).decode() for key in data}
_msg_read_ts = str(data["_msg_read_ts"]) if data and "_msg_read_ts" in data else None
_msg_train_ts = str(data["_msg_train_ts"]) if data and "_msg_train_ts" in data else None
return _msg_read_ts, _msg_train_ts
_msg_train_records = (
str(data["_msg_train_records"]) if data and "_msg_train_records" in data else None
)
return _DedupMetadata(
msg_read_ts=_msg_read_ts,
msg_train_ts=_msg_train_ts,
msg_train_records=_msg_train_records,
)

def ack_read(self, key: KEYS, uuid: str, retrain_freq: int = 24, retry: int = 600) -> bool:
def ack_insufficient_data(self, key: KEYS, uuid: str, train_records: int) -> bool:
"""
Acknowledge the read message. Return True wh`en the msg has to be trained.
Acknowledge the insufficient data message. Retry training after certain period of a time.
Args:
key: key
uuid: uuid
train_records: number of train records found.
Returns
-------
bool.
"""
_key = self.__construct_key(key)
try:
self.client.hset(name=_key, key="_msg_train_records", value=str(train_records))
except RedisError:
_LOGGER.exception(
" %s - Problem while updating _msg_train_records information for the key: %s",
uuid,
key,
)
return False
else:
_LOGGER.info("%s - Acknowledging insufficient data for the key: %s", uuid, key)
return True

def ack_read(
self,
key: KEYS,
uuid: str,
retrain_freq: int = 24,
retry: int = 600,
min_train_records: int = 180,
data_freq: int = 60,
) -> bool:
"""
Acknowledge the read message. Return True when the msg has to be trained.
Args:
key: key
uuid: uuid.
retrain_freq: retrain frequency for the model in hrs
retry: Time difference(in secs) between triggering retraining and msg read_ack.
min_train_records: minimum number of records required for training.
data_freq: data granularity/frequency in secs.
Returns
-------
bool
"""
_key = self.__construct_key(key)
_msg_read_ts, _msg_train_ts = self.__fetch_ts(key=_key)
metadata = self.__fetch_ts(key=_key)
_msg_read_ts, _msg_train_ts, _msg_train_records = (
metadata.msg_read_ts,
metadata.msg_train_ts,
metadata.msg_train_records,
)
# If insufficient data: retry after (min_train_records-train_records) * data_granularity
if (
_msg_train_records
and _msg_read_ts
and time.time() - float(_msg_read_ts)
< (min_train_records - int(_msg_train_records)) * data_freq
):
_LOGGER.info(
"%s - There was insufficient data for the key in the past: %s. Retrying fetching"
" and training after %s secs",
uuid,
key,
(min_train_records - int(_msg_train_records)) * data_freq,
)

return False

# Check if the model is being trained by another process
if _msg_read_ts and time.time() - float(_msg_read_ts) < retry:
_LOGGER.info("%s - Model with key : %s is being trained by another process", uuid, key)
return False

# This check is needed if there is backpressure in the pl.
# This check is needed if there is backpressure in the pipeline
if _msg_train_ts and time.time() - float(_msg_train_ts) < retrain_freq * 60 * 60:
_LOGGER.info(
"%s - Model was saved for the key: %s in less than %s secs, skipping training",
"%s - Model was saved for the key: %s in less than %s hrs, skipping training",
uuid,
key,
retrain_freq,
Expand Down Expand Up @@ -241,5 +315,5 @@ def ack_train(self, key: KEYS, uuid: str) -> bool:
)
return False
else:
_LOGGER.info("%s - Acknowledging model saving complete for for the key: %s", uuid, key)
_LOGGER.info("%s - Acknowledging model saving complete for the key: %s", uuid, key)
return True
9 changes: 8 additions & 1 deletion numalogic/udfs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
uuid=payload.uuid,
retrain_freq=retrain_freq_ts,
retry=retry_ts,
min_train_records=_conf.numalogic_conf.trainer.min_train_size,
data_freq=_conf.numalogic_conf.trainer.data_freq_sec,
):
return Messages(Message.to_drop())

Expand Down Expand Up @@ -320,7 +322,12 @@ def artifacts_to_save(

def _is_data_sufficient(self, payload: TrainerPayload, df: pd.DataFrame) -> bool:
_conf = self.get_conf(payload.config_id)
return len(df) > _conf.numalogic_conf.trainer.min_train_size
if len(df) < _conf.numalogic_conf.trainer.min_train_size:
_ = self.train_msg_deduplicator.ack_insufficient_data(
key=payload.composite_keys, uuid=payload.uuid, train_records=len(df)
)
return False
return True

@staticmethod
def get_feature_arr(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.6.0rc0"
version = "0.6.0"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
17 changes: 17 additions & 0 deletions tests/udfs/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,21 @@ def test_trainer_do_not_train_3(self):
),
)

@patch.object(DruidFetcher, "fetch", Mock(return_value=mock_druid_fetch_data(50)))
def test_trainer_do_not_train_4(self):
self.udf1.register_conf(
"druid-config",
StreamConf(
numalogic_conf=NumalogicConf(
model=ModelInfo(name="VanillaAE", conf={"seq_len": 12, "n_features": 2}),
preprocess=[ModelInfo(name="LogTransformer"), ModelInfo(name="StandardScaler")],
trainer=TrainerConf(pltrainer_conf=LightningTrainerConf(max_epochs=1)),
)
),
)
self.udf1(self.keys, self.datum)
self.udf1(self.keys, self.datum)

def test_trainer_conf_err(self):
self.udf1 = TrainerUDF(
REDIS_CLIENT,
Expand Down Expand Up @@ -308,6 +323,8 @@ def test_TrainMsgDeduplicator_exception_1(self):
self.assertLogs("RedisError")
train_dedup.ack_train(self.keys, "some-uuid")
self.assertLogs("RedisError")
train_dedup.ack_insufficient_data(self.keys, "some-uuid", train_records=180)
self.assertLogs("RedisError")

@patch("redis.Redis.hgetall", Mock(side_effect=RedisError))
def test_TrainMsgDeduplicator_exception_2(self):
Expand Down

0 comments on commit 46cdfcd

Please sign in to comment.