From 533130408a28c036f4bdf4c2d7289311d28bf906 Mon Sep 17 00:00:00 2001 From: sid-acryl <155424659+sid-acryl@users.noreply.github.com> Date: Fri, 2 Feb 2024 02:17:09 +0530 Subject: [PATCH 1/2] feat(ingestion/redshift): collapse lineage to permanent table (#9704) Co-authored-by: Harshal Sheth Co-authored-by: treff7es --- .../src/datahub/emitter/mce_builder.py | 1 + .../src/datahub/ingestion/api/common.py | 2 +- .../ingestion/source/redshift/config.py | 17 +- .../ingestion/source/redshift/lineage.py | 554 ++++++++++++++- .../ingestion/source/redshift/query.py | 136 +++- .../source/redshift/redshift_schema.py | 76 +- .../ingestion/source/redshift/report.py | 2 + .../source/snowflake/snowflake_config.py | 11 +- .../src/datahub/utilities/sqlglot_lineage.py | 8 + .../tests/unit/redshift_query_mocker.py | 104 +++ .../tests/unit/test_redshift_lineage.py | 663 +++++++++++++++++- 11 files changed, 1515 insertions(+), 59 deletions(-) create mode 100644 metadata-ingestion/tests/unit/redshift_query_mocker.py diff --git a/metadata-ingestion/src/datahub/emitter/mce_builder.py b/metadata-ingestion/src/datahub/emitter/mce_builder.py index 9da1b0ab56f89..fe9ecee8f80d0 100644 --- a/metadata-ingestion/src/datahub/emitter/mce_builder.py +++ b/metadata-ingestion/src/datahub/emitter/mce_builder.py @@ -1,4 +1,5 @@ """Convenience functions for creating MCEs""" + import hashlib import json import logging diff --git a/metadata-ingestion/src/datahub/ingestion/api/common.py b/metadata-ingestion/src/datahub/ingestion/api/common.py index a6761a3c77d5e..097859939cfea 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/common.py +++ b/metadata-ingestion/src/datahub/ingestion/api/common.py @@ -64,7 +64,7 @@ def _set_dataset_urn_to_lower_if_needed(self) -> None: # TODO: Get rid of this function once lower-casing is the standard. if self.graph: server_config = self.graph.get_config() - if server_config and server_config.get("datasetUrnNameCasing"): + if server_config and server_config.get("datasetUrnNameCasing") is True: set_dataset_urn_to_lower(True) def register_checkpointer(self, committable: Committable) -> None: diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py index 540adbf4bfd15..fe66ef006ec69 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py @@ -94,10 +94,10 @@ class RedshiftConfig( description="The default schema to use if the sql parser fails to parse the schema with `sql_based` lineage collector", ) - include_table_lineage: Optional[bool] = Field( + include_table_lineage: bool = Field( default=True, description="Whether table lineage should be ingested." ) - include_copy_lineage: Optional[bool] = Field( + include_copy_lineage: bool = Field( default=True, description="Whether lineage should be collected from copy commands", ) @@ -107,17 +107,15 @@ class RedshiftConfig( description="Generate usage statistic. email_domain config parameter needs to be set if enabled", ) - include_unload_lineage: Optional[bool] = Field( + include_unload_lineage: bool = Field( default=True, description="Whether lineage should be collected from unload commands", ) - capture_lineage_query_parser_failures: Optional[bool] = Field( - hide_from_schema=True, + include_table_rename_lineage: bool = Field( default=False, - description="Whether to capture lineage query parser errors with dataset properties for debugging", + description="Whether we should follow `alter table ... rename to` statements when computing lineage. ", ) - table_lineage_mode: Optional[LineageMode] = Field( default=LineageMode.STL_SCAN_BASED, description="Which table lineage collector mode to use. Available modes are: [stl_scan_based, sql_based, mixed]", @@ -139,6 +137,11 @@ class RedshiftConfig( description="When enabled, emits lineage as incremental to existing lineage already in DataHub. When disabled, re-states lineage on each run. This config works with rest-sink only.", ) + resolve_temp_table_in_lineage: bool = Field( + default=False, + description="Whether to resolve temp table appear in lineage to upstream permanent tables.", + ) + @root_validator(pre=True) def check_email_is_set_on_usage(cls, values): if values.get("include_usage_statistics"): diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py index 3efef58737c6e..898e6db0b14b0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py @@ -4,11 +4,12 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Dict, List, Optional, Set, Tuple, Union, cast from urllib.parse import urlparse import humanfriendly import redshift_connector +import sqlglot import datahub.emitter.mce_builder as builder import datahub.utilities.sqlglot_lineage as sqlglot_l @@ -24,17 +25,24 @@ RedshiftSchema, RedshiftTable, RedshiftView, + TempTableRow, ) from datahub.ingestion.source.redshift.report import RedshiftReport from datahub.ingestion.source.state.redundant_run_skip_handler import ( RedundantLineageRunSkipHandler, ) +from datahub.metadata._schema_classes import SchemaFieldDataTypeClass from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( FineGrainedLineage, FineGrainedLineageDownstreamType, FineGrainedLineageUpstreamType, UpstreamLineage, ) +from datahub.metadata.com.linkedin.pegasus2avro.schema import ( + OtherSchema, + SchemaField, + SchemaMetadata, +) from datahub.metadata.schema_classes import ( DatasetLineageTypeClass, UpstreamClass, @@ -111,6 +119,34 @@ def merge_lineage( self.cll = self.cll or None +def parse_alter_table_rename(default_schema: str, query: str) -> Tuple[str, str, str]: + """ + Parses an ALTER TABLE ... RENAME TO ... query and returns the schema, previous table name, and new table name. + """ + + parsed_query = sqlglot.parse_one(query, dialect="redshift") + assert isinstance(parsed_query, sqlglot.exp.AlterTable) + prev_name = parsed_query.this.name + rename_clause = parsed_query.args["actions"][0] + assert isinstance(rename_clause, sqlglot.exp.RenameTable) + new_name = rename_clause.this.name + + schema = parsed_query.this.db or default_schema + + return schema, prev_name, new_name + + +def split_qualified_table_name(urn: str) -> Tuple[str, str, str]: + qualified_table_name = dataset_urn.DatasetUrn.create_from_string( + urn + ).get_entity_id()[1] + + # -3 because platform instance is optional and that can cause the split to have more than 3 elements + db, schema, table = qualified_table_name.split(".")[-3:] + + return db, schema, table + + class RedshiftLineageExtractor: def __init__( self, @@ -130,6 +166,95 @@ def __init__( self.report.lineage_end_time, ) = self.get_time_window() + self.temp_tables: Dict[str, TempTableRow] = {} + + def _init_temp_table_schema( + self, database: str, temp_tables: List[TempTableRow] + ) -> None: + if self.context.graph is None: # to silent lint + return + + schema_resolver: sqlglot_l.SchemaResolver = ( + self.context.graph._make_schema_resolver( + platform=LineageDatasetPlatform.REDSHIFT.value, + platform_instance=self.config.platform_instance, + env=self.config.env, + ) + ) + + dataset_vs_columns: Dict[str, List[SchemaField]] = {} + # prepare dataset_urn vs List of schema fields + for table in temp_tables: + logger.debug( + f"Processing temp table: {table.create_command} with query text {table.query_text}" + ) + result = sqlglot_l.create_lineage_sql_parsed_result( + platform=LineageDatasetPlatform.REDSHIFT.value, + platform_instance=self.config.platform_instance, + env=self.config.env, + default_db=database, + default_schema=self.config.default_schema, + query=table.query_text, + graph=self.context.graph, + ) + + if ( + result is None + or result.column_lineage is None + or result.query_type != sqlglot_l.QueryType.CREATE + or not result.out_tables + ): + logger.debug(f"Unsupported temp table query found: {table.query_text}") + continue + + table.parsed_result = result + if result.column_lineage[0].downstream.table: + table.urn = result.column_lineage[0].downstream.table + + self.temp_tables[result.out_tables[0]] = table + + for table in self.temp_tables.values(): + if ( + table.parsed_result is None + or table.parsed_result.column_lineage is None + ): + continue + for column_lineage in table.parsed_result.column_lineage: + if column_lineage.downstream.table not in dataset_vs_columns: + dataset_vs_columns[cast(str, column_lineage.downstream.table)] = [] + # Initialise the temp table urn, we later need this to merge CLL + + dataset_vs_columns[cast(str, column_lineage.downstream.table)].append( + SchemaField( + fieldPath=column_lineage.downstream.column, + type=cast( + SchemaFieldDataTypeClass, + column_lineage.downstream.column_type, + ), + nativeDataType=cast( + str, column_lineage.downstream.native_column_type + ), + ) + ) + + # Add datasets, and it's respective fields in schema_resolver, so that later schema_resolver would be able + # correctly generates the upstreams for temporary tables + for urn in dataset_vs_columns: + db, schema, table_name = split_qualified_table_name(urn) + schema_resolver.add_schema_metadata( + urn=urn, + schema_metadata=SchemaMetadata( + schemaName=table_name, + platform=builder.make_data_platform_urn( + LineageDatasetPlatform.REDSHIFT.value + ), + version=0, + hash="", + platformSchema=OtherSchema(rawSchema=""), + fields=dataset_vs_columns[urn], + ), + ) + def get_time_window(self) -> Tuple[datetime, datetime]: if self.redundant_run_skip_handler: self.report.stateful_lineage_ingestion_enabled = True @@ -157,25 +282,32 @@ def _get_s3_path(self, path: str) -> str: return path def _get_sources_from_query( - self, db_name: str, query: str + self, + db_name: str, + query: str, + parsed_result: Optional[sqlglot_l.SqlParsingResult] = None, ) -> Tuple[List[LineageDataset], Optional[List[sqlglot_l.ColumnLineageInfo]]]: sources: List[LineageDataset] = list() - parsed_result: Optional[ - sqlglot_l.SqlParsingResult - ] = sqlglot_l.create_lineage_sql_parsed_result( - query=query, - platform=LineageDatasetPlatform.REDSHIFT.value, - platform_instance=self.config.platform_instance, - default_db=db_name, - default_schema=str(self.config.default_schema), - graph=self.context.graph, - env=self.config.env, - ) + if parsed_result is None: + parsed_result = sqlglot_l.create_lineage_sql_parsed_result( + query=query, + platform=LineageDatasetPlatform.REDSHIFT.value, + platform_instance=self.config.platform_instance, + default_db=db_name, + default_schema=str(self.config.default_schema), + graph=self.context.graph, + env=self.config.env, + ) if parsed_result is None: logger.debug(f"native query parsing failed for {query}") return sources, None + elif parsed_result.debug_info.table_error: + logger.debug( + f"native query parsing failed for {query} with error: {parsed_result.debug_info.table_error}" + ) + return sources, None logger.debug(f"parsed_result = {parsed_result}") @@ -277,7 +409,7 @@ def _populate_lineage_map( database: str, lineage_type: LineageCollectorType, connection: redshift_connector.Connection, - all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]], + all_tables_set: Dict[str, Dict[str, Set[str]]], ) -> None: """ This method generate table level lineage based with the given query. @@ -292,7 +424,10 @@ def _populate_lineage_map( return: The method does not return with anything as it directly modify the self._lineage_map property. :rtype: None """ + + logger.info(f"Extracting {lineage_type.name} lineage for db {database}") try: + logger.debug(f"Processing lineage query: {query}") cll: Optional[List[sqlglot_l.ColumnLineageInfo]] = None raw_db_name = database alias_db_name = self.config.database @@ -301,11 +436,18 @@ def _populate_lineage_map( conn=connection, query=query ): target = self._get_target_lineage( - alias_db_name, lineage_row, lineage_type + alias_db_name, + lineage_row, + lineage_type, + all_tables_set=all_tables_set, ) if not target: continue + logger.debug( + f"Processing {lineage_type.name} lineage row: {lineage_row}" + ) + sources, cll = self._get_sources( lineage_type, alias_db_name, @@ -318,9 +460,12 @@ def _populate_lineage_map( target.upstreams.update( self._get_upstream_lineages( sources=sources, - all_tables=all_tables, + target_table=target.dataset.urn, + target_dataset_cll=cll, + all_tables_set=all_tables_set, alias_db_name=alias_db_name, raw_db_name=raw_db_name, + connection=connection, ) ) target.cll = cll @@ -344,21 +489,50 @@ def _populate_lineage_map( ) self.report_status(f"extract-{lineage_type.name}", False) + def _update_lineage_map_for_table_renames( + self, table_renames: Dict[str, str] + ) -> None: + if not table_renames: + return + + logger.info(f"Updating lineage map for {len(table_renames)} table renames") + for new_table_urn, prev_table_urn in table_renames.items(): + # This table was renamed from some other name, copy in the lineage + # for the previous name as well. + prev_table_lineage = self._lineage_map.get(prev_table_urn) + if prev_table_lineage: + logger.debug( + f"including lineage for {prev_table_urn} in {new_table_urn} due to table rename" + ) + self._lineage_map[new_table_urn].merge_lineage( + upstreams=prev_table_lineage.upstreams, + cll=prev_table_lineage.cll, + ) + def _get_target_lineage( self, alias_db_name: str, lineage_row: LineageRow, lineage_type: LineageCollectorType, + all_tables_set: Dict[str, Dict[str, Set[str]]], ) -> Optional[LineageItem]: if ( lineage_type != LineageCollectorType.UNLOAD and lineage_row.target_schema and lineage_row.target_table ): - if not self.config.schema_pattern.allowed( - lineage_row.target_schema - ) or not self.config.table_pattern.allowed( - f"{alias_db_name}.{lineage_row.target_schema}.{lineage_row.target_table}" + if ( + not self.config.schema_pattern.allowed(lineage_row.target_schema) + or not self.config.table_pattern.allowed( + f"{alias_db_name}.{lineage_row.target_schema}.{lineage_row.target_table}" + ) + ) and not ( + # We also check the all_tables_set, since this might be a renamed table + # that we don't want to drop lineage for. + alias_db_name in all_tables_set + and lineage_row.target_schema in all_tables_set[alias_db_name] + and lineage_row.target_table + in all_tables_set[alias_db_name][lineage_row.target_schema] ): return None # Target @@ -400,18 +574,19 @@ def _get_target_lineage( def _get_upstream_lineages( self, sources: List[LineageDataset], - all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]], + target_table: str, + all_tables_set: Dict[str, Dict[str, Set[str]]], alias_db_name: str, raw_db_name: str, + connection: redshift_connector.Connection, + target_dataset_cll: Optional[List[sqlglot_l.ColumnLineageInfo]], ) -> List[LineageDataset]: - targe_source = [] + target_source = [] + probable_temp_tables: List[str] = [] + for source in sources: if source.platform == LineageDatasetPlatform.REDSHIFT: - qualified_table_name = dataset_urn.DatasetUrn.create_from_string( - source.urn - ).get_entity_id()[1] - # -3 because platform instance is optional and that can cause the split to have more than 3 elements - db, schema, table = qualified_table_name.split(".")[-3:] + db, schema, table = split_qualified_table_name(source.urn) if db == raw_db_name: db = alias_db_name path = f"{db}.{schema}.{table}" @@ -427,19 +602,40 @@ def _get_upstream_lineages( # Filtering out tables which does not exist in Redshift # It was deleted in the meantime or query parser did not capture well the table name + # Or it might be a temp table if ( - db not in all_tables - or schema not in all_tables[db] - or not any(table == t.name for t in all_tables[db][schema]) + db not in all_tables_set + or schema not in all_tables_set[db] + or table not in all_tables_set[db][schema] ): logger.debug( - f"{source.urn} missing table, dropping from lineage.", + f"{source.urn} missing table. Adding it to temp table list for target table {target_table}.", ) + probable_temp_tables.append(f"{schema}.{table}") self.report.num_lineage_tables_dropped += 1 continue - targe_source.append(source) - return targe_source + target_source.append(source) + + if probable_temp_tables and self.config.resolve_temp_table_in_lineage: + self.report.num_lineage_processed_temp_tables += len(probable_temp_tables) + # Generate lineage dataset from temporary tables + number_of_permanent_dataset_found: int = ( + self.update_table_and_column_lineage( + db_name=raw_db_name, + connection=connection, + temp_table_names=probable_temp_tables, + target_source_dataset=target_source, + target_dataset_cll=target_dataset_cll, + ) + ) + + logger.debug( + f"Number of permanent datasets found for {target_table} = {number_of_permanent_dataset_found} in " + f"temp tables {probable_temp_tables}" + ) + + return target_source def populate_lineage( self, @@ -447,8 +643,27 @@ def populate_lineage( connection: redshift_connector.Connection, all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]], ) -> None: + if self.config.resolve_temp_table_in_lineage: + self._init_temp_table_schema( + database=database, + temp_tables=self.get_temp_tables(connection=connection), + ) + populate_calls: List[Tuple[str, LineageCollectorType]] = [] + all_tables_set: Dict[str, Dict[str, Set[str]]] = { + db: {schema: {t.name for t in tables} for schema, tables in schemas.items()} + for db, schemas in all_tables.items() + } + + table_renames: Dict[str, str] = {} + if self.config.include_table_rename_lineage: + table_renames, all_tables_set = self._process_table_renames( + database=database, + connection=connection, + all_tables=all_tables_set, + ) + if self.config.table_lineage_mode in { LineageMode.STL_SCAN_BASED, LineageMode.MIXED, @@ -504,9 +719,12 @@ def populate_lineage( database=database, lineage_type=lineage_type, connection=connection, - all_tables=all_tables, + all_tables_set=all_tables_set, ) + # Handling for alter table statements. + self._update_lineage_map_for_table_renames(table_renames=table_renames) + self.report.lineage_mem_size[self.config.database] = humanfriendly.format_size( memory_footprint.total_size(self._lineage_map) ) @@ -613,3 +831,271 @@ def get_lineage( def report_status(self, step: str, status: bool) -> None: if self.redundant_run_skip_handler: self.redundant_run_skip_handler.report_current_run_status(step, status) + + def _process_table_renames( + self, + database: str, + connection: redshift_connector.Connection, + all_tables: Dict[str, Dict[str, Set[str]]], + ) -> Tuple[Dict[str, str], Dict[str, Dict[str, Set[str]]]]: + logger.info(f"Processing table renames for db {database}") + + # new urn -> prev urn + table_renames: Dict[str, str] = {} + + query = RedshiftQuery.alter_table_rename_query( + db_name=database, + start_time=self.start_time, + end_time=self.end_time, + ) + + for rename_row in RedshiftDataDictionary.get_alter_table_commands( + connection, query + ): + schema, prev_name, new_name = parse_alter_table_rename( + default_schema=self.config.default_schema, + query=rename_row.query_text, + ) + + prev_urn = make_dataset_urn_with_platform_instance( + platform=LineageDatasetPlatform.REDSHIFT.value, + platform_instance=self.config.platform_instance, + name=f"{database}.{schema}.{prev_name}", + env=self.config.env, + ) + new_urn = make_dataset_urn_with_platform_instance( + platform=LineageDatasetPlatform.REDSHIFT.value, + platform_instance=self.config.platform_instance, + name=f"{database}.{schema}.{new_name}", + env=self.config.env, + ) + + table_renames[new_urn] = prev_urn + + # We want to generate lineage for the previous name too. + all_tables[database][schema].add(prev_name) + + logger.info(f"Discovered {len(table_renames)} table renames") + return table_renames, all_tables + + def get_temp_tables( + self, connection: redshift_connector.Connection + ) -> List[TempTableRow]: + ddl_query: str = RedshiftQuery.temp_table_ddl_query( + start_time=self.config.start_time, + end_time=self.config.end_time, + ) + + logger.debug(f"Temporary table ddl query = {ddl_query}") + + temp_table_rows: List[TempTableRow] = [] + + for row in RedshiftDataDictionary.get_temporary_rows( + conn=connection, + query=ddl_query, + ): + temp_table_rows.append(row) + + return temp_table_rows + + def find_temp_tables( + self, temp_table_rows: List[TempTableRow], temp_table_names: List[str] + ) -> List[TempTableRow]: + matched_temp_tables: List[TempTableRow] = [] + + for table_name in temp_table_names: + prefixes = RedshiftQuery.get_temp_table_clause(table_name) + prefixes.extend( + RedshiftQuery.get_temp_table_clause(table_name.split(".")[-1]) + ) + + for row in temp_table_rows: + if any( + row.create_command.lower().startswith(prefix) for prefix in prefixes + ): + matched_temp_tables.append(row) + + return matched_temp_tables + + def resolve_column_refs( + self, column_refs: List[sqlglot_l.ColumnRef], depth: int = 0 + ) -> List[sqlglot_l.ColumnRef]: + """ + This method resolves the column reference to the original column reference. + For example, if the column reference is to a temporary table, it will be resolved to the original column + reference. + """ + max_depth = 10 + + resolved_column_refs: List[sqlglot_l.ColumnRef] = [] + if not column_refs: + return column_refs + + if depth >= max_depth: + logger.warning( + f"Max depth reached for resolving temporary columns: {column_refs}" + ) + self.report.num_unresolved_temp_columns += 1 + return column_refs + + for ref in column_refs: + resolved = False + if ref.table in self.temp_tables: + table = self.temp_tables[ref.table] + if table.parsed_result and table.parsed_result.column_lineage: + for column_lineage in table.parsed_result.column_lineage: + if ( + column_lineage.downstream.table == ref.table + and column_lineage.downstream.column == ref.column + ): + resolved_column_refs.extend( + self.resolve_column_refs( + column_lineage.upstreams, depth=depth + 1 + ) + ) + resolved = True + break + # If we reach here, it means that we were not able to resolve the column reference. + if resolved is False: + logger.warning( + f"Unable to resolve column reference {ref} to a permanent table" + ) + else: + logger.debug( + f"Resolved column reference {ref} is not resolved because referenced table {ref.table} is not a temp table or not found. Adding reference as non-temp table. This is normal." + ) + resolved_column_refs.append(ref) + return resolved_column_refs + + def _update_target_dataset_cll( + self, + temp_table_urn: str, + target_dataset_cll: List[sqlglot_l.ColumnLineageInfo], + source_dataset_cll: List[sqlglot_l.ColumnLineageInfo], + ) -> None: + for target_column_lineage in target_dataset_cll: + upstreams: List[sqlglot_l.ColumnRef] = [] + # Look for temp_table_urn in upstream of column_lineage, if found then we need to replace it with + # column of permanent table + for target_column_ref in target_column_lineage.upstreams: + if target_column_ref.table == temp_table_urn: + # Look for column_ref.table and column_ref.column in downstream of source_dataset_cll. + # The source_dataset_cll contains CLL generated from create statement of temp table (temp_table_urn) + for source_column_lineage in source_dataset_cll: + if ( + source_column_lineage.downstream.table + == target_column_ref.table + and source_column_lineage.downstream.column + == target_column_ref.column + ): + resolved_columns = self.resolve_column_refs( + source_column_lineage.upstreams + ) + # Add all upstream of above temporary column into upstream of target column + upstreams.extend(resolved_columns) + continue + + upstreams.append(target_column_ref) + + if upstreams: + # update the upstreams + target_column_lineage.upstreams = upstreams + + def _add_permanent_datasets_recursively( + self, + db_name: str, + temp_table_rows: List[TempTableRow], + visited_tables: Set[str], + connection: redshift_connector.Connection, + permanent_lineage_datasets: List[LineageDataset], + target_dataset_cll: Optional[List[sqlglot_l.ColumnLineageInfo]], + ) -> None: + transitive_temp_tables: List[TempTableRow] = [] + + for temp_table in temp_table_rows: + logger.debug( + f"Processing temp table with transaction id: {temp_table.transaction_id} and query text {temp_table.query_text}" + ) + + intermediate_l_datasets, cll = self._get_sources_from_query( + db_name=db_name, + query=temp_table.query_text, + parsed_result=temp_table.parsed_result, + ) + + if ( + temp_table.urn is not None + and target_dataset_cll is not None + and cll is not None + ): # condition to silent the lint + self._update_target_dataset_cll( + temp_table_urn=temp_table.urn, + target_dataset_cll=target_dataset_cll, + source_dataset_cll=cll, + ) + + # make sure lineage dataset should not contain a temp table + # if such dataset is present then add it to transitive_temp_tables to resolve it to original permanent table + for lineage_dataset in intermediate_l_datasets: + db, schema, table = split_qualified_table_name(lineage_dataset.urn) + + if table in visited_tables: + # The table is already processed + continue + + # Check if table found is again a temp table + repeated_temp_table: List[TempTableRow] = self.find_temp_tables( + temp_table_rows=list(self.temp_tables.values()), + temp_table_names=[table], + ) + + if not repeated_temp_table: + logger.debug(f"Unable to find table {table} in temp tables.") + + if repeated_temp_table: + transitive_temp_tables.extend(repeated_temp_table) + visited_tables.add(table) + continue + + permanent_lineage_datasets.append(lineage_dataset) + + if transitive_temp_tables: + # recursive call + self._add_permanent_datasets_recursively( + db_name=db_name, + temp_table_rows=transitive_temp_tables, + visited_tables=visited_tables, + connection=connection, + permanent_lineage_datasets=permanent_lineage_datasets, + target_dataset_cll=target_dataset_cll, + ) + + def update_table_and_column_lineage( + self, + db_name: str, + temp_table_names: List[str], + connection: redshift_connector.Connection, + target_source_dataset: List[LineageDataset], + target_dataset_cll: Optional[List[sqlglot_l.ColumnLineageInfo]], + ) -> int: + permanent_lineage_datasets: List[LineageDataset] = [] + + temp_table_rows: List[TempTableRow] = self.find_temp_tables( + temp_table_rows=list(self.temp_tables.values()), + temp_table_names=temp_table_names, + ) + + visited_tables: Set[str] = set(temp_table_names) + + self._add_permanent_datasets_recursively( + db_name=db_name, + temp_table_rows=temp_table_rows, + visited_tables=visited_tables, + connection=connection, + permanent_lineage_datasets=permanent_lineage_datasets, + target_dataset_cll=target_dataset_cll, + ) + + target_source_dataset.extend(permanent_lineage_datasets) + + return len(permanent_lineage_datasets) diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/query.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/query.py index 92e36fffd6bb4..93beb5980ea62 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/query.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/query.py @@ -1,9 +1,14 @@ from datetime import datetime +from typing import List redshift_datetime_format = "%Y-%m-%d %H:%M:%S" class RedshiftQuery: + CREATE_TEMP_TABLE_CLAUSE = "create temp table" + CREATE_TEMPORARY_TABLE_CLAUSE = "create temporary table" + CREATE_TABLE_CLAUSE = "create table" + list_databases: str = """SELECT datname FROM pg_database WHERE (datname <> ('padb_harvest')::name) AND (datname <> ('template0')::name) @@ -97,7 +102,7 @@ class RedshiftQuery: NULL as table_description FROM pg_catalog.svv_external_tables ORDER BY "schema", - "relname"; + "relname" """ list_columns: str = """ SELECT @@ -379,7 +384,8 @@ def list_insert_create_queries_sql( target_schema, target_table, username, - querytxt as ddl + query as query_id, + LISTAGG(CASE WHEN LEN(RTRIM(querytxt)) = 0 THEN querytxt ELSE RTRIM(querytxt) END) WITHIN GROUP (ORDER BY sequence) as ddl from ( select @@ -388,7 +394,9 @@ def list_insert_create_queries_sql( sti.table as target_table, sti.database as cluster, usename as username, - querytxt, + text as querytxt, + sq.query, + sequence, si.starttime as starttime from stl_insert as si @@ -396,19 +404,20 @@ def list_insert_create_queries_sql( sti.table_id = tbl left join svl_user_info sui on si.userid = sui.usesysid - left join stl_query sq on + left join STL_QUERYTEXT sq on si.query = sq.query left join stl_load_commits slc on slc.query = si.query where sui.usename <> 'rdsdb' - and sq.aborted = 0 and slc.query IS NULL and cluster = '{db_name}' and si.starttime >= '{start_time}' and si.starttime < '{end_time}' + and sequence < 320 ) as target_tables - order by cluster, target_schema, target_table, starttime asc + group by cluster, query_id, target_schema, target_table, username, starttime + order by cluster, query_id, target_schema, target_table, starttime asc """.format( # We need the original database name for filtering db_name=db_name, @@ -443,3 +452,118 @@ def list_copy_commands_sql( start_time=start_time.strftime(redshift_datetime_format), end_time=end_time.strftime(redshift_datetime_format), ) + + @staticmethod + def get_temp_table_clause(table_name: str) -> List[str]: + return [ + f"{RedshiftQuery.CREATE_TABLE_CLAUSE} {table_name}", + f"{RedshiftQuery.CREATE_TEMP_TABLE_CLAUSE} {table_name}", + f"{RedshiftQuery.CREATE_TEMPORARY_TABLE_CLAUSE} {table_name}", + ] + + @staticmethod + def temp_table_ddl_query(start_time: datetime, end_time: datetime) -> str: + start_time_str: str = start_time.strftime(redshift_datetime_format) + + end_time_str: str = end_time.strftime(redshift_datetime_format) + + return rf"""-- DataHub Redshift Source temp table DDL query + select + * + from + ( + select + session_id, + transaction_id, + start_time, + userid, + REGEXP_REPLACE(REGEXP_SUBSTR(REGEXP_REPLACE(query_text,'\\\\n','\\n'), '(CREATE(?:[\\n\\s\\t]+(?:temp|temporary))?(?:[\\n\\s\\t]+)table(?:[\\n\\s\\t]+)[^\\n\\s\\t()-]+)', 0, 1, 'ipe'),'[\\n\\s\\t]+',' ',1,'p') as create_command, + query_text, + row_number() over ( + partition by TRIM(query_text) + order by start_time desc + ) rn + from + ( + select + pid as session_id, + xid as transaction_id, + starttime as start_time, + type, + query_text, + userid + from + ( + select + starttime, + pid, + xid, + type, + userid, + LISTAGG(case + when LEN(RTRIM(text)) = 0 then text + else RTRIM(text) + end, + '') within group ( + order by sequence + ) as query_text + from + SVL_STATEMENTTEXT + where + type in ('DDL', 'QUERY') + AND starttime >= '{start_time_str}' + AND starttime < '{end_time_str}' + -- See https://stackoverflow.com/questions/72770890/redshift-result-size-exceeds-listagg-limit-on-svl-statementtext + AND sequence < 320 + group by + starttime, + pid, + xid, + type, + userid + order by + starttime, + pid, + xid, + type, + userid + asc) + where + type in ('DDL', 'QUERY') + ) + where + (create_command ilike 'create temp table %' + or create_command ilike 'create temporary table %' + -- we want to get all the create table statements and not just temp tables if non temp table is created and dropped in the same transaction + or create_command ilike 'create table %') + -- Redshift creates temp tables with the following names: volt_tt_%. We need to filter them out. + and query_text not ilike 'CREATE TEMP TABLE volt_tt_%' + and create_command not like 'CREATE TEMP TABLE volt_tt_' + -- We need to filter out our query and it was not possible earlier when we did not have any comment in the query + and query_text not ilike '%https://stackoverflow.com/questions/72770890/redshift-result-size-exceeds-listagg-limit-on-svl-statementtext%' + + ) + where + rn = 1; + """ + + @staticmethod + def alter_table_rename_query( + db_name: str, start_time: datetime, end_time: datetime + ) -> str: + start_time_str: str = start_time.strftime(redshift_datetime_format) + end_time_str: str = end_time.strftime(redshift_datetime_format) + + return f""" + SELECT transaction_id, + session_id, + start_time, + query_text + FROM sys_query_history SYS + WHERE SYS.status = 'success' + AND SYS.query_type = 'DDL' + AND SYS.database_name = '{db_name}' + AND SYS.start_time >= '{start_time_str}' + AND SYS.end_time < '{end_time_str}' + AND SYS.query_text ILIKE 'alter table % rename to %' + """ diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_schema.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_schema.py index ca81682ae00e4..0ea073c050502 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_schema.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_schema.py @@ -9,6 +9,7 @@ from datahub.ingestion.source.sql.sql_generic import BaseColumn, BaseTable from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField from datahub.utilities.hive_schema_to_avro import get_schema_fields_for_hive_column +from datahub.utilities.sqlglot_lineage import SqlParsingResult logger: logging.Logger = logging.getLogger(__name__) @@ -80,6 +81,26 @@ class LineageRow: filename: Optional[str] +@dataclass +class TempTableRow: + transaction_id: int + session_id: str + query_text: str + create_command: str + start_time: datetime + urn: Optional[str] + parsed_result: Optional[SqlParsingResult] = None + + +@dataclass +class AlterTableRow: + # TODO unify this type with TempTableRow + transaction_id: int + session_id: str + query_text: str + start_time: datetime + + # this is a class to be a proxy to query Redshift class RedshiftDataDictionary: @staticmethod @@ -359,9 +380,62 @@ def get_lineage_rows( target_table=row[field_names.index("target_table")] if "target_table" in field_names else None, - ddl=row[field_names.index("ddl")] if "ddl" in field_names else None, + # See https://docs.aws.amazon.com/redshift/latest/dg/r_STL_QUERYTEXT.html + # for why we need to remove the \\n. + ddl=row[field_names.index("ddl")].replace("\\n", "\n") + if "ddl" in field_names + else None, filename=row[field_names.index("filename")] if "filename" in field_names else None, ) rows = cursor.fetchmany() + + @staticmethod + def get_temporary_rows( + conn: redshift_connector.Connection, + query: str, + ) -> Iterable[TempTableRow]: + cursor = conn.cursor() + + cursor.execute(query) + + field_names = [i[0] for i in cursor.description] + + rows = cursor.fetchmany() + while rows: + for row in rows: + yield TempTableRow( + transaction_id=row[field_names.index("transaction_id")], + session_id=row[field_names.index("session_id")], + # See https://docs.aws.amazon.com/redshift/latest/dg/r_STL_QUERYTEXT.html + # for why we need to replace the \n with a newline. + query_text=row[field_names.index("query_text")].replace( + r"\n", "\n" + ), + create_command=row[field_names.index("create_command")], + start_time=row[field_names.index("start_time")], + urn=None, + ) + rows = cursor.fetchmany() + + @staticmethod + def get_alter_table_commands( + conn: redshift_connector.Connection, + query: str, + ) -> Iterable[AlterTableRow]: + # TODO: unify this with get_temporary_rows + cursor = RedshiftDataDictionary.get_query_result(conn, query) + + field_names = [i[0] for i in cursor.description] + + rows = cursor.fetchmany() + while rows: + for row in rows: + yield AlterTableRow( + transaction_id=row[field_names.index("transaction_id")], + session_id=row[field_names.index("session_id")], + query_text=row[field_names.index("query_text")], + start_time=row[field_names.index("start_time")], + ) + rows = cursor.fetchmany() diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py index 333c851650fb3..36ac7955f15d5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py @@ -35,6 +35,7 @@ class RedshiftReport(ProfilingSqlReport, IngestionStageReport, BaseTimeWindowRep num_lineage_tables_dropped: int = 0 num_lineage_dropped_query_parser: int = 0 num_lineage_dropped_not_support_copy_path: int = 0 + num_lineage_processed_temp_tables = 0 lineage_start_time: Optional[datetime] = None lineage_end_time: Optional[datetime] = None @@ -43,6 +44,7 @@ class RedshiftReport(ProfilingSqlReport, IngestionStageReport, BaseTimeWindowRep usage_start_time: Optional[datetime] = None usage_end_time: Optional[datetime] = None stateful_usage_ingestion_enabled: bool = False + num_unresolved_temp_columns: int = 0 def report_dropped(self, key: str) -> None: self.filtered.append(key) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index b896df1fa340e..aad4a6ed27cb8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -140,7 +140,9 @@ class SnowflakeV2Config( # This is required since access_history table does not capture whether the table was temporary table. temporary_tables_pattern: List[str] = Field( default=DEFAULT_TABLES_DENY_LIST, - description="[Advanced] Regex patterns for temporary tables to filter in lineage ingestion. Specify regex to match the entire table name in database.schema.table format. Defaults are to set in such a way to ignore the temporary staging tables created by known ETL tools.", + description="[Advanced] Regex patterns for temporary tables to filter in lineage ingestion. Specify regex to " + "match the entire table name in database.schema.table format. Defaults are to set in such a way " + "to ignore the temporary staging tables created by known ETL tools.", ) rename_upstreams_deny_pattern_to_temporary_table_pattern = pydantic_renamed_field( @@ -150,13 +152,16 @@ class SnowflakeV2Config( shares: Optional[Dict[str, SnowflakeShareConfig]] = Field( default=None, description="Required if current account owns or consumes snowflake share." - " If specified, connector creates lineage and siblings relationship between current account's database tables and consumer/producer account's database tables." + "If specified, connector creates lineage and siblings relationship between current account's database tables " + "and consumer/producer account's database tables." " Map of share name -> details of share.", ) email_as_user_identifier: bool = Field( default=True, - description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is provided, generates email addresses for snowflake users with unset emails, based on their username.", + description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is " + "provided, generates email addresses for snowflake users with unset emails, based on their " + "username.", ) @validator("convert_urns_to_lowercase") diff --git a/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py b/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py index abe4f82673777..5b063451df9cf 100644 --- a/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py +++ b/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py @@ -1037,6 +1037,14 @@ def _sqlglot_lineage_inner( default_db = default_db.upper() if default_schema: default_schema = default_schema.upper() + if _is_dialect_instance(dialect, "redshift") and not default_schema: + # On Redshift, there's no "USE SCHEMA " command. The default schema + # is public, and "current schema" is the one at the front of the search path. + # See https://docs.aws.amazon.com/redshift/latest/dg/r_search_path.html + # and https://stackoverflow.com/questions/9067335/how-does-the-search-path-influence-identifier-resolution-and-the-current-schema?noredirect=1&lq=1 + # default_schema = "public" + # TODO: Re-enable this. + pass logger.debug("Parsing lineage from sql statement: %s", sql) statement = _parse_statement(sql, dialect=dialect) diff --git a/metadata-ingestion/tests/unit/redshift_query_mocker.py b/metadata-ingestion/tests/unit/redshift_query_mocker.py new file mode 100644 index 0000000000000..631e6e7ceaf1f --- /dev/null +++ b/metadata-ingestion/tests/unit/redshift_query_mocker.py @@ -0,0 +1,104 @@ +from datetime import datetime +from unittest.mock import MagicMock + + +def mock_temp_table_cursor(cursor: MagicMock) -> None: + cursor.description = [ + ["transaction_id"], + ["session_id"], + ["query_text"], + ["create_command"], + ["start_time"], + ] + + cursor.fetchmany.side_effect = [ + [ + ( + 126, + "abc", + "CREATE TABLE #player_price distkey(player_id) AS SELECT player_id, SUM(price) AS " + "price_usd from player_activity group by player_id", + "CREATE TABLE #player_price", + datetime.now(), + ) + ], + [ + # Empty result to stop the while loop + ], + ] + + +def mock_stl_insert_table_cursor(cursor: MagicMock) -> None: + cursor.description = [ + ["source_schema"], + ["source_table"], + ["target_schema"], + ["target_table"], + ["ddl"], + ] + + cursor.fetchmany.side_effect = [ + [ + ( + "public", + "#player_price", + "public", + "player_price_with_hike_v6", + "INSERT INTO player_price_with_hike_v6 SELECT (price_usd + 0.2 * price_usd) as price, '20%' FROM " + "#player_price", + ) + ], + [ + # Empty result to stop the while loop + ], + ] + + +query_vs_cursor_mocker = { + ( + "-- DataHub Redshift Source temp table DDL query\n select\n *\n " + "from\n (\n select\n session_id,\n " + " transaction_id,\n start_time,\n userid,\n " + " REGEXP_REPLACE(REGEXP_SUBSTR(REGEXP_REPLACE(query_text,'\\\\\\\\n','\\\\n'), '(CREATE(?:[" + "\\\\n\\\\s\\\\t]+(?:temp|temporary))?(?:[\\\\n\\\\s\\\\t]+)table(?:[\\\\n\\\\s\\\\t]+)[" + "^\\\\n\\\\s\\\\t()-]+)', 0, 1, 'ipe'),'[\\\\n\\\\s\\\\t]+',' ',1,'p') as create_command,\n " + " query_text,\n row_number() over (\n partition " + "by TRIM(query_text)\n order by start_time desc\n ) rn\n " + " from\n (\n select\n pid " + "as session_id,\n xid as transaction_id,\n starttime " + "as start_time,\n type,\n query_text,\n " + " userid\n from\n (\n " + "select\n starttime,\n pid,\n " + " xid,\n type,\n userid,\n " + " LISTAGG(case\n when LEN(RTRIM(text)) = 0 then text\n " + " else RTRIM(text)\n end,\n " + " '') within group (\n order by sequence\n " + " ) as query_text\n from\n " + "SVL_STATEMENTTEXT\n where\n type in ('DDL', " + "'QUERY')\n AND starttime >= '2024-01-01 12:00:00'\n " + " AND starttime < '2024-01-10 12:00:00'\n -- See " + "https://stackoverflow.com/questions/72770890/redshift-result-size-exceeds-listagg-limit-on-svl" + "-statementtext\n AND sequence < 320\n group by\n " + " starttime,\n pid,\n " + "xid,\n type,\n userid\n " + " order by\n starttime,\n pid,\n " + " xid,\n type,\n userid\n " + " asc)\n where\n type in ('DDL', " + "'QUERY')\n )\n where\n (create_command ilike " + "'create temp table %'\n or create_command ilike 'create temporary table %'\n " + " -- we want to get all the create table statements and not just temp tables " + "if non temp table is created and dropped in the same transaction\n or " + "create_command ilike 'create table %')\n -- Redshift creates temp tables with " + "the following names: volt_tt_%. We need to filter them out.\n and query_text not " + "ilike 'CREATE TEMP TABLE volt_tt_%'\n and create_command not like 'CREATE TEMP " + "TABLE volt_tt_'\n -- We need to filter out our query and it was not possible " + "earlier when we did not have any comment in the query\n and query_text not ilike " + "'%https://stackoverflow.com/questions/72770890/redshift-result-size-exceeds-listagg-limit-on-svl" + "-statementtext%'\n\n )\n where\n rn = 1;\n " + ): mock_temp_table_cursor, + "select * from test_collapse_temp_lineage": mock_stl_insert_table_cursor, +} + + +def mock_cursor(cursor: MagicMock, query: str) -> None: + query_vs_cursor_mocker[query](cursor=cursor) diff --git a/metadata-ingestion/tests/unit/test_redshift_lineage.py b/metadata-ingestion/tests/unit/test_redshift_lineage.py index db5af3a71efb9..6a3e6e47bd96a 100644 --- a/metadata-ingestion/tests/unit/test_redshift_lineage.py +++ b/metadata-ingestion/tests/unit/test_redshift_lineage.py @@ -1,8 +1,31 @@ +from datetime import datetime +from functools import partial +from typing import List +from unittest.mock import MagicMock + +import datahub.utilities.sqlglot_lineage as sqlglot_l from datahub.ingestion.api.common import PipelineContext +from datahub.ingestion.graph.client import DataHubGraph from datahub.ingestion.source.redshift.config import RedshiftConfig -from datahub.ingestion.source.redshift.lineage import RedshiftLineageExtractor +from datahub.ingestion.source.redshift.lineage import ( + LineageCollectorType, + LineageDataset, + LineageDatasetPlatform, + LineageItem, + RedshiftLineageExtractor, + parse_alter_table_rename, +) +from datahub.ingestion.source.redshift.redshift_schema import TempTableRow from datahub.ingestion.source.redshift.report import RedshiftReport -from datahub.utilities.sqlglot_lineage import ColumnLineageInfo, DownstreamColumnRef +from datahub.metadata._schema_classes import NumberTypeClass, SchemaFieldDataTypeClass +from datahub.utilities.sqlglot_lineage import ( + ColumnLineageInfo, + DownstreamColumnRef, + QueryType, + SqlParsingDebugInfo, + SqlParsingResult, +) +from tests.unit.redshift_query_mocker import mock_cursor def test_get_sources_from_query(): @@ -120,16 +143,45 @@ def test_get_sources_from_query_with_only_table(): ) -def test_cll(): - config = RedshiftConfig(host_port="localhost:5439", database="test") +def test_parse_alter_table_rename(): + assert parse_alter_table_rename("public", "alter table foo rename to bar") == ( + "public", + "foo", + "bar", + ) + assert parse_alter_table_rename( + "public", "alter table second_schema.storage_v2_stg rename to storage_v2; " + ) == ( + "second_schema", + "storage_v2_stg", + "storage_v2", + ) + + +def get_lineage_extractor() -> RedshiftLineageExtractor: + config = RedshiftConfig( + host_port="localhost:5439", + database="test", + resolve_temp_table_in_lineage=True, + start_time=datetime(2024, 1, 1, 12, 0, 0).isoformat() + "Z", + end_time=datetime(2024, 1, 10, 12, 0, 0).isoformat() + "Z", + ) report = RedshiftReport() + lineage_extractor = RedshiftLineageExtractor( + config, report, PipelineContext(run_id="foo", graph=mock_graph()) + ) + + return lineage_extractor + + +def test_cll(): test_query = """ select a,b,c from db.public.customer inner join db.public.order on db.public.customer.id = db.public.order.customer_id """ - lineage_extractor = RedshiftLineageExtractor( - config, report, PipelineContext(run_id="foo") - ) + + lineage_extractor = get_lineage_extractor() + _, cll = lineage_extractor._get_sources_from_query(db_name="db", query=test_query) assert cll == [ @@ -149,3 +201,600 @@ def test_cll(): logic=None, ), ] + + +def cursor_execute_side_effect(cursor: MagicMock, query: str) -> None: + mock_cursor(cursor=cursor, query=query) + + +def mock_redshift_connection() -> MagicMock: + connection = MagicMock() + + cursor = MagicMock() + + connection.cursor.return_value = cursor + + cursor.execute.side_effect = partial(cursor_execute_side_effect, cursor) + + return connection + + +def mock_graph() -> DataHubGraph: + + graph = MagicMock() + + graph._make_schema_resolver.return_value = sqlglot_l.SchemaResolver( + platform="redshift", + env="PROD", + platform_instance=None, + graph=None, + ) + + return graph + + +def test_collapse_temp_lineage(): + lineage_extractor = get_lineage_extractor() + + connection: MagicMock = mock_redshift_connection() + + lineage_extractor._init_temp_table_schema( + database=lineage_extractor.config.database, + temp_tables=lineage_extractor.get_temp_tables(connection=connection), + ) + + lineage_extractor._populate_lineage_map( + query="select * from test_collapse_temp_lineage", + database=lineage_extractor.config.database, + all_tables_set={ + lineage_extractor.config.database: {"public": {"player_price_with_hike_v6"}} + }, + connection=connection, + lineage_type=LineageCollectorType.QUERY_SQL_PARSER, + ) + + print(lineage_extractor._lineage_map) + + target_urn: str = "urn:li:dataset:(urn:li:dataPlatform:redshift,test.public.player_price_with_hike_v6,PROD)" + + assert lineage_extractor._lineage_map.get(target_urn) is not None + + lineage_item: LineageItem = lineage_extractor._lineage_map[target_urn] + + assert list(lineage_item.upstreams)[0].urn == ( + "urn:li:dataset:(urn:li:dataPlatform:redshift," + "test.public.player_activity,PROD)" + ) + + assert lineage_item.cll is not None + + assert lineage_item.cll[0].downstream.table == ( + "urn:li:dataset:(urn:li:dataPlatform:redshift," + "test.public.player_price_with_hike_v6,PROD)" + ) + + assert lineage_item.cll[0].downstream.column == "price" + + assert lineage_item.cll[0].upstreams[0].table == ( + "urn:li:dataset:(urn:li:dataPlatform:redshift," + "test.public.player_activity,PROD)" + ) + + assert lineage_item.cll[0].upstreams[0].column == "price" + + +def test_collapse_temp_recursive_cll_lineage(): + lineage_extractor = get_lineage_extractor() + + temp_table: TempTableRow = TempTableRow( + transaction_id=126, + query_text="CREATE TABLE #player_price distkey(player_id) AS SELECT player_id, SUM(price_usd) AS price_usd " + "from #player_activity_temp group by player_id", + start_time=datetime.now(), + session_id="abc", + create_command="CREATE TABLE #player_price", + parsed_result=SqlParsingResult( + query_type=QueryType.CREATE, + in_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)" + ], + out_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)" + ], + debug_info=SqlParsingDebugInfo(), + column_lineage=[ + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + column="player_id", + column_type=SchemaFieldDataTypeClass(NumberTypeClass()), + native_column_type="INTEGER", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="player_id", + ) + ], + logic=None, + ), + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + column="price_usd", + column_type=SchemaFieldDataTypeClass(NumberTypeClass()), + native_column_type="BIGINT", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="price_usd", + ) + ], + logic=None, + ), + ], + ), + urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + ) + + temp_table_activity: TempTableRow = TempTableRow( + transaction_id=127, + query_text="CREATE TABLE #player_activity_temp SELECT player_id, SUM(price) AS price_usd " + "from player_activity", + start_time=datetime.now(), + session_id="abc", + create_command="CREATE TABLE #player_activity_temp", + parsed_result=SqlParsingResult( + query_type=QueryType.CREATE, + in_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)" + ], + out_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)" + ], + debug_info=SqlParsingDebugInfo(), + column_lineage=[ + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="player_id", + column_type=SchemaFieldDataTypeClass(NumberTypeClass()), + native_column_type="INTEGER", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)", + column="player_id", + ) + ], + logic=None, + ), + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="price_usd", + column_type=SchemaFieldDataTypeClass(NumberTypeClass()), + native_column_type="BIGINT", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)", + column="price", + ) + ], + logic=None, + ), + ], + ), + urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + ) + + assert temp_table.urn + assert temp_table_activity.urn + + lineage_extractor.temp_tables[temp_table.urn] = temp_table + lineage_extractor.temp_tables[temp_table_activity.urn] = temp_table_activity + + target_dataset_cll: List[sqlglot_l.ColumnLineageInfo] = [ + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v6,PROD)", + column="price", + column_type=SchemaFieldDataTypeClass(type=NumberTypeClass()), + native_column_type="DOUBLE PRECISION", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + column="price_usd", + ) + ], + logic=None, + ) + ] + + datasets = lineage_extractor._get_upstream_lineages( + sources=[ + LineageDataset( + platform=LineageDatasetPlatform.REDSHIFT, + urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + ) + ], + target_table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v4,PROD)", + raw_db_name="dev", + alias_db_name="dev", + all_tables_set={ + "dev": { + "public": set(), + } + }, + connection=MagicMock(), + target_dataset_cll=target_dataset_cll, + ) + + assert len(datasets) == 1 + + assert ( + datasets[0].urn + == "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)" + ) + + assert target_dataset_cll[0].upstreams[0].table == ( + "urn:li:dataset:(urn:li:dataPlatform:redshift," + "dev.public.player_activity,PROD)" + ) + assert target_dataset_cll[0].upstreams[0].column == "price" + + +def test_collapse_temp_recursive_with_compex_column_cll_lineage(): + lineage_extractor = get_lineage_extractor() + + temp_table: TempTableRow = TempTableRow( + transaction_id=126, + query_text="CREATE TABLE #player_price distkey(player_id) AS SELECT player_id, SUM(price+tax) AS price_usd " + "from #player_activity_temp group by player_id", + start_time=datetime.now(), + session_id="abc", + create_command="CREATE TABLE #player_price", + parsed_result=SqlParsingResult( + query_type=QueryType.CREATE, + in_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)" + ], + out_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)" + ], + debug_info=SqlParsingDebugInfo(), + column_lineage=[ + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + column="player_id", + column_type=SchemaFieldDataTypeClass(NumberTypeClass()), + native_column_type="INTEGER", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="player_id", + ) + ], + logic=None, + ), + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + column="price_usd", + column_type=SchemaFieldDataTypeClass(NumberTypeClass()), + native_column_type="BIGINT", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="price", + ), + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="tax", + ), + ], + logic=None, + ), + ], + ), + urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + ) + + temp_table_activity: TempTableRow = TempTableRow( + transaction_id=127, + query_text="CREATE TABLE #player_activity_temp SELECT player_id, price, tax " + "from player_activity", + start_time=datetime.now(), + session_id="abc", + create_command="CREATE TABLE #player_activity_temp", + parsed_result=SqlParsingResult( + query_type=QueryType.CREATE, + in_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)" + ], + out_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)" + ], + debug_info=SqlParsingDebugInfo(), + column_lineage=[ + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="player_id", + column_type=SchemaFieldDataTypeClass(NumberTypeClass()), + native_column_type="INTEGER", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)", + column="player_id", + ) + ], + logic=None, + ), + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="price", + column_type=SchemaFieldDataTypeClass(NumberTypeClass()), + native_column_type="BIGINT", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)", + column="price", + ) + ], + logic=None, + ), + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="tax", + column_type=SchemaFieldDataTypeClass(NumberTypeClass()), + native_column_type="BIGINT", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)", + column="tax", + ) + ], + logic=None, + ), + ], + ), + urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + ) + assert temp_table.urn + assert temp_table_activity.urn + + lineage_extractor.temp_tables[temp_table.urn] = temp_table + lineage_extractor.temp_tables[temp_table_activity.urn] = temp_table_activity + + target_dataset_cll: List[sqlglot_l.ColumnLineageInfo] = [ + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v6,PROD)", + column="price", + column_type=SchemaFieldDataTypeClass(type=NumberTypeClass()), + native_column_type="DOUBLE PRECISION", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + column="price_usd", + ) + ], + logic=None, + ), + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v6,PROD)", + column="player_id", + column_type=SchemaFieldDataTypeClass(type=NumberTypeClass()), + native_column_type="BIGINT", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + column="player_id", + ) + ], + logic=None, + ), + ] + + datasets = lineage_extractor._get_upstream_lineages( + sources=[ + LineageDataset( + platform=LineageDatasetPlatform.REDSHIFT, + urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + ) + ], + target_table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v4,PROD)", + raw_db_name="dev", + alias_db_name="dev", + all_tables_set={ + "dev": { + "public": set(), + } + }, + connection=MagicMock(), + target_dataset_cll=target_dataset_cll, + ) + + assert len(datasets) == 1 + + assert ( + datasets[0].urn + == "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)" + ) + + assert target_dataset_cll[0].upstreams[0].table == ( + "urn:li:dataset:(urn:li:dataPlatform:redshift," + "dev.public.player_activity,PROD)" + ) + assert target_dataset_cll[0].upstreams[0].column == "price" + assert target_dataset_cll[0].upstreams[1].column == "tax" + assert target_dataset_cll[1].upstreams[0].column == "player_id" + + +def test_collapse_temp_recursive_cll_lineage_with_circular_reference(): + lineage_extractor = get_lineage_extractor() + + temp_table: TempTableRow = TempTableRow( + transaction_id=126, + query_text="CREATE TABLE #player_price distkey(player_id) AS SELECT player_id, SUM(price_usd) AS price_usd " + "from #player_activity_temp group by player_id", + start_time=datetime.now(), + session_id="abc", + create_command="CREATE TABLE #player_price", + parsed_result=SqlParsingResult( + query_type=QueryType.CREATE, + in_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)" + ], + out_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)" + ], + debug_info=SqlParsingDebugInfo(), + column_lineage=[ + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + column="player_id", + column_type=SchemaFieldDataTypeClass(NumberTypeClass()), + native_column_type="INTEGER", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="player_id", + ) + ], + logic=None, + ), + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + column="price_usd", + column_type=SchemaFieldDataTypeClass(NumberTypeClass()), + native_column_type="BIGINT", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="price_usd", + ) + ], + logic=None, + ), + ], + ), + urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + ) + + temp_table_activity: TempTableRow = TempTableRow( + transaction_id=127, + query_text="CREATE TABLE #player_activity_temp SELECT player_id, SUM(price) AS price_usd " + "from #player_price", + start_time=datetime.now(), + session_id="abc", + create_command="CREATE TABLE #player_activity_temp", + parsed_result=SqlParsingResult( + query_type=QueryType.CREATE, + in_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)" + ], + out_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)" + ], + debug_info=SqlParsingDebugInfo(), + column_lineage=[ + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="player_id", + column_type=SchemaFieldDataTypeClass(NumberTypeClass()), + native_column_type="INTEGER", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="player_id", + ) + ], + logic=None, + ), + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="price_usd", + column_type=SchemaFieldDataTypeClass(NumberTypeClass()), + native_column_type="BIGINT", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + column="price_usd", + ) + ], + logic=None, + ), + ], + ), + urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", + ) + + assert temp_table.urn + assert temp_table_activity.urn + + lineage_extractor.temp_tables[temp_table.urn] = temp_table + lineage_extractor.temp_tables[temp_table_activity.urn] = temp_table_activity + + target_dataset_cll: List[sqlglot_l.ColumnLineageInfo] = [ + ColumnLineageInfo( + downstream=DownstreamColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v6,PROD)", + column="price", + column_type=SchemaFieldDataTypeClass(type=NumberTypeClass()), + native_column_type="DOUBLE PRECISION", + ), + upstreams=[ + sqlglot_l.ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + column="price_usd", + ) + ], + logic=None, + ) + ] + + datasets = lineage_extractor._get_upstream_lineages( + sources=[ + LineageDataset( + platform=LineageDatasetPlatform.REDSHIFT, + urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", + ) + ], + target_table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v4,PROD)", + raw_db_name="dev", + alias_db_name="dev", + all_tables_set={ + "dev": { + "public": set(), + } + }, + connection=MagicMock(), + target_dataset_cll=target_dataset_cll, + ) + + assert len(datasets) == 1 + # Here we only interested if it fails or not From 0e418b527e64b9314c2a4da1df7794b129ac21cb Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Thu, 1 Feb 2024 16:33:15 -0800 Subject: [PATCH 2/2] fix(ingest): upgrade pytest-docker (#9765) --- metadata-ingestion/setup.py | 2 +- .../tests/test_helpers/docker_helpers.py | 20 +++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index f8d51997330a9..d4e2ada1fc68f 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -468,7 +468,7 @@ pytest_dep, "pytest-asyncio>=0.16.0", "pytest-cov>=2.8.1", - "pytest-docker>=1.0.1", + "pytest-docker>=1.1.0", deepdiff_dep, "requests-mock", "freezegun", diff --git a/metadata-ingestion/tests/test_helpers/docker_helpers.py b/metadata-ingestion/tests/test_helpers/docker_helpers.py index 2eb61068196a2..bacb8d80b9e72 100644 --- a/metadata-ingestion/tests/test_helpers/docker_helpers.py +++ b/metadata-ingestion/tests/test_helpers/docker_helpers.py @@ -2,7 +2,7 @@ import logging import os import subprocess -from typing import Callable, Optional, Union +from typing import Callable, Iterator, List, Optional, Union import pytest import pytest_docker.plugin @@ -37,9 +37,11 @@ def wait_for_port( docker_services.wait_until_responsive( timeout=timeout, pause=pause, - check=checker - if checker - else lambda: is_responsive(container_name, container_port, hostname), + check=( + checker + if checker + else lambda: is_responsive(container_name, container_port, hostname) + ), ) logger.info(f"Container {container_name} is ready!") finally: @@ -62,14 +64,16 @@ def docker_compose_runner( ): @contextlib.contextmanager def run( - compose_file_path: Union[str, list], key: str, cleanup: bool = True - ) -> pytest_docker.plugin.Services: + compose_file_path: Union[str, List[str]], key: str, cleanup: bool = True + ) -> Iterator[pytest_docker.plugin.Services]: with pytest_docker.plugin.get_docker_services( docker_compose_command=docker_compose_command, - docker_compose_file=compose_file_path, + # We can remove the type ignore once this is merged: + # https://github.com/avast/pytest-docker/pull/108 + docker_compose_file=compose_file_path, # type: ignore docker_compose_project_name=f"{docker_compose_project_name}-{key}", docker_setup=docker_setup, - docker_cleanup=docker_cleanup if cleanup else False, + docker_cleanup=docker_cleanup if cleanup else [], ) as docker_services: yield docker_services