Skip to content

Commit

Permalink
fix(ingest/bigquery): fix partition and median queries for profiling (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate authored Sep 6, 2023
1 parent 8bf28bf commit e680a97
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,15 @@ class BigQueryV2Report(ProfilingSqlReport, IngestionStageReport, BaseTimeWindowR
partition_info: Dict[str, str] = field(default_factory=TopKDict)
profile_table_selection_criteria: Dict[str, str] = field(default_factory=TopKDict)
selected_profile_tables: Dict[str, List[str]] = field(default_factory=TopKDict)
invalid_partition_ids: Dict[str, str] = field(default_factory=TopKDict)
profiling_skipped_invalid_partition_ids: Dict[str, str] = field(
default_factory=TopKDict
)
profiling_skipped_invalid_partition_type: Dict[str, str] = field(
default_factory=TopKDict
)
profiling_skipped_partition_profiling_disabled: List[str] = field(
default_factory=LossyList
)
allow_pattern: Optional[str] = None
deny_pattern: Optional[str] = None
num_usage_workunits_emitted: int = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def from_range_partitioning(

return cls(
field=field,
type="RANGE",
type=RANGE_PARTITION_NAME,
)

@classmethod
Expand Down Expand Up @@ -151,6 +151,9 @@ class BigqueryQuery:
"""

# https://cloud.google.com/bigquery/docs/information-schema-table-storage?hl=en
# Note for max_partition_id -
# should we instead pick the partition with latest LAST_MODIFIED_TIME ?
# for range partitioning max may not be latest partition
tables_for_dataset = f"""
SELECT
t.table_catalog as table_catalog,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,17 @@ def generate_partition_profiler_query(
partition_where_clause: str

if table.partition_info.type == RANGE_PARTITION_NAME:
if table.partition_info and table.partition_info.column:
if table.partition_info.column:
partition_where_clause = (
f"{table.partition_info.column.name} >= {partition}"
)
else:
logger.warning(
f"Partitioned table {table.name} without partiton column"
)
self.report.profiling_skipped_invalid_partition_ids[
f"{project}.{schema}.{table.name}"
] = partition
return None, None
else:
logger.debug(
Expand All @@ -118,8 +121,8 @@ def generate_partition_profiler_query(
logger.error(
f"Unable to get partition range for partition id: {partition} it failed with exception {e}"
)
self.report.invalid_partition_ids[
f"{schema}.{table.name}"
self.report.profiling_skipped_invalid_partition_ids[
f"{project}.{schema}.{table.name}"
] = partition
return None, None

Expand All @@ -132,11 +135,14 @@ def generate_partition_profiler_query(
partition_column_name = table.partition_info.column.name
partition_data_type = table.partition_info.column.data_type
if table.partition_info.type in ("HOUR", "DAY", "MONTH", "YEAR"):
partition_where_clause = f"{partition_data_type}(`{partition_column_name}`) BETWEEN {partition_data_type}('{partition_datetime}') AND {partition_data_type}('{upper_bound_partition_datetime}')"
partition_where_clause = f"`{partition_column_name}` BETWEEN {partition_data_type}('{partition_datetime}') AND {partition_data_type}('{upper_bound_partition_datetime}')"
else:
logger.warning(
f"Not supported partition type {table.partition_info.type}"
)
self.report.profiling_skipped_invalid_partition_type[
f"{project}.{schema}.{table.name}"
] = table.partition_info.type
return None, None
custom_sql = """
SELECT
Expand All @@ -153,7 +159,7 @@ def generate_partition_profiler_query(
)

return (partition, custom_sql)
if table.max_shard_id:
elif table.max_shard_id:
# For sharded table we want to get the partition id but not needed to generate custom query
return table.max_shard_id, None

Expand All @@ -162,15 +168,9 @@ def generate_partition_profiler_query(
def get_workunits(
self, project_id: str, tables: Dict[str, List[BigqueryTable]]
) -> Iterable[MetadataWorkUnit]:
# Otherwise, if column level profiling is enabled, use GE profiler.
if not self.config.project_id_pattern.allowed(project_id):
return
profile_requests = []

for dataset in tables:
if not self.config.schema_pattern.allowed(dataset):
continue

for table in tables[dataset]:
normalized_table_name = BigqueryTableIdentifier(
project_id=project_id, dataset=dataset, table=table.name
Expand Down Expand Up @@ -253,24 +253,26 @@ def get_bigquery_profile_request(
if self.config.profiling.report_dropped_profiles:
self.report.report_dropped(f"profile of {dataset_name}")
return None

(partition, custom_sql) = self.generate_partition_profiler_query(
project, dataset, table, self.config.profiling.partition_datetime
)

if partition is None and table.partition_info:
self.report.report_warning(
"profile skipped as partitioned table is empty or partition id was invalid",
"profile skipped as partitioned table is empty or partition id or type was invalid",
dataset_name,
)
return None

if (
partition is not None
and not self.config.profiling.partition_profiling_enabled
):
logger.debug(
f"{dataset_name} and partition {partition} is skipped because profiling.partition_profiling_enabled property is disabled"
)
self.report.profiling_skipped_partition_profiling_disabled.append(
dataset_name
)
return None

self.report.report_entity_profiled(dataset_name)
Expand Down
51 changes: 38 additions & 13 deletions metadata-ingestion/src/datahub/ingestion/source/ge_data_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import contextlib
import dataclasses
import functools
import json
import logging
import threading
import traceback
Expand All @@ -27,7 +28,7 @@
import sqlalchemy as sa
import sqlalchemy.sql.compiler
from great_expectations.core.util import convert_to_json_serializable
from great_expectations.data_context import BaseDataContext
from great_expectations.data_context import AbstractDataContext, BaseDataContext
from great_expectations.data_context.types.base import (
DataContextConfig,
DatasourceConfig,
Expand Down Expand Up @@ -55,6 +56,7 @@
DatasetProfileClass,
HistogramClass,
PartitionSpecClass,
PartitionTypeClass,
QuantileClass,
ValueFrequencyClass,
)
Expand All @@ -70,6 +72,12 @@
logger: logging.Logger = logging.getLogger(__name__)

P = ParamSpec("P")
POSTGRESQL = "postgresql"
MYSQL = "mysql"
SNOWFLAKE = "snowflake"
BIGQUERY = "bigquery"
REDSHIFT = "redshift"
TRINO = "trino"

# The reason for this wacky structure is quite fun. GE basically assumes that
# the config structures were generated directly from YML and further assumes that
Expand Down Expand Up @@ -113,14 +121,14 @@ class GEProfilerRequest:


def get_column_unique_count_patch(self: SqlAlchemyDataset, column: str) -> int:
if self.engine.dialect.name.lower() == "redshift":
if self.engine.dialect.name.lower() == REDSHIFT:
element_values = self.engine.execute(
sa.select(
[sa.text(f'APPROXIMATE count(distinct "{column}")')] # type:ignore
).select_from(self._table)
)
return convert_to_json_serializable(element_values.fetchone()[0])
elif self.engine.dialect.name.lower() == "bigquery":
elif self.engine.dialect.name.lower() == BIGQUERY:
element_values = self.engine.execute(
sa.select(
[
Expand All @@ -131,7 +139,7 @@ def get_column_unique_count_patch(self: SqlAlchemyDataset, column: str) -> int:
).select_from(self._table)
)
return convert_to_json_serializable(element_values.fetchone()[0])
elif self.engine.dialect.name.lower() == "snowflake":
elif self.engine.dialect.name.lower() == SNOWFLAKE:
element_values = self.engine.execute(
sa.select(sa.func.APPROX_COUNT_DISTINCT(sa.column(column))).select_from(
self._table
Expand Down Expand Up @@ -361,7 +369,7 @@ def _get_column_cardinality(
def _get_dataset_rows(self, dataset_profile: DatasetProfileClass) -> None:
if self.config.profile_table_row_count_estimate_only:
dialect_name = self.dataset.engine.dialect.name.lower()
if dialect_name == "postgresql":
if dialect_name == POSTGRESQL:
schema_name = self.dataset_name.split(".")[1]
table_name = self.dataset_name.split(".")[2]
logger.debug(
Expand All @@ -370,7 +378,7 @@ def _get_dataset_rows(self, dataset_profile: DatasetProfileClass) -> None:
get_estimate_script = sa.text(
f"SELECT c.reltuples AS estimate FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = '{table_name}' AND n.nspname = '{schema_name}'"
)
elif dialect_name == "mysql":
elif dialect_name == MYSQL:
schema_name = self.dataset_name.split(".")[0]
table_name = self.dataset_name.split(".")[1]
logger.debug(
Expand Down Expand Up @@ -421,14 +429,22 @@ def _get_dataset_column_median(
if not self.config.include_field_median_value:
return
try:
if self.dataset.engine.dialect.name.lower() == "snowflake":
if self.dataset.engine.dialect.name.lower() == SNOWFLAKE:
column_profile.median = str(
self.dataset.engine.execute(
sa.select([sa.func.median(sa.column(column))]).select_from(
self.dataset._table
)
).scalar()
)
elif self.dataset.engine.dialect.name.lower() == BIGQUERY:
column_profile.median = str(
self.dataset.engine.execute(
sa.select(
sa.text(f"approx_quantiles(`{column}`, 2) [OFFSET (1)]")
).select_from(self.dataset._table)
).scalar()
)
else:
column_profile.median = str(self.dataset.get_column_median(column))
except Exception as e:
Expand Down Expand Up @@ -583,6 +599,13 @@ def generate_dataset_profile( # noqa: C901 (complexity)
profile = DatasetProfileClass(timestampMillis=get_sys_time())
if self.partition:
profile.partitionSpec = PartitionSpecClass(partition=self.partition)
elif self.config.limit and self.config.offset:
profile.partitionSpec = PartitionSpecClass(
type=PartitionTypeClass.QUERY,
partition=json.dumps(
dict(limit=self.config.limit, offset=self.config.offset)
),
)
profile.fieldProfiles = []
self._get_dataset_rows(profile)

Expand Down Expand Up @@ -717,7 +740,7 @@ def generate_dataset_profile( # noqa: C901 (complexity)

@dataclasses.dataclass
class GEContext:
data_context: BaseDataContext
data_context: AbstractDataContext
datasource_name: str


Expand Down Expand Up @@ -935,7 +958,7 @@ def _generate_single_profile(
}

bigquery_temp_table: Optional[str] = None
if platform == "bigquery" and (
if platform == BIGQUERY and (
custom_sql or self.config.limit or self.config.offset
):
# On BigQuery, we need to bypass GE's mechanism for creating temporary tables because
Expand All @@ -950,6 +973,8 @@ def _generate_single_profile(
)
if custom_sql is not None:
# Note that limit and offset are not supported for custom SQL.
# Presence of custom SQL represents that the bigquery table
# is either partitioned or sharded
bq_sql = custom_sql
else:
bq_sql = f"SELECT * FROM `{table}`"
Expand Down Expand Up @@ -1015,7 +1040,7 @@ def _generate_single_profile(
finally:
raw_connection.close()

if platform == "bigquery":
if platform == BIGQUERY:
if bigquery_temp_table:
ge_config["table"] = bigquery_temp_table
ge_config["schema"] = None
Expand Down Expand Up @@ -1066,7 +1091,7 @@ def _generate_single_profile(
self.report.report_warning(pretty_name, f"Profiling exception {e}")
return None
finally:
if self.base_engine.engine.name == "trino":
if self.base_engine.engine.name == TRINO:
self._drop_trino_temp_table(batch)

def _get_ge_dataset(
Expand Down Expand Up @@ -1103,7 +1128,7 @@ def _get_ge_dataset(
**batch_kwargs,
},
)
if platform is not None and platform == "bigquery":
if platform == BIGQUERY:
# This is done as GE makes the name as DATASET.TABLE
# but we want it to be PROJECT.DATASET.TABLE instead for multi-project setups
name_parts = pretty_name.split(".")
Expand All @@ -1124,7 +1149,7 @@ def _get_ge_dataset(
# Stringified types are used to avoid dialect specific import errors
@lru_cache(maxsize=1)
def _get_column_types_to_ignore(dialect_name: str) -> List[str]:
if dialect_name.lower() == "postgresql":
if dialect_name.lower() == POSTGRESQL:
return ["JSON"]

return []
8 changes: 4 additions & 4 deletions metadata-ingestion/tests/unit/test_bigquery_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_generate_day_partitioned_partition_profiler_query():
FROM
`test_project.test_dataset.test_table`
WHERE
TIMESTAMP(`date`) BETWEEN TIMESTAMP('2020-01-01 00:00:00') AND TIMESTAMP('2020-01-02 00:00:00')
`date` BETWEEN TIMESTAMP('2020-01-01 00:00:00') AND TIMESTAMP('2020-01-02 00:00:00')
""".strip()

assert "20200101" == query[0]
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_generate_day_partitioned_partition_profiler_query_with_set_partition_ti
FROM
`test_project.test_dataset.test_table`
WHERE
TIMESTAMP(`date`) BETWEEN TIMESTAMP('2020-01-01 00:00:00') AND TIMESTAMP('2020-01-02 00:00:00')
`date` BETWEEN TIMESTAMP('2020-01-01 00:00:00') AND TIMESTAMP('2020-01-02 00:00:00')
""".strip()

assert "20200101" == query[0]
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_generate_hour_partitioned_partition_profiler_query():
FROM
`test_project.test_dataset.test_table`
WHERE
TIMESTAMP(`partition_column`) BETWEEN TIMESTAMP('2020-01-01 03:00:00') AND TIMESTAMP('2020-01-01 04:00:00')
`partition_column` BETWEEN TIMESTAMP('2020-01-01 03:00:00') AND TIMESTAMP('2020-01-01 04:00:00')
""".strip()

assert "2020010103" == query[0]
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_generate_ingestion_partitioned_partition_profiler_query():
FROM
`test_project.test_dataset.test_table`
WHERE
TIMESTAMP(`_PARTITIONTIME`) BETWEEN TIMESTAMP('2020-01-01 00:00:00') AND TIMESTAMP('2020-01-02 00:00:00')
`_PARTITIONTIME` BETWEEN TIMESTAMP('2020-01-01 00:00:00') AND TIMESTAMP('2020-01-02 00:00:00')
""".strip()

assert "20200101" == query[0]
Expand Down
11 changes: 11 additions & 0 deletions metadata-ingestion/tests/unit/test_bigquery_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ def test_get_projects_with_project_ids_overrides_project_id_pattern():
]


def test_platform_instance_config_always_none():
config = BigQueryV2Config.parse_obj(
{"include_data_platform_instance": True, "platform_instance": "something"}
)
assert config.platform_instance is None

config = BigQueryV2Config(platform_instance="something", project_id="project_id")
assert config.project_id == "project_id"
assert config.platform_instance is None


def test_get_dataplatform_instance_aspect_returns_project_id():
project_id = "project_id"
expected_instance = (
Expand Down

0 comments on commit e680a97

Please sign in to comment.