Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rds trainer1 #2

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion numalogic/config/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get_cls(cls, name: str):
class ConnectorFactory(_ObjectFactory):
"""Factory class for data connectors."""

_CLS_SET: ClassVar[frozenset] = frozenset({"PrometheusFetcher", "DruidFetcher"})
_CLS_SET: ClassVar[frozenset] = frozenset({"PrometheusFetcher", "DruidFetcher", "RDSFetcher"})

@classmethod
def get_cls(cls, name: str):
Expand Down
7 changes: 6 additions & 1 deletion numalogic/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
DruidConf,
DruidFetcherConf,
ConnectorType,
RDSConf,
RDSFetcherConf,
)
from numalogic.connectors.rds import RDSFetcher
from numalogic.connectors.prometheus import PrometheusFetcher

__all__ = [
Expand All @@ -18,9 +21,11 @@
"DruidFetcherConf",
"ConnectorType",
"PrometheusFetcher",
"RDSFetcher",
"RDSConf",
"RDSFetcherConf",
]


if find_spec("pydruid"):
from numalogic.connectors.druid import DruidFetcher # noqa: F401

Expand Down
67 changes: 67 additions & 0 deletions numalogic/connectors/_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Optional
from numalogic.connectors.utils.aws.config import RDSConnectionConfig
from numalogic.connectors.exceptions import RDSFetcherConfValidationException


class ConnectorType(IntEnum):
Expand Down Expand Up @@ -52,6 +54,49 @@ def __post_init__(self):
self.aggregations = {"count": doublesum("count")}


@dataclass
class RDSFetcherConf:
"""
RDSFetcherConf class represents the configuration for fetching data from an RDS data source.

Args:
datasource (str): The name of the data source.
dimensions (list[str]): A list of dimension column names.
group_by (list[str]): A list of column names to group the data by.
pivot (Pivot): An instance of the Pivot class representing the pivot configuration.
hash_query_type (bool): A boolean indicating whether to use hash query type.
hash_column_name (Optional[str]): The name of the hash column. (default: None)
datetime_column_name (str): The name of the datetime column. (default: "eventdatetime")
metrics (list[str]): A list of metric column names.

Methods
-------
__post_init__(): Performs post-initialization validation checks.

Raises
------
RDSFetcherConfValidationException: If the hash_query_type is enabled
but hash_column_name is not provided.
"""

datasource: str
dimensions: list[str]
# metric column names
metrics: list[str]
group_by: list[str] = field(default_factory=list)
pivot: Pivot = field(default_factory=lambda: Pivot())
hash_query_type: bool = True
hash_column_name: str = "model_md5_hash"
datetime_column_name: str = "eventdatetime"

def __post_init__(self):
if self.hash_query_type:
if self.hash_column_name.strip() == "":
raise RDSFetcherConfValidationException(
"when hash_query_type is enabled, hash_column_name is required property "
)


@dataclass
class DruidConf(ConnectorConf):
"""
Expand All @@ -70,3 +115,25 @@ class DruidConf(ConnectorConf):
delay_hrs: float = 3.0
fetcher: Optional[DruidFetcherConf] = None
id_fetcher: Optional[dict[str, DruidFetcherConf]] = None


@dataclass
class RDSConf:
"""
Class representing the configuration for fetching data from an RDS data source.

Args:
connection_conf (RDSConnectionConfig): An instance of the RDSConnectionConfig class
representing the connection configuration.
delay_hrs (float): The delay in hours for fetching data. Defaults to 3.0.
fetcher (Optional[RDSFetcherConf]): An optional instance of the RDSFetcherConf class
representing the fetcher configuration. Defaults to None.
id_fetcher (Optional[dict[str, RDSFetcherConf]]): An optional dictionary mapping IDs to
instances of the RDSFetcherConf class representing the fetcher configuration.
Defaults to None.
"""

connection_conf: RDSConnectionConfig
delay_hrs: float = 3.0
fetcher: Optional[RDSFetcherConf] = None
id_fetcher: Optional[dict[str, RDSFetcherConf]] = None
10 changes: 10 additions & 0 deletions numalogic/connectors/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class ConnectorFetcherException(Exception):
"""Custom exception class for grouping all Connector Exceptions together."""

pass


class RDSFetcherConfValidationException(ConnectorFetcherException):
"""A custom exception class for handling validation errors in RDSFetcherConf."""

pass
14 changes: 7 additions & 7 deletions numalogic/connectors/rds/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABCMeta, abstractmethod
from typing import Optional
import pandas as pd
from numalogic.connectors.utils.aws.config import DatabaseServiceProvider, RDSConfig
from numalogic.connectors.utils.aws.config import DatabaseServiceProvider, RDSConnectionConfig
from numalogic.connectors.utils.aws.boto3_client_manager import Boto3ClientManager
import logging
from numalogic.connectors._config import Pivot
Expand All @@ -13,7 +13,7 @@
def format_dataframe(
df: pd.DataFrame,
query: str,
datetime_field_name: str,
datetime_column_name: str,
group_by: Optional[list[str]] = None,
pivot: Optional[Pivot] = None,
) -> pd.DataFrame:
Expand All @@ -26,7 +26,7 @@ def format_dataframe(
The input DataFrame to be formatted.
query : str
The SQL query used to retrieve the data.
datetime_field_name : str
datetime_column_name : str
The name of the datetime field in the DataFrame.
group_by : Optional[list[str]], optional
A list of column names to group the DataFrame by, by default None.
Expand All @@ -40,8 +40,8 @@ def format_dataframe(

"""
_start_time = time.perf_counter()
df["timestamp"] = pd.to_datetime(df[datetime_field_name]).astype("int64") // 10**6
df.drop(columns=datetime_field_name, inplace=True)
df["timestamp"] = pd.to_datetime(df[datetime_column_name]).astype("int64") // 10**6
df.drop(columns=datetime_column_name, inplace=True)
if group_by:
df = df.groupby(by=group_by).sum().reset_index()

Expand All @@ -65,12 +65,12 @@ class represents a data fetcher for RDS (Relational Database Service) connection
connection, and executing queries.

Args:
- db_config (RDSConfig): The configuration object for the RDS connection.
- db_config (RDSConnectionConfig): The configuration object for the RDS connection.
- kwargs (dict): Additional keyword arguments.

"""

def __init__(self, db_config: RDSConfig, **kwargs):
def __init__(self, db_config: RDSConnectionConfig, **kwargs):
self.kwargs = kwargs
self.db_config = db_config
self.connection = None
Expand Down
16 changes: 10 additions & 6 deletions numalogic/connectors/rds/_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from numalogic.connectors._base import DataFetcher
from numalogic.connectors._config import Pivot
from numalogic.connectors.rds._base import format_dataframe
from numalogic.connectors.utils.aws.config import RDSConfig
from numalogic.connectors.utils.aws.config import RDSConnectionConfig
import logging
import pandas as pd
from numalogic.connectors.rds.db.factory import RdsFactory
Expand All @@ -19,12 +19,12 @@ class is a subclass of DataFetcher and ABC (Abstract Base Class).

Attributes
----------
db_config (RDSConfig): The configuration object for the RDS instance.
db_config (RDSConnectionConfig): The configuration object for the RDS instance.
fetcher (db.CLASS_TYPE): The fetcher object for the specific database type.

"""

def __init__(self, db_config: RDSConfig):
def __init__(self, db_config: RDSConnectionConfig):
super().__init__(db_config.endpoint)
self.db_config = db_config
factory_object = RdsFactory()
Expand All @@ -34,7 +34,7 @@ def __init__(self, db_config: RDSConfig):
def fetch(
self,
query,
datetime_field_name: str,
datetime_column_name: str,
pivot: Optional[Pivot] = None,
group_by: Optional[list[str]] = None,
) -> pd.DataFrame:
Expand All @@ -43,7 +43,7 @@ def fetch(

Args:
query (str): The SQL query to be executed.
datetime_field_name (str): The name of the datetime field in the fetched data.
datetime_column_name (str): The name of the datetime field in the fetched data.
pivot (Optional[Pivot], optional): The pivot configuration for the fetched data.
Defaults to None.
group_by (Optional[list[str]], optional): The list of fields to group the
Expand All @@ -60,7 +60,11 @@ def fetch(
return pd.DataFrame()

formatted_df = format_dataframe(
df, query=query, datetime_field_name=datetime_field_name, pivot=pivot, group_by=group_by
df,
query=query,
datetime_column_name=datetime_column_name,
pivot=pivot,
group_by=group_by,
)
_end_time = time.perf_counter() - _start_time
_LOGGER.info("RDS Query: %s Fetch Time: %.4fs", query, _end_time)
Expand Down
12 changes: 6 additions & 6 deletions numalogic/connectors/rds/db/mysql_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import logging

from numalogic.connectors.utils.aws.config import DatabaseTypes, RDSConfig
from numalogic.connectors.utils.aws.config import DatabaseTypes, RDSConnectionConfig

_LOGGER = logging.getLogger(__name__)

Expand All @@ -16,8 +16,8 @@ class MysqlFetcher(RDSBase):
"""
class that inherits from RDSBase. It is used to fetch data from a MySQL database.

- __init__(self, db_config: RDSConfig, **kwargs): Initializes the MysqlFetcher object with
the given RDSConfig and additional keyword arguments.
- __init__(self, db_config: RDSConnectionConfig, **kwargs): Initializes the MysqlFetcher object
with the given RDSConnectionConfig and additional keyword arguments.

The MysqlFetcher class is designed to be used as a base class for fetching data from a MySQL
database. It provides methods for establishing a connection, executing queries,
Expand All @@ -27,7 +27,7 @@ class that inherits from RDSBase. It is used to fetch data from a MySQL database

database_type = DatabaseTypes.MYSQL

def __init__(self, db_config: RDSConfig, **kwargs):
def __init__(self, db_config: RDSConnectionConfig, **kwargs):
super().__init__(db_config)
self.db_config = db_config
self.kwargs = kwargs
Expand All @@ -44,8 +44,8 @@ def get_connection(self) -> pymysql.Connection:
------
None

Notes: - If SSL/TLS is enabled and configured in the RDSConfig object, the connection
will be established with SSL/TLS. - If SSL/TLS is not enabled or configured,
Notes: - If SSL/TLS is enabled and configured in the RDSConnectionConfig object,
the connection will be established with SSL/TLS. - If SSL/TLS is not enabled or configured,
the connection will be established without SSL/TLS.

"""
Expand Down
4 changes: 2 additions & 2 deletions numalogic/connectors/utils/aws/boto3_client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from boto3 import Session
import logging

from numalogic.connectors.utils.aws.config import DatabaseServiceProvider, RDSConfig
from numalogic.connectors.utils.aws.config import DatabaseServiceProvider, RDSConnectionConfig
from numalogic.connectors.utils.aws.exceptions import UnRecognizedAWSClientException
from numalogic.connectors.utils.aws.sts_client_manager import STSClientManager

Expand Down Expand Up @@ -30,7 +30,7 @@ class Boto3ClientManager:
methods.
"""

def __init__(self, configurations: RDSConfig):
def __init__(self, configurations: RDSConnectionConfig):
self.rds_client = None
self.athena_client = None
self.configurations = configurations
Expand Down
2 changes: 1 addition & 1 deletion numalogic/connectors/utils/aws/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class RDBMSConfig:


@dataclass
class RDSConfig(AWSConfig, RDBMSConfig):
class RDSConnectionConfig(AWSConfig, RDBMSConfig):
"""
Class representing the configuration for an RDS (Relational Database Service) instance.

Expand Down
6 changes: 3 additions & 3 deletions numalogic/connectors/utils/aws/db_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from numalogic.tools.exceptions import ConfigNotFoundError
from omegaconf import OmegaConf
from numalogic.connectors.utils.aws.config import RDSConfig
from numalogic.connectors.utils.aws.config import RDSConnectionConfig

_LOGGER = logging.getLogger(__name__)


def load_db_conf(*paths: str) -> RDSConfig:
def load_db_conf(*paths: str) -> RDSConnectionConfig:
"""
Load database configuration from one or more YAML files.

Expand Down Expand Up @@ -38,6 +38,6 @@ def load_db_conf(*paths: str) -> RDSConfig:
_err_msg = f"None of the given conf paths exist: {paths}"
raise ConfigNotFoundError(_err_msg)

schema = OmegaConf.structured(RDSConfig)
schema = OmegaConf.structured(RDSConnectionConfig)
conf = OmegaConf.merge(schema, *confs)
return OmegaConf.to_object(conf)
8 changes: 7 additions & 1 deletion numalogic/tools/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class DataFormatError(Exception):


class DruidFetcherError(Exception):
"""Base class for all exceptions raised by the PrometheusFetcher class."""
"""Base class for all exceptions raised by the DruidFetcher class."""

pass


class RDSFetcherError(Exception):
"""Base class for all exceptions raised by the RDSFetcher class."""

pass
4 changes: 2 additions & 2 deletions numalogic/udfs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from logging import config as logconf
import os


from numalogic._constants import BASE_DIR
from numalogic.udfs._base import NumalogicUDF
from numalogic.udfs._config import StreamConf, PipelineConf, MLPipelineConf, load_pipeline_conf
Expand All @@ -11,7 +10,7 @@
from numalogic.udfs.inference import InferenceUDF
from numalogic.udfs.postprocess import PostprocessUDF
from numalogic.udfs.preprocess import PreprocessUDF
from numalogic.udfs.trainer import TrainerUDF, PromTrainerUDF, DruidTrainerUDF
from numalogic.udfs.trainer import TrainerUDF, PromTrainerUDF, DruidTrainerUDF, RDSTrainerUDF


def set_logger() -> None:
Expand All @@ -32,6 +31,7 @@ def set_logger() -> None:
"TrainerUDF",
"PromTrainerUDF",
"DruidTrainerUDF",
"RDSTrainerUDF",
"PostprocessUDF",
"UDFFactory",
"StreamConf",
Expand Down
3 changes: 3 additions & 0 deletions numalogic/udfs/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
RedisConf,
PrometheusConf,
DruidConf,
RDSConf,
)
from numalogic.tools.exceptions import ConfigNotFoundError

Expand Down Expand Up @@ -68,6 +69,7 @@ class PipelineConf:
registry_conf (Optional[RegistryInfo]): The configuration for the registry.
prometheus_conf (Optional[PrometheusConf]): The configuration for Prometheus.
druid_conf (Optional[DruidConf]): The configuration for Druid.
rds_conf (Optional[RDSConf]): The configuration for RDS.
"""

stream_confs: dict[str, StreamConf] = field(default_factory=dict)
Expand All @@ -77,6 +79,7 @@ class PipelineConf:
)
prometheus_conf: Optional[PrometheusConf] = None
druid_conf: Optional[DruidConf] = None
rds_conf: Optional[RDSConf] = None


def load_pipeline_conf(*paths: str) -> PipelineConf:
Expand Down
Loading