diff --git a/README.md b/README.md index 91dac241..2e965dee 100644 --- a/README.md +++ b/README.md @@ -76,10 +76,10 @@ metrics: metric_type: row_count table: table_1 filter: - sql_query: "category = 'HAT' AND is_valid is True" + where_clause: "category = 'HAT' AND is_valid is True" count_content_non_valid: metric_type: row_count table: table_1 filter: - sql_query: "is_valid is False" + where_clause: "is_valid is False" ``` \ No newline at end of file diff --git a/datachecks/core/configuration/configuration.py b/datachecks/core/configuration/configuration.py index 7e2b7779..896c414c 100644 --- a/datachecks/core/configuration/configuration.py +++ b/datachecks/core/configuration/configuration.py @@ -14,10 +14,9 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import yaml -from yaml import SafeLoader class DatasourceType(Enum): @@ -56,8 +55,8 @@ class MetricsFilterConfiguration: Filter configuration for a metric """ - sql_query: Optional[list] - search_query: Optional[list] + where_clause: Optional[str] = None + search_query: Optional[str] = None @dataclass @@ -70,7 +69,8 @@ class MetricConfiguration: metric_type: str index: Optional[str] = None table: Optional[str] = None - filter: Optional[MetricsFilterConfiguration] = None + field: Optional[str] = None + filters: Optional[MetricsFilterConfiguration] = None @dataclass @@ -125,11 +125,12 @@ def load_configuration_from_yaml_str(yaml_string: str) -> Configuration: metric_type=metric_value["metric_type"], index=metric_value.get("index"), table=metric_value.get("table"), - filter=MetricsFilterConfiguration( - sql_query=metric_value.get("filter", {}).get("sql_query", None), - search_query=metric_value.get("filter", {}).get( + field=metric_value.get("field"), + filters=MetricsFilterConfiguration( + where_clause=metric_value.get("filters", {}).get("where_clause", None), + search_query=metric_value.get("filters", {}).get( "search_query", None - ), + ) ), ) for metric_name, metric_value in metric_list.items() diff --git a/datachecks/core/datasource/base.py b/datachecks/core/datasource/base.py index d036f626..98a8f3a6 100644 --- a/datachecks/core/datasource/base.py +++ b/datachecks/core/datasource/base.py @@ -76,6 +76,16 @@ def query_get_max(self, index_name: str, field: str, filters: str = None) -> int """ raise NotImplementedError("query_get_max method is not implemented") + def query_get_time_diff(self, index_name: str, field: str) -> int: + """ + Get the time difference + :param index_name: name of the index + :param field: field name + :param filters: optional filter + :return: time difference in milliseconds + """ + raise NotImplementedError("query_get_time_diff method is not implemented") + class SQLDatasource(DataSource): """ @@ -109,16 +119,16 @@ def query_get_row_count(self, table: str, filters: str = None) -> int: return self.connection.execute(text(query)).fetchone()[0] - def query_get_max(self, table: str, field: str, filter: str = None) -> int: + def query_get_max(self, table: str, field: str, filters: str = None) -> int: """ Get the max value :param table: table name :param field: column name - :param filter: filter condition + :param filters: filter condition :return: """ query = "SELECT MAX({}) FROM {}".format(field, table) - if filter: - query += " WHERE {}".format(filter) + if filters: + query += " WHERE {}".format(filters) return self.connection.execute(text(query)).fetchone()[0] diff --git a/datachecks/core/datasource/manager.py b/datachecks/core/datasource/manager.py index d8e45791..d7a4fbc2 100644 --- a/datachecks/core/datasource/manager.py +++ b/datachecks/core/datasource/manager.py @@ -41,13 +41,13 @@ def _initialize_data_sources(self): :return: """ for data_source_config in self.data_source_configs: - self.data_sources[data_source_config.name] = self.create_data_source( + self.data_sources[data_source_config.name] = self._create_data_source( data_source_config=data_source_config ) self.data_sources[data_source_config.name].connect() @staticmethod - def create_data_source(data_source_config: DataSourceConfiguration) -> DataSource: + def _create_data_source(data_source_config: DataSourceConfiguration) -> DataSource: """ Create a data source :param data_source_config: data source configuration diff --git a/datachecks/core/datasource/opensearch.py b/datachecks/core/datasource/opensearch.py index ed70dbe8..d9a3d77b 100644 --- a/datachecks/core/datasource/opensearch.py +++ b/datachecks/core/datasource/opensearch.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from datetime import datetime from typing import Dict +from dateutil import parser from opensearchpy import OpenSearch from datachecks.core.datasource.base import SearchIndexDataSource @@ -83,3 +84,23 @@ def query_get_max(self, index_name: str, field: str, filters: Dict = None) -> in response = self.client.search(index=index_name, body=query) return response["aggregations"]["max_value"]["value"] + + def query_get_time_diff(self, index_name: str, field: str) -> int: + """ + Get the time difference between the latest and the now + :param index_name: + :param field: + :return: + """ + query = {"query": {"match_all": {}}, "sort": [{f"{field}": {"order": "desc"}}]} + + response = self.client.search(index=index_name, body=query) + + if response["hits"]["hits"]: + last_updated = response["hits"]["hits"][0]["_source"][field] + + last_updated = parser.parse(timestr=last_updated).timestamp() + now = datetime.utcnow().timestamp() + return int(now - last_updated) + + return 0 diff --git a/datachecks/core/metric/base.py b/datachecks/core/metric/base.py index 75f8ab11..98414b74 100644 --- a/datachecks/core/metric/base.py +++ b/datachecks/core/metric/base.py @@ -26,6 +26,7 @@ class MetricsType(str, Enum): ROW_COUNT = "row_count" DOCUMENT_COUNT = "document_count" MAX = "max" + FRESHNESS = "freshness" class MetricIdentity: @@ -45,7 +46,7 @@ def generate_identity( identifiers.append(metric_name) if index_name: identifiers.append(index_name) - if table_name: + elif table_name: identifiers.append(table_name) if field_name: identifiers.append(field_name) @@ -74,6 +75,7 @@ def __init__( if index_name is None and table_name is None: raise ValueError("Please give a value for table_name or index_name") + self.index_name, self.table_name = None, None if index_name: self.index_name = index_name if table_name: @@ -84,12 +86,17 @@ def __init__( self.metric_type = metric_type self.filter_query = None if filters is not None: - if "search_query" in filters and "sql_query" in filters: + if ( + "search_query" in filters and filters["search_query"] is not None + ) and ( + "where_clause" in filters and filters["where_clause"] is not None + ): raise ValueError( - "Please give a value for search_query or sql_query (but not both)" + "Please give a value for search_query or where_clause (but not both)" ) - if "search_query" in filters: + if "search_query" in filters and filters["search_query"] is not None: + print(filters) self.filter_query = json.loads(filters["search_query"]) elif "where_clause" in filters: self.filter_query = filters["where_clause"] @@ -147,15 +154,6 @@ def __init__( self.field_name = field_name - def get_metric_identity(self): - return MetricIdentity.generate_identity( - metric_type=MetricsType.DOCUMENT_COUNT, - metric_name=self.name, - data_source=self.data_source, - table_name=self.table_name, - field_name=self.field_name, - ) - @property def get_field_name(self): return self.field_name diff --git a/datachecks/core/metric/freshness_metric.py b/datachecks/core/metric/freshness_metric.py new file mode 100644 index 00000000..86fb02a9 --- /dev/null +++ b/datachecks/core/metric/freshness_metric.py @@ -0,0 +1,46 @@ +# Copyright 2022-present, the Waterdip Labs Pvt. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datachecks.core.datasource.base import (SearchIndexDataSource, + SQLDatasource) +from datachecks.core.metric.base import (FieldMetrics, MetricIdentity, + MetricsType) + + +class FreshnessValueMetric(FieldMetrics): + """ + FreshnessMetric is a class that represents a metric that is generated by a data source. + """ + + def get_metric_identity(self): + return MetricIdentity.generate_identity( + metric_type=MetricsType.FRESHNESS, + metric_name=self.name, + data_source=self.data_source, + field_name=self.field_name, + ) + + def _generate_metric_value(self): + if isinstance(self.data_source, SQLDatasource): + return self.data_source.query_get_max( + table=self.table_name, + field=self.field_name, + filters=self.filter_query if self.filter_query else None, + ) + elif isinstance(self.data_source, SearchIndexDataSource): + return self.data_source.query_get_time_diff( + index_name=self.index_name, field=self.field_name + ) + else: + raise ValueError("Invalid data source type") diff --git a/datachecks/core/metric/manager.py b/datachecks/core/metric/manager.py index 559ed10a..91617c96 100644 --- a/datachecks/core/metric/manager.py +++ b/datachecks/core/metric/manager.py @@ -19,7 +19,7 @@ from datachecks.core.datasource.manager import DataSourceManager from datachecks.core.metric.base import MetricsType from datachecks.core.metric.numeric_metric import (DocumentCountMetric, - RowCountMetric) + MaxMetric, RowCountMetric) class MetricManager: @@ -41,20 +41,37 @@ def _build_metrics(self, config: Dict[str, List[MetricConfiguration]]): data_source=self.data_source_manager.get_data_source( data_source ), - filter=asdict(metric_config.filter), + filters=asdict(metric_config.filters) + if metric_config.filters + else None, index_name=metric_config.index, + metric_type=MetricsType.DOCUMENT_COUNT, ) - self.metrics[metric.metric_identity] = metric + self.metrics[metric.get_metric_identity()] = metric elif metric_config.metric_type == MetricsType.ROW_COUNT: metric = RowCountMetric( name=metric_config.name, data_source=self.data_source_manager.get_data_source( data_source ), - filter=asdict(metric_config.filter), + filters=asdict(metric_config.filters) if metric_config.filters else None, table_name=metric_config.table, + metric_type=MetricsType.ROW_COUNT, ) - self.metrics[metric.metric_identity] = metric + self.metrics[metric.get_metric_identity()] = metric + elif metric_config.metric_type == MetricsType.MAX: + metric = MaxMetric( + name=metric_config.name, + data_source=self.data_source_manager.get_data_source( + data_source + ), + filters=asdict(metric_config.filters) if metric_config.filters else None, + table_name=metric_config.table, + index_name=metric_config.index, + metric_type=MetricsType.MAX, + field_name=metric_config.field, + ) + self.metrics[metric.get_metric_identity()] = metric else: raise ValueError("Invalid metric type") diff --git a/datachecks/core/metric/numeric_metric.py b/datachecks/core/metric/numeric_metric.py index a07f53a0..14975354 100644 --- a/datachecks/core/metric/numeric_metric.py +++ b/datachecks/core/metric/numeric_metric.py @@ -28,6 +28,14 @@ class DocumentCountMetric(Metric): def validate_data_source(self): return isinstance(self.data_source, SearchIndexDataSource) + def get_metric_identity(self): + return MetricIdentity.generate_identity( + metric_type=MetricsType.DOCUMENT_COUNT, + metric_name=self.name, + data_source=self.data_source, + index_name=self.index_name, + ) + def _generate_metric_value(self): if isinstance(self.data_source, SearchIndexDataSource): return self.data_source.query_get_document_count( @@ -77,6 +85,8 @@ def get_metric_identity(self): metric_name=self.name, data_source=self.data_source, field_name=self.field_name, + table_name=self.table_name if self.table_name else None, + index_name=self.index_name if self.index_name else None, ) def _generate_metric_value(self): @@ -84,7 +94,7 @@ def _generate_metric_value(self): return self.data_source.query_get_max( table=self.table_name, field=self.field_name, - filter=self.filter_query if self.filter_query else None, + filters=self.filter_query if self.filter_query else None, ) elif isinstance(self.data_source, SearchIndexDataSource): return self.data_source.query_get_max( diff --git a/example/config.yaml b/example/config.yaml index 600a7ce7..c7dac9f2 100644 --- a/example/config.yaml +++ b/example/config.yaml @@ -12,41 +12,38 @@ data_sources: host: 127.0.0.1 port: 5431 username: postgres - password: changeme - database: postgres - - name: staging - type: postgres - connection: - host: 127.0.0.1 - port: 5430 - username: postgres - password: changeme - database: postgres + password: postgres + database: dc_db metrics: search: - count_search: + count_document: metric_type: document_count index: category_tabel filters: search_query: '{"match_all" : {}}' + max_price: + metric_type: max + index: category_tabel + field: price + last_updated: + metric_type: freshness + index: category_tabel + field: last_updated content: count_content_hat: metric_type: row_count table: table_1 filters: - sql_query: "category = 'HAT' AND is_valid is True" + where_clause: "category = 'HAT' AND is_valid is True" count_content_non_valid: metric_type: row_count table: table_1 filters: where_clause: "is_valid is False" - - staging: - count_staging: - metric_type: row_count - table: table_2 - filters: - where_clause: "category = 'OIL' AND is_valid is True" + max_content: + metric_type: max + table: table_1 + field: price diff --git a/example/data_generator.py b/example/data_generator.py index cdaccb0f..bfa1042f 100644 --- a/example/data_generator.py +++ b/example/data_generator.py @@ -28,6 +28,7 @@ class TestData: name: str = None category: str = None is_valid: bool = True + price: int = None CATEGORIES = ["HAT", "OIL", "FILTER", "BATTERIES", "HARDWARE"] @@ -35,19 +36,12 @@ class TestData: url_content = URL.create( drivername="postgresql", username="postgres", - password="changeme", + password="postgres", host="127.0.0.1", port=5431, - database="postgres", -) -url_staging = URL.create( - drivername="postgresql", - username="postgres", - password="changeme", - host="127.0.0.1", - port=5430, - database="postgres", + database="dc_db", ) + client = OpenSearch( hosts=[{"host": "127.0.0.1", "port": 9201}], http_auth=("admin", "admin"), @@ -56,10 +50,8 @@ class TestData: ca_certs=False, ) engine_content = create_engine(url_content) -engine_staging = create_engine(url_staging) content_connection = engine_content.connect() -staging_connection = engine_staging.connect() def generate_data_content(number_of_data: int): @@ -69,39 +61,29 @@ def generate_data_content(number_of_data: int): name=f"name_{i}", category=random.choice(CATEGORIES), is_valid=bool(random.getrandbits(1)), + price=random.randint(1, 1000), ) try: content_connection.execute( text( """ - INSERT INTO table_1 (id, name, category, is_valid) VALUES (:id, :name, :category, :is_valid) - """ - ), - { - "id": d.id, - "name": d.name, - "category": d.category, - "is_valid": d.is_valid, - }, + CREATE TABLE IF NOT EXISTS table_1 ( + id int, + name varchar(255), + category varchar(255), + is_valid boolean, + price int, + PRIMARY KEY (id) + ) + """ + ) ) content_connection.commit() - except Exception as e: - raise e - - -def generate_data_staging(number_of_data: int): - for i in range(number_of_data): - d = TestData( - id=i + 2000, - name=f"name_{i}", - category=random.choice(CATEGORIES), - is_valid=bool(random.getrandbits(1)), - ) - try: - staging_connection.execute( + content_connection.execute( text( """ - INSERT INTO table_2 (id, name, category, is_valid) VALUES (:id, :name, :category, :is_valid) + INSERT INTO table_1 (id, name, category, is_valid, price) + VALUES (:id, :name, :category, :is_valid, :price) """ ), { @@ -109,11 +91,12 @@ def generate_data_staging(number_of_data: int): "name": d.name, "category": d.category, "is_valid": d.is_valid, + "price": d.price, }, ) - staging_connection.commit() + content_connection.commit() except Exception as e: - print(e) + raise e def generate_open_search(number_of_data: int): @@ -131,6 +114,5 @@ def generate_open_search(number_of_data: int): if __name__ == "__main__": - generate_data_content(100) - generate_data_staging(200) + generate_data_content(1000) generate_open_search(300) diff --git a/poetry.lock b/poetry.lock index cac97ff6..9329fee8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -614,7 +614,7 @@ name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" category = "main" -optional = true +optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, @@ -729,7 +729,7 @@ name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" category = "main" -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, @@ -884,4 +884,4 @@ postgresql = ["psycopg2", "sqlalchemy"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "0a7578a54956507d7e21d2eca9393784dda2f6f8fbbaa8ece8da8dbc781a7f3b" +content-hash = "e755ac9e9a445f83b737e4e77e6dceab076fc5e62607d3c462ea96aa38ddb9f6" diff --git a/pyproject.toml b/pyproject.toml index d9f66718..4146427a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ requests = "^2.31.0" opensearch-py = { version="^2.2.0", optional=true } sqlalchemy = { version="^2.0.19", optional=true } psycopg2 = { version="^2.9.6", optional=true } +python-dateutil = "^2.8.2" [tool.poetry.group.dev.dependencies] pytest = "^7.1.3" diff --git a/tests/core/metric/test_freshness_metric.py b/tests/core/metric/test_freshness_metric.py new file mode 100644 index 00000000..4ac6ac93 --- /dev/null +++ b/tests/core/metric/test_freshness_metric.py @@ -0,0 +1,110 @@ +# Copyright 2022-present, the Waterdip Labs Pvt. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import time + +import pytest +from opensearchpy import OpenSearch + +from datachecks.core.configuration.configuration import \ + DataSourceConnectionConfiguration +from datachecks.core.metric.base import MetricsType +from datachecks.core.metric.freshness_metric import FreshnessValueMetric +from tests.utils import create_opensearch_client, create_postgres_connection + +INDEX_NAME = "freshness_metric_test" + + +@pytest.fixture(scope="class") +def setup_data( + opensearch_client_configuration: DataSourceConnectionConfiguration, + pgsql_connection_configuration: DataSourceConnectionConfiguration, +): + opensearch_client = create_opensearch_client(opensearch_client_configuration) + postgresql_connection = create_postgres_connection(pgsql_connection_configuration) + try: + populate_opensearch_datasource(opensearch_client) + # populate_postgres_datasource(postgresql_connection) + yield True + except Exception as e: + print(e) + finally: + opensearch_client.indices.delete(index=INDEX_NAME, ignore=[400, 404]) + # postgresql_connection.execute(text("DROP TABLE IF EXISTS numeric_metric_test")) + postgresql_connection.commit() + + opensearch_client.close() + postgresql_connection.close() + + +def populate_opensearch_datasource(opensearch_client: OpenSearch): + try: + opensearch_client.indices.delete(index=INDEX_NAME, ignore=[400, 404]) + opensearch_client.indices.create( + index=INDEX_NAME, + body={"mappings": {"properties": {"last_fight": {"type": "date"}}}}, + ) + opensearch_client.index( + index=INDEX_NAME, + body={ + "name": "thor", + "last_fight": datetime.datetime.utcnow() - datetime.timedelta(days=10), + }, + ) + opensearch_client.index( + index=INDEX_NAME, + body={ + "name": "captain america", + "last_fight": datetime.datetime.utcnow() - datetime.timedelta(days=3), + }, + ) + opensearch_client.index( + index=INDEX_NAME, + body={ + "name": "iron man", + "last_fight": datetime.datetime.utcnow() - datetime.timedelta(days=4), + }, + ) + opensearch_client.index( + index=INDEX_NAME, + body={ + "name": "hawk eye", + "last_fight": datetime.datetime.utcnow() - datetime.timedelta(days=5), + }, + ) + opensearch_client.index( + index=INDEX_NAME, + body={ + "name": "black widow", + "last_fight": datetime.datetime.utcnow() - datetime.timedelta(days=6), + }, + ) + opensearch_client.indices.refresh(index=INDEX_NAME) + except Exception as e: + print(e) + + +@pytest.mark.usefixtures("setup_data", "postgres_datasource", "opensearch_datasource") +class TestFreshnessValueMetric: + def test_should_get_freshness_value_from_opensearch(self, opensearch_datasource): + metric = FreshnessValueMetric( + name="freshness_value_metric_test", + data_source=opensearch_datasource, + index_name=INDEX_NAME, + field_name="last_fight", + metric_type=MetricsType.FRESHNESS, + ) + metric_value = metric.get_value() + assert metric_value["value"] == 3 * 3600 * 24 diff --git a/tests/core/metric/test_metric_manager.py b/tests/core/metric/test_metric_manager.py new file mode 100644 index 00000000..85439df7 --- /dev/null +++ b/tests/core/metric/test_metric_manager.py @@ -0,0 +1,179 @@ +# Copyright 2022-present, the Waterdip Labs Pvt. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from datachecks.core.configuration.configuration import ( + DataSourceConfiguration, DataSourceConnectionConfiguration, DatasourceType, + MetricConfiguration, MetricsFilterConfiguration) +from datachecks.core.datasource.manager import DataSourceManager +from datachecks.core.metric.base import MetricsType +from datachecks.core.metric.manager import MetricManager +from datachecks.core.metric.numeric_metric import DocumentCountMetric + +OPEN_SEARCH_DATA_SOURCE_NAME = "test_open_search_data_source" +POSTGRES_DATA_SOURCE_NAME = "test_postgres_data_source" + + +@pytest.mark.usefixtures( + "opensearch_client_configuration", "pgsql_connection_configuration" +) +@pytest.fixture(scope="class") +def setup_data_source_manager( + opensearch_client_configuration: DataSourceConnectionConfiguration, + pgsql_connection_configuration: DataSourceConnectionConfiguration, +) -> DataSourceManager: + data_source_manager = DataSourceManager( + config=[ + DataSourceConfiguration( + name=OPEN_SEARCH_DATA_SOURCE_NAME, + type=DatasourceType.OPENSEARCH, + connection_config=opensearch_client_configuration, + ), + DataSourceConfiguration( + name=POSTGRES_DATA_SOURCE_NAME, + type=DatasourceType.POSTGRES, + connection_config=pgsql_connection_configuration, + ), + ] + ) + yield data_source_manager + for data_source in data_source_manager.data_sources.values(): + data_source.close() + + +@pytest.mark.usefixtures("setup_data_source_manager") +class TestMetricManager: + def test_should_create_document_count_metric( + self, setup_data_source_manager: DataSourceManager + ): + metric_name, index_name = "test_document_count_metric", "test_index" + metric_config = { + "metric_type": "document_count", + "name": metric_name, + "index": index_name, + } + + metric_config = MetricConfiguration(**metric_config) + metric_manager = MetricManager( + metric_config={OPEN_SEARCH_DATA_SOURCE_NAME: [metric_config]}, + data_source_manager=setup_data_source_manager, + ) + + metric_identity = f"{OPEN_SEARCH_DATA_SOURCE_NAME}.document_count.{metric_name}.{index_name}" + metric = metric_manager.get_metric(metric_identity) + + assert isinstance(metric, DocumentCountMetric) + assert metric.name == "test_document_count_metric" + assert metric.metric_type == MetricsType.DOCUMENT_COUNT + assert metric.index_name == "test_index" + + def test_should_create_document_count_metric_with_filter( + self, setup_data_source_manager: DataSourceManager + ): + metric_name, index_name = "test_document_count_metric", "test_index" + metric_config = { + "metric_type": "document_count", + "name": metric_name, + "index": index_name, + } + filters = {"search_query": '{"range": {"age": {"gte": 30, "lte": 40}}}'} + + metric_config = MetricConfiguration(**metric_config) + metric_config.filters = MetricsFilterConfiguration(**filters) + metric_manager = MetricManager( + metric_config={OPEN_SEARCH_DATA_SOURCE_NAME: [metric_config]}, + data_source_manager=setup_data_source_manager, + ) + + metric_identity = f"{OPEN_SEARCH_DATA_SOURCE_NAME}.document_count.{metric_name}.{index_name}" + metric = metric_manager.get_metric(metric_identity) + + assert isinstance(metric, DocumentCountMetric) + assert metric.name == "test_document_count_metric" + assert metric.metric_type == MetricsType.DOCUMENT_COUNT + assert metric.index_name == "test_index" + assert metric.filter_query == {"range": {"age": {"gte": 30, "lte": 40}}} + + def test_should_create_row_count_metric( + self, setup_data_source_manager: DataSourceManager + ): + metric_name, table_name = "test_row_count_metric", "test_table" + metric_config = { + "metric_type": "row_count", + "name": metric_name, + "table": table_name, + } + + metric_config = MetricConfiguration(**metric_config) + metric_manager = MetricManager( + metric_config={POSTGRES_DATA_SOURCE_NAME: [metric_config]}, + data_source_manager=setup_data_source_manager, + ) + + metric_identity = f"{POSTGRES_DATA_SOURCE_NAME}.row_count.{metric_name}.{table_name}" + metric = metric_manager.get_metric(metric_identity) + + assert metric.name == "test_row_count_metric" + assert metric.metric_type == MetricsType.ROW_COUNT + assert metric.table_name == "test_table" + + def test_should_create_row_count_metric_with_filters( + self, setup_data_source_manager: DataSourceManager + ): + metric_name, table_name = "test_row_count_metric", "test_table" + metric_config = { + "metric_type": "row_count", + "name": metric_name, + "table": table_name, + } + filters = {"where_clause": "age > 30"} + metric_config = MetricConfiguration(**metric_config) + metric_config.filters = MetricsFilterConfiguration(**filters) + metric_manager = MetricManager( + metric_config={POSTGRES_DATA_SOURCE_NAME: [metric_config]}, + data_source_manager=setup_data_source_manager, + ) + + metric_identity = f"{POSTGRES_DATA_SOURCE_NAME}.row_count.{metric_name}.{table_name}" + metric = metric_manager.get_metric(metric_identity) + + assert metric.name == metric_name + assert metric.metric_type == MetricsType.ROW_COUNT + assert metric.table_name == "test_table" + + def test_should_create_max_metric(self, setup_data_source_manager: DataSourceManager): + metric_name, table_name, field_name = "test_max_metric", "test_table", "age" + metric_config = { + "metric_type": "max", + "name": metric_name, + "table": table_name, + "field": field_name, + } + + metric_config = MetricConfiguration(**metric_config) + metric_manager = MetricManager( + metric_config={POSTGRES_DATA_SOURCE_NAME: [metric_config]}, + data_source_manager=setup_data_source_manager, + ) + print("====") + print(metric_manager.metrics) + print("====") + metric_identity = f"{POSTGRES_DATA_SOURCE_NAME}.max.{metric_name}.{table_name}.{field_name}" + metric = metric_manager.get_metric(metric_identity) + + assert metric.name == metric_name + assert metric.metric_type == MetricsType.MAX + assert metric.table_name == "test_table" + assert metric.field_name == "age"