diff --git a/numalogic/config/_config.py b/numalogic/config/_config.py index c9045e0d..42abc89b 100644 --- a/numalogic/config/_config.py +++ b/numalogic/config/_config.py @@ -85,11 +85,14 @@ class LightningTrainerConf: @dataclass class TrainerConf: + """Schema for defining the trainer config.""" + train_hours: int = 24 * 8 # 8 days worth of data min_train_size: int = 2000 retrain_freq_hr: int = 24 retry_sec: int = 600 # 10 min batch_size: int = 64 + data_freq_sec: int = 60 pltrainer_conf: LightningTrainerConf = field(default_factory=LightningTrainerConf) diff --git a/numalogic/udfs/inference.py b/numalogic/udfs/inference.py index 9d7e59ab..aa5d10f8 100644 --- a/numalogic/udfs/inference.py +++ b/numalogic/udfs/inference.py @@ -130,7 +130,6 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: load_latest=LOAD_LATEST, ) - # TODO: revisit retraining logic # Send training request if artifact loading is not successful if not artifact_data: payload = replace( diff --git a/numalogic/udfs/tools.py b/numalogic/udfs/tools.py index 5a31046c..78105b97 100644 --- a/numalogic/udfs/tools.py +++ b/numalogic/udfs/tools.py @@ -1,7 +1,7 @@ import logging from dataclasses import replace import time -from typing import Optional +from typing import Optional, NamedTuple import numpy as np import pandas as pd @@ -17,6 +17,14 @@ _LOGGER = logging.getLogger(__name__) +class _DedupMetadata(NamedTuple): + """Data Structure for Dedup Metadata.""" + + msg_read_ts: Optional[str] + msg_train_ts: Optional[str] + msg_train_records: Optional[str] + + def get_df( data_payload: dict, stream_conf: StreamConf, fill_value: float = 0.0 ) -> tuple[DataFrame, list[int]]: @@ -164,27 +172,70 @@ def __init__(self, r_client: redis_client_t): def __construct_key(keys: KEYS) -> str: return f"TRAIN::{':'.join(keys)}" - def __fetch_ts(self, key: str) -> tuple[Optional[str], Optional[str]]: + def __fetch_ts(self, key: str) -> _DedupMetadata: try: data = self.client.hgetall(key) except RedisError: _LOGGER.exception("Problem fetching ts information for the key: %s", key) - return None, None + return _DedupMetadata(msg_read_ts=None, msg_train_ts=None, msg_train_records=None) else: # decode the key:value pair and update the values data = {key.decode(): data.get(key).decode() for key in data} _msg_read_ts = str(data["_msg_read_ts"]) if data and "_msg_read_ts" in data else None _msg_train_ts = str(data["_msg_train_ts"]) if data and "_msg_train_ts" in data else None - return _msg_read_ts, _msg_train_ts + _msg_train_records = ( + str(data["_msg_train_records"]) if data and "_msg_train_records" in data else None + ) + return _DedupMetadata( + msg_read_ts=_msg_read_ts, + msg_train_ts=_msg_train_ts, + msg_train_records=_msg_train_records, + ) - def ack_read(self, key: KEYS, uuid: str, retrain_freq: int = 24, retry: int = 600) -> bool: + def ack_insufficient_data(self, key: KEYS, uuid: str, train_records: int) -> bool: """ - Acknowledge the read message. Return True wh`en the msg has to be trained. + Acknowledge the insufficient data message. Retry training after certain period of a time. + Args: + key: key + uuid: uuid + train_records: number of train records found. + + Returns + ------- + bool. + """ + _key = self.__construct_key(key) + try: + self.client.hset(name=_key, key="_msg_train_records", value=str(train_records)) + except RedisError: + _LOGGER.exception( + " %s - Problem while updating _msg_train_records information for the key: %s", + uuid, + key, + ) + return False + else: + _LOGGER.info("%s - Acknowledging insufficient data for the key: %s", uuid, key) + return True + + def ack_read( + self, + key: KEYS, + uuid: str, + retrain_freq: int = 24, + retry: int = 600, + min_train_records: int = 180, + data_freq: int = 60, + ) -> bool: + """ + Acknowledge the read message. Return True when the msg has to be trained. Args: key: key uuid: uuid. retrain_freq: retrain frequency for the model in hrs retry: Time difference(in secs) between triggering retraining and msg read_ack. + min_train_records: minimum number of records required for training. + data_freq: data granularity/frequency in secs. Returns ------- @@ -192,15 +243,38 @@ def ack_read(self, key: KEYS, uuid: str, retrain_freq: int = 24, retry: int = 60 """ _key = self.__construct_key(key) - _msg_read_ts, _msg_train_ts = self.__fetch_ts(key=_key) + metadata = self.__fetch_ts(key=_key) + _msg_read_ts, _msg_train_ts, _msg_train_records = ( + metadata.msg_read_ts, + metadata.msg_train_ts, + metadata.msg_train_records, + ) + # If insufficient data: retry after (min_train_records-train_records) * data_granularity + if ( + _msg_train_records + and _msg_read_ts + and time.time() - float(_msg_read_ts) + < (min_train_records - int(_msg_train_records)) * data_freq + ): + _LOGGER.info( + "%s - There was insufficient data for the key in the past: %s. Retrying fetching" + " and training after %s secs", + uuid, + key, + (min_train_records - int(_msg_train_records)) * data_freq, + ) + + return False + + # Check if the model is being trained by another process if _msg_read_ts and time.time() - float(_msg_read_ts) < retry: _LOGGER.info("%s - Model with key : %s is being trained by another process", uuid, key) return False - # This check is needed if there is backpressure in the pl. + # This check is needed if there is backpressure in the pipeline if _msg_train_ts and time.time() - float(_msg_train_ts) < retrain_freq * 60 * 60: _LOGGER.info( - "%s - Model was saved for the key: %s in less than %s secs, skipping training", + "%s - Model was saved for the key: %s in less than %s hrs, skipping training", uuid, key, retrain_freq, @@ -241,5 +315,5 @@ def ack_train(self, key: KEYS, uuid: str) -> bool: ) return False else: - _LOGGER.info("%s - Acknowledging model saving complete for for the key: %s", uuid, key) + _LOGGER.info("%s - Acknowledging model saving complete for the key: %s", uuid, key) return True diff --git a/numalogic/udfs/trainer.py b/numalogic/udfs/trainer.py index bded2481..059975f3 100644 --- a/numalogic/udfs/trainer.py +++ b/numalogic/udfs/trainer.py @@ -215,6 +215,8 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: uuid=payload.uuid, retrain_freq=retrain_freq_ts, retry=retry_ts, + min_train_records=_conf.numalogic_conf.trainer.min_train_size, + data_freq=_conf.numalogic_conf.trainer.data_freq_sec, ): return Messages(Message.to_drop()) @@ -320,7 +322,12 @@ def artifacts_to_save( def _is_data_sufficient(self, payload: TrainerPayload, df: pd.DataFrame) -> bool: _conf = self.get_conf(payload.config_id) - return len(df) > _conf.numalogic_conf.trainer.min_train_size + if len(df) < _conf.numalogic_conf.trainer.min_train_size: + _ = self.train_msg_deduplicator.ack_insufficient_data( + key=payload.composite_keys, uuid=payload.uuid, train_records=len(df) + ) + return False + return True @staticmethod def get_feature_arr( diff --git a/pyproject.toml b/pyproject.toml index 6bc7a336..97c9332c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic" -version = "0.6.0rc0" +version = "0.6.0" description = "Collection of operational Machine Learning models and tools." authors = ["Numalogic Developers"] packages = [{ include = "numalogic" }] diff --git a/tests/udfs/test_trainer.py b/tests/udfs/test_trainer.py index d690bee1..76667fa9 100644 --- a/tests/udfs/test_trainer.py +++ b/tests/udfs/test_trainer.py @@ -251,6 +251,21 @@ def test_trainer_do_not_train_3(self): ), ) + @patch.object(DruidFetcher, "fetch", Mock(return_value=mock_druid_fetch_data(50))) + def test_trainer_do_not_train_4(self): + self.udf1.register_conf( + "druid-config", + StreamConf( + numalogic_conf=NumalogicConf( + model=ModelInfo(name="VanillaAE", conf={"seq_len": 12, "n_features": 2}), + preprocess=[ModelInfo(name="LogTransformer"), ModelInfo(name="StandardScaler")], + trainer=TrainerConf(pltrainer_conf=LightningTrainerConf(max_epochs=1)), + ) + ), + ) + self.udf1(self.keys, self.datum) + self.udf1(self.keys, self.datum) + def test_trainer_conf_err(self): self.udf1 = TrainerUDF( REDIS_CLIENT, @@ -308,6 +323,8 @@ def test_TrainMsgDeduplicator_exception_1(self): self.assertLogs("RedisError") train_dedup.ack_train(self.keys, "some-uuid") self.assertLogs("RedisError") + train_dedup.ack_insufficient_data(self.keys, "some-uuid", train_records=180) + self.assertLogs("RedisError") @patch("redis.Redis.hgetall", Mock(side_effect=RedisError)) def test_TrainMsgDeduplicator_exception_2(self):