Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(metric): numeric metric average calculation #25

Merged
merged 1 commit into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions datachecks/core/datasource/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def query_get_max(self, index_name: str, field: str, filters: str = None) -> int
:return: max value
"""
raise NotImplementedError("query_get_max method is not implemented")

def query_get_avg(self, index_name: str, field: str, filters: str = None) -> int:
"""
Get the average value
:param index_name: name of the index
:param field: field name
:param filters: optional filter
:return: average value
"""
raise NotImplementedError("query_get_avg method is not implemented")

def query_get_time_diff(self, index_name: str, field: str) -> int:
"""
Expand Down Expand Up @@ -134,6 +144,20 @@ def query_get_max(self, table: str, field: str, filters: str = None) -> int:
query += " WHERE {}".format(filters)

return self.connection.execute(text(query)).fetchone()[0]

def query_get_avg(self, table: str, field: str, filters: str = None) -> int:
"""
Get the average value
:param table: table name
:param field: column name
:param filters: filter condition
:return:
"""
query = "SELECT ROUND(AVG({}), 2) FROM {}".format(field, table)
if filters:
query += " WHERE {}".format(filters)

return self.connection.execute(text(query)).fetchone()[0]

def query_get_time_diff(self, table: str, field: str) -> int:
"""
Expand Down
15 changes: 15 additions & 0 deletions datachecks/core/datasource/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ def query_get_max(self, index_name: str, field: str, filters: Dict = None) -> in

response = self.client.search(index=index_name, body=query)
return response["aggregations"]["max_value"]["value"]

def query_get_avg(self, index_name: str, field: str, filters: Dict = None) -> int:
"""
Get the average value of a field
:param index_name:
:param field:
:param filters:
:return:
"""
query = {"aggs": {"avg_value": {"avg": {"field": field}}}}
if filters:
query["query"] = filters

response = self.client.search(index=index_name, body=query)
return round(response["aggregations"]["avg_value"]["value"], 2)

def query_get_time_diff(self, index_name: str, field: str) -> int:
"""
Expand Down
1 change: 1 addition & 0 deletions datachecks/core/metric/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class MetricsType(str, Enum):
ROW_COUNT = "row_count"
DOCUMENT_COUNT = "document_count"
MAX = "max"
AVG = "avg"
FRESHNESS = "freshness"


Expand Down
32 changes: 32 additions & 0 deletions datachecks/core/metric/numeric_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,35 @@ def _generate_metric_value(self):
)
else:
raise ValueError("Invalid data source type")

class AvgMetric(FieldMetrics):

"""
AvgMetric is a class that represents a metric that is generated by a data source.
"""

def get_metric_identity(self):
return MetricIdentity.generate_identity(
metric_type=MetricsType.AVG,
metric_name=self.name,
data_source=self.data_source,
field_name=self.field_name,
table_name=self.table_name if self.table_name else None,
index_name=self.index_name if self.index_name else None,
)

def _generate_metric_value(self):
if isinstance(self.data_source, SQLDatasource):
return self.data_source.query_get_avg(
table=self.table_name,
field=self.field_name,
filters=self.filter_query if self.filter_query else None,
)
elif isinstance(self.data_source, SearchIndexDataSource):
return self.data_source.query_get_avg(
index_name=self.index_name,
field=self.field_name,
filters=self.filter_query if self.filter_query else None,
)
else:
raise ValueError("Invalid data source type")
58 changes: 57 additions & 1 deletion tests/core/metric/test_numeric_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from datachecks.core.datasource.postgres import PostgresSQLDatasource
from datachecks.core.metric.base import MetricsType
from datachecks.core.metric.numeric_metric import (DocumentCountMetric,
MaxMetric, RowCountMetric)
MaxMetric, AvgMetric, RowCountMetric)
from tests.utils import create_opensearch_client, create_postgres_connection


Expand Down Expand Up @@ -206,3 +206,59 @@ def test_should_return_max_column_value_opensearch_with_filter(
)
row_value = row.get_value()
assert row_value["value"] == 110

@pytest.mark.usefixtures("setup_data", "postgres_datasource", "opensearch_datasource")
class TestAvgColumnValueMetric:
def test_should_return_avg_column_value_postgres_without_filter(
self, postgres_datasource: PostgresSQLDatasource
):
row = AvgMetric(
name="avg_metric_test",
data_source=postgres_datasource,
table_name="numeric_metric_test",
metric_type=MetricsType.AVG,
field_name="age",
)
row_value = row.get_value()
assert float(row_value["value"]) == 141.40

def test_should_return_avg_column_value_postgres_with_filter(
self, postgres_datasource: PostgresSQLDatasource
):
row = AvgMetric(
name="avg_metric_test_1",
data_source=postgres_datasource,
table_name="numeric_metric_test",
metric_type=MetricsType.AVG,
field_name="age",
filters={"where_clause": "age >= 30 AND age <= 200"},
)
row_value = row.get_value()
assert float(row_value["value"]) == 51.50

def test_should_return_avg_column_value_opensearch_without_filter(
self, opensearch_datasource: OpenSearchSearchIndexDataSource
):
row = AvgMetric(
name="avg_metric_test",
data_source=opensearch_datasource,
index_name="numeric_metric_test",
metric_type=MetricsType.AVG,
field_name="age",
)
row_value = row.get_value()
assert float(row_value["value"]) == 141.40

def test_should_return_avg_column_value_opensearch_with_filter(
self, opensearch_datasource: OpenSearchSearchIndexDataSource
):
row = AvgMetric(
name="avg_metric_test_1",
data_source=opensearch_datasource,
index_name="numeric_metric_test",
metric_type=MetricsType.AVG,
field_name="age",
filters={"search_query": '{"range": {"age": {"gte": 30, "lte": 200}}}'},
)
row_value = row.get_value()
assert float(row_value["value"]) == 51.50
Loading