Skip to content

Commit

Permalink
feat(ingest/snowflake): initialize schema resolver from datahub for l…
Browse files Browse the repository at this point in the history
…ineage-only ingestion
  • Loading branch information
mayurinehate committed Sep 26, 2023
1 parent 0a869dd commit 12a7bfe
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class SnowflakeV2Config(
)

include_view_column_lineage: bool = Field(
default=False,
description="Populates view->view and table->view column lineage.",
default=True,
description="Populates view->view and table->view column lineage using DataHub's sql parser.",
)

_check_role_grants_removed = pydantic_removed_field("check_role_grants")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,11 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config):
# Caches tables for a single database. Consider moving to disk or S3 when possible.
self.db_tables: Dict[str, List[SnowflakeTable]] = {}

self.sql_parser_schema_resolver = SchemaResolver(
platform=self.platform,
platform_instance=self.config.platform_instance,
env=self.config.env,
)
self.view_definitions: FileBackedDict[str] = FileBackedDict()
self.add_config_to_report()

self.sql_parser_schema_resolver = self._init_schema_resolver()

@classmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> "Source":
config = SnowflakeV2Config.parse_obj(config_dict)
Expand Down Expand Up @@ -493,6 +490,24 @@ def query(query):

return _report

def _init_schema_resolver(self) -> SchemaResolver:
if not self.config.include_technical_schema and self.config.parse_view_ddl:
if self.ctx.graph:
return self.ctx.graph.initialize_schema_resolver_from_datahub(
platform=self.platform,
platform_instance=self.config.platform_instance,
env=self.config.env,
)[0]
else:
logger.warning(
"Failed to load schema info from DataHub as DataHubGraph is missing.",
)
return SchemaResolver(
platform=self.platform,
platform_instance=self.config.platform_instance,
env=self.config.env,
)

def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
return [
*super().get_workunit_processors(),
Expand Down Expand Up @@ -764,7 +779,7 @@ def _process_schema(
)
self.db_tables[schema_name] = tables

if self.config.include_technical_schema or self.config.parse_view_ddl:
if self.config.include_technical_schema:
for table in tables:
yield from self._process_table(table, schema_name, db_name)

Expand All @@ -776,7 +791,7 @@ def _process_schema(
if view.view_definition:
self.view_definitions[key] = view.view_definition

if self.config.include_technical_schema or self.config.parse_view_ddl:
if self.config.include_technical_schema:
for view in views:
yield from self._process_view(view, schema_name, db_name)

Expand Down Expand Up @@ -892,8 +907,6 @@ def _process_table(
yield from self._process_tag(tag)

yield from self.gen_dataset_workunits(table, schema_name, db_name)
elif self.config.parse_view_ddl:
self.gen_schema_metadata(table, schema_name, db_name)

def fetch_sample_data_for_classification(
self, table: SnowflakeTable, schema_name: str, db_name: str, dataset_name: str
Expand Down Expand Up @@ -1004,8 +1017,6 @@ def _process_view(
yield from self._process_tag(tag)

yield from self.gen_dataset_workunits(view, schema_name, db_name)
elif self.config.parse_view_ddl:
self.gen_schema_metadata(view, schema_name, db_name)

def _process_tag(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]:
tag_identifier = tag.identifier()
Expand Down

0 comments on commit 12a7bfe

Please sign in to comment.