diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py index 33580d9637de48..d71c7d21dbc87e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py @@ -20,12 +20,12 @@ import datahub.emitter.mce_builder as builder from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.emitter.sql_parsing_builder import SqlParsingBuilder from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.aws.s3_util import make_s3_urn from datahub.ingestion.source.snowflake.constants import ( LINEAGE_PERMISSION_ERROR, SnowflakeEdition, - SnowflakeObjectDomain, ) from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery @@ -53,7 +53,6 @@ sqlglot_lineage, ) from datahub.utilities.time import ts_millis_to_datetime -from datahub.utilities.urns.dataset_urn import DatasetUrn logger: logging.Logger = logging.getLogger(__name__) @@ -196,19 +195,6 @@ def get_table_upstream_workunits( f"Upstream lineage detected for {self.report.num_tables_with_upstreams} tables.", ) - def _gen_workunit_from_sql_parsing_result( - self, - dataset_identifier: str, - result: SqlParsingResult, - ) -> MetadataWorkUnit: - upstreams, fine_upstreams = self.get_upstreams_from_sql_parsing_result( - self.dataset_urn_builder(dataset_identifier), result - ) - self.report.num_views_with_upstreams += 1 - return self._create_upstream_lineage_workunit( - dataset_identifier, upstreams, fine_upstreams - ) - def _gen_workunits_from_query_result( self, discovered_assets: Collection[str], @@ -242,18 +228,31 @@ def get_view_upstream_workunits( schema_resolver: SchemaResolver, view_definitions: MutableMapping[str, str], ) -> Iterable[MetadataWorkUnit]: - views_processed = set() + views_failed_parsing = set() if self.config.include_view_column_lineage: with PerfTimer() as timer: + builder = SqlParsingBuilder( + generate_lineage=True, + generate_usage_statistics=False, + generate_operations=False, + ) for view_identifier, view_definition in view_definitions.items(): result = self._run_sql_parser( view_identifier, view_definition, schema_resolver ) - if result: - views_processed.add(view_identifier) - yield self._gen_workunit_from_sql_parsing_result( - view_identifier, result + if result and result.out_tables: + self.report.num_views_with_upstreams += 1 + # This does not yield any workunits but we use + # yield here to execute this method + yield from builder.process_sql_parsing_result( + result=result, + query=view_definition, + is_view_ddl=True, ) + else: + views_failed_parsing.add(view_identifier) + + yield from builder.gen_workunits() self.report.view_lineage_parse_secs = timer.elapsed_seconds() with PerfTimer() as timer: @@ -261,7 +260,7 @@ def get_view_upstream_workunits( if results: yield from self._gen_workunits_from_query_result( - set(discovered_views) - views_processed, + views_failed_parsing, results, upstream_for_view=True, ) @@ -349,47 +348,6 @@ def get_upstreams_from_query_result_row( return upstreams, fine_upstreams - def get_upstreams_from_sql_parsing_result( - self, downstream_table_urn: str, result: SqlParsingResult - ) -> Tuple[List[UpstreamClass], List[FineGrainedLineage]]: - # Note: This ignores the out_tables section of the sql parsing result. - upstreams = [ - UpstreamClass(dataset=upstream_table_urn, type=DatasetLineageTypeClass.VIEW) - for upstream_table_urn in set(result.in_tables) - ] - - # Maps downstream_col -> [upstream_col] - fine_lineage: Dict[str, Set[SnowflakeColumnId]] = defaultdict(set) - for column_lineage in result.column_lineage or []: - out_column = column_lineage.downstream.column - for upstream_column_info in column_lineage.upstreams: - upstream_table_id = DatasetUrn.create_from_string( - upstream_column_info.table - ).get_dataset_name() - if self.config.platform_instance and upstream_table_id.startswith( - f"{self.config.platform_instance}." - ): - upstream_table_name = upstream_table_id[ - len(f"{self.config.platform_instance}.") : - ] - else: - upstream_table_name = upstream_table_id - fine_lineage[out_column].add( - SnowflakeColumnId( - columnName=upstream_column_info.column, - objectName=upstream_table_name, - objectDomain=SnowflakeObjectDomain.VIEW.value, - ) - ) - fine_upstreams = [ - self.build_finegrained_lineage( - downstream_table_urn, downstream_col, upstream_cols - ) - for downstream_col, upstream_cols in fine_lineage.items() - ] - - return upstreams, list(filter(None, fine_upstreams)) - def _populate_external_lineage_map(self, discovered_tables: List[str]) -> None: with PerfTimer() as timer: self.report.num_external_table_edges_scanned = 0