From 47bb20d5d541b91d78084273b7d72cdfec835ddd Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Wed, 26 Jul 2023 19:33:37 +0530 Subject: [PATCH 01/12] feat(ingest/snowflake): tables from snowflake shares as siblings 1. add config maps to accept inbound and outbound share details from user 2. emit mirrored database tables as siblings, with tables from share owner(producer) account as primary sibling. 3. push down allow-deny patterns in snowflake_schema --- .../source/snowflake/snowflake_config.py | 21 ++- .../source/snowflake/snowflake_schema.py | 142 ++++++++++++------ .../source/snowflake/snowflake_shares.py | 134 +++++++++++++++++ .../source/snowflake/snowflake_v2.py | 44 +----- 4 files changed, 258 insertions(+), 83 deletions(-) create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 79bf538af91d2..b56b2dc549bc2 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -4,7 +4,7 @@ from pydantic import Field, SecretStr, root_validator, validator -from datahub.configuration.common import AllowDenyPattern +from datahub.configuration.common import AllowDenyPattern, ConfigModel from datahub.configuration.pattern_utils import UUID_REGEX from datahub.configuration.validate_field_removal import pydantic_removed_field from datahub.configuration.validate_field_rename import pydantic_renamed_field @@ -42,6 +42,11 @@ class TagOption(str, Enum): skip = "skip" +class SnowflakeDatabaseDataHubId(ConfigModel): + platform_instance: str + database_name: str + + class SnowflakeV2Config( SnowflakeConfig, SnowflakeUsageConfig, @@ -120,6 +125,20 @@ class SnowflakeV2Config( "upstreams_deny_pattern", "temporary_tables_pattern" ) + inbound_shares_map: Optional[Dict[str, SnowflakeDatabaseDataHubId]] = Field( + default=None, + description="Required if the current account has any database created from inbound snowflake share." + " If specified, connector creates lineage and siblings relationship between current account's database tables and original database tables from which snowflake share was created." + " Map of database name -> (platform instance of snowflake account containing original database, original database name).", + ) + + outbound_shares_map: Optional[Dict[str, List[SnowflakeDatabaseDataHubId]]] = Field( + default=None, + description="Required if the current account has created any outbound snowflake shares and there is at least one consumer account in which database is created from such share." + " If specified, connector creates siblings relationship between current account's database tables and all database tables created in consumer accounts from the share including current account's database." + " Map of database name X -> list of (platform instance of snowflake consumer account who've created database from share, name of database created from share) for all shares created from database name X.", + ) + @validator("include_column_lineage") def validate_include_column_lineage(cls, v, values): if not values.get("include_table_lineage") and v: diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py index dab46645bffcc..298a8dbde48c8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py @@ -8,9 +8,15 @@ import pandas as pd from snowflake.connector import SnowflakeConnection +from datahub.configuration.pattern_utils import is_schema_allowed from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain +from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery -from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeQueryMixin +from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report +from datahub.ingestion.source.snowflake.snowflake_utils import ( + SnowflakeCommonMixin, + SnowflakeQueryMixin, +) from datahub.ingestion.source.sql.sql_generic import BaseColumn, BaseTable, BaseView logger: logging.Logger = logging.getLogger(__name__) @@ -177,8 +183,10 @@ def get_column_tags_for_table( ) -class SnowflakeDataDictionary(SnowflakeQueryMixin): - def __init__(self) -> None: +class SnowflakeDataDictionary(SnowflakeQueryMixin, SnowflakeCommonMixin): + def __init__(self, config: SnowflakeV2Config, report: SnowflakeV2Report) -> None: + self.config = config + self.report = report self.logger = logger self.connection: Optional[SnowflakeConnection] = None @@ -221,7 +229,11 @@ def get_databases(self, db_name: str) -> List[SnowflakeDatabase]: last_altered=database["LAST_ALTERED"], comment=database["COMMENT"], ) - databases.append(snowflake_db) + self.report.report_entity_scanned(snowflake_db.name, "database") + if not self.config.database_pattern.allowed(snowflake_db.name): + self.report.report_dropped(f"{snowflake_db.name}.*") + else: + databases.append(snowflake_db) return databases @@ -239,7 +251,16 @@ def get_schemas_for_database(self, db_name: str) -> List[SnowflakeSchema]: last_altered=schema["LAST_ALTERED"], comment=schema["COMMENT"], ) - snowflake_schemas.append(snowflake_schema) + self.report.report_entity_scanned(snowflake_schema.name, "schema") + if not is_schema_allowed( + self.config.schema_pattern, + snowflake_schema.name, + db_name, + self.config.match_fully_qualified_names, + ): + self.report.report_dropped(f"{db_name}.{snowflake_schema.name}.*") + else: + snowflake_schemas.append(snowflake_schema) return snowflake_schemas @lru_cache(maxsize=1) @@ -261,17 +282,25 @@ def get_tables_for_database( for table in cur: if table["TABLE_SCHEMA"] not in tables: tables[table["TABLE_SCHEMA"]] = [] - tables[table["TABLE_SCHEMA"]].append( - SnowflakeTable( - name=table["TABLE_NAME"], - created=table["CREATED"], - last_altered=table["LAST_ALTERED"], - size_in_bytes=table["BYTES"], - rows_count=table["ROW_COUNT"], - comment=table["COMMENT"], - clustering_key=table["CLUSTERING_KEY"], - ) + + table_identifier = self.get_dataset_identifier( + table["TABLE_NAME"], table["TABLE_SCHEMA"], db_name ) + self.report.report_entity_scanned(table_identifier) + if not self.config.table_pattern.allowed(table_identifier): + self.report.report_dropped(table_identifier) + else: + tables[table["TABLE_SCHEMA"]].append( + SnowflakeTable( + name=table["TABLE_NAME"], + created=table["CREATED"], + last_altered=table["LAST_ALTERED"], + size_in_bytes=table["BYTES"], + rows_count=table["ROW_COUNT"], + comment=table["COMMENT"], + clustering_key=table["CLUSTERING_KEY"], + ) + ) return tables def get_tables_for_schema( @@ -284,17 +313,24 @@ def get_tables_for_schema( ) for table in cur: - tables.append( - SnowflakeTable( - name=table["TABLE_NAME"], - created=table["CREATED"], - last_altered=table["LAST_ALTERED"], - size_in_bytes=table["BYTES"], - rows_count=table["ROW_COUNT"], - comment=table["COMMENT"], - clustering_key=table["CLUSTERING_KEY"], - ) + table_identifier = self.get_dataset_identifier( + table["TABLE_NAME"], schema_name, db_name ) + self.report.report_entity_scanned(table_identifier) + if not self.config.table_pattern.allowed(table_identifier): + self.report.report_dropped(table_identifier) + else: + tables.append( + SnowflakeTable( + name=table["TABLE_NAME"], + created=table["CREATED"], + last_altered=table["LAST_ALTERED"], + size_in_bytes=table["BYTES"], + rows_count=table["ROW_COUNT"], + comment=table["COMMENT"], + clustering_key=table["CLUSTERING_KEY"], + ) + ) return tables @lru_cache(maxsize=1) @@ -314,18 +350,27 @@ def get_views_for_database( for table in cur: if table["schema_name"] not in views: views[table["schema_name"]] = [] - views[table["schema_name"]].append( - SnowflakeView( - name=table["name"], - created=table["created_on"], - # last_altered=table["last_altered"], - comment=table["comment"], - view_definition=table["text"], - last_altered=table["created_on"], - materialized=table.get("is_materialized", "false").lower() - == "true", - ) + view_name = self.get_dataset_identifier( + table["name"], table["schema_name"], db_name ) + + self.report.report_entity_scanned(view_name, "view") + + if not self.config.view_pattern.allowed(view_name): + self.report.report_dropped(view_name) + else: + views[table["schema_name"]].append( + SnowflakeView( + name=table["name"], + created=table["created_on"], + # last_altered=table["last_altered"], + comment=table["comment"], + view_definition=table["text"], + last_altered=table["created_on"], + materialized=table.get("is_materialized", "false").lower() + == "true", + ) + ) return views def get_views_for_schema( @@ -335,16 +380,23 @@ def get_views_for_schema( cur = self.query(SnowflakeQuery.show_views_for_schema(schema_name, db_name)) for table in cur: - views.append( - SnowflakeView( - name=table["name"], - created=table["created_on"], - # last_altered=table["last_altered"], - comment=table["comment"], - view_definition=table["text"], - last_altered=table["created_on"], + view_name = self.get_dataset_identifier(table["name"], schema_name, db_name) + + self.report.report_entity_scanned(view_name, "view") + + if not self.config.view_pattern.allowed(view_name): + self.report.report_dropped(view_name) + else: + views.append( + SnowflakeView( + name=table["name"], + created=table["created_on"], + # last_altered=table["last_altered"], + comment=table["comment"], + view_definition=table["text"], + last_altered=table["created_on"], + ) ) - ) return views @lru_cache(maxsize=1) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py new file mode 100644 index 0000000000000..14a67a3f98578 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -0,0 +1,134 @@ +import logging +from typing import Callable, Iterable, List + +from datahub.emitter.mce_builder import make_dataset_urn_with_platform_instance +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.source.snowflake.snowflake_config import ( + SnowflakeDatabaseDataHubId, + SnowflakeV2Config, +) +from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report +from datahub.ingestion.source.snowflake.snowflake_schema import SnowflakeDatabase +from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeCommonMixin +from datahub.metadata.com.linkedin.pegasus2avro.common import Siblings +from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( + DatasetLineageType, + Upstream, + UpstreamLineage, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +class SnowflakeSharesHandler(SnowflakeCommonMixin): + def __init__( + self, + config: SnowflakeV2Config, + report: SnowflakeV2Report, + dataset_urn_builder: Callable[[str], str], + ) -> None: + self.config = config + self.report = report + self.logger = logger + self.dataset_urn_builder = dataset_urn_builder + + def get_workunits( + self, databases: List[SnowflakeDatabase] + ) -> Iterable[MetadataWorkUnit]: + for db in databases: + inbound = ( + self.config.inbound_shares_map.get(db.name) + if self.config.inbound_shares_map + and db.name in self.config.inbound_shares_map + else None + ) + outbounds = ( + self.config.outbound_shares_map.get(db.name) + if self.config.outbound_shares_map + and db.name in self.config.outbound_shares_map + else None + ) + sibling_dbs: List[SnowflakeDatabaseDataHubId] + if inbound: + sibling_dbs = [inbound] + elif outbounds: + sibling_dbs = outbounds + else: + continue + # TODO: pydantic validator to check that database key in inbound_shares_map is not present in outbound_shares_map + # TODO: logger statements and exception handling + for schema in db.schemas: + for table_name in schema.tables: + yield from self.get_siblings( + db.name, + schema.name, + table_name, + True if outbounds else False, + sibling_dbs, + ) + + if inbound: + # TODO: Should this be governed by any config flag ? Should this be part of lineage_extractor ? + yield self.get_upstream_lineage_with_primary_sibling( + db.name, schema.name, table_name, inbound + ) + + def get_siblings( + self, + database_name: str, + schema_name: str, + table_name: str, + primary: bool, + sibling_databases: List[SnowflakeDatabaseDataHubId], + ) -> Iterable[MetadataWorkUnit]: + if not sibling_databases: + return + dataset_identifier = self.get_dataset_identifier( + table_name, schema_name, database_name + ) + urn = self.dataset_urn_builder(dataset_identifier) + + sibling_urns = [ + make_dataset_urn_with_platform_instance( + self.platform, + self.get_dataset_identifier( + table_name, schema_name, sibling_db.database_name + ), + sibling_db.platform_instance, + ) + for sibling_db in sibling_databases + ] + + sibling_urns.append(urn) + + yield MetadataChangeProposalWrapper( + entityUrn=urn, aspect=Siblings(primary=primary, siblings=sibling_urns) + ).as_workunit() + + def get_upstream_lineage_with_primary_sibling( + self, + database_name: str, + schema_name: str, + table_name: str, + primary_sibling_db: SnowflakeDatabaseDataHubId, + ) -> MetadataWorkUnit: + dataset_identifier = self.get_dataset_identifier( + table_name, schema_name, database_name + ) + urn = self.dataset_urn_builder(dataset_identifier) + + upstream_urn = make_dataset_urn_with_platform_instance( + self.platform, + self.get_dataset_identifier( + table_name, schema_name, primary_sibling_db.database_name + ), + primary_sibling_db.platform_instance, + ) + + return MetadataChangeProposalWrapper( + entityUrn=urn, + aspect=UpstreamLineage( + upstreams=[Upstream(dataset=upstream_urn, type=DatasetLineageType.COPY)] + ), + ).as_workunit() diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index bd3e5782ec2af..46e0b04688547 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -9,7 +9,6 @@ import pandas as pd from snowflake.connector import SnowflakeConnection -from datahub.configuration.pattern_utils import is_schema_allowed from datahub.emitter.mce_builder import ( make_data_platform_urn, make_dataset_urn, @@ -71,6 +70,7 @@ SnowflakeTag, SnowflakeView, ) +from datahub.ingestion.source.snowflake.snowflake_shares import SnowflakeSharesHandler from datahub.ingestion.source.snowflake.snowflake_tag import SnowflakeTagExtractor from datahub.ingestion.source.snowflake.snowflake_usage_v2 import ( SnowflakeUsageExtractor, @@ -238,7 +238,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): ) # For database, schema, tables, views, etc - self.data_dictionary = SnowflakeDataDictionary() + self.data_dictionary = SnowflakeDataDictionary(self.config, self.report) self.lineage_extractor: Union[ SnowflakeLineageExtractor, SnowflakeLineageLegacyExtractor @@ -532,25 +532,22 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # TODO: The checkpoint state for stale entity detection can be committed here. + if self.config.inbound_shares_map or self.config.outbound_shares_map: + yield from SnowflakeSharesHandler( + self.config, self.report, self.gen_dataset_urn + ).get_workunits(databases) + discovered_tables: List[str] = [ self.get_dataset_identifier(table_name, schema.name, db.name) for db in databases for schema in db.schemas for table_name in schema.tables - if self._is_dataset_pattern_allowed( - self.get_dataset_identifier(table_name, schema.name, db.name), - SnowflakeObjectDomain.TABLE, - ) ] discovered_views: List[str] = [ self.get_dataset_identifier(table_name, schema.name, db.name) for db in databases for schema in db.schemas for table_name in schema.views - if self._is_dataset_pattern_allowed( - self.get_dataset_identifier(table_name, schema.name, db.name), - SnowflakeObjectDomain.VIEW, - ) ] if len(discovered_tables) == 0 and len(discovered_views) == 0: @@ -654,11 +651,6 @@ def get_databases_from_ischema(self, databases): def _process_database( self, snowflake_db: SnowflakeDatabase ) -> Iterable[MetadataWorkUnit]: - self.report.report_entity_scanned(snowflake_db.name, "database") - if not self.config.database_pattern.allowed(snowflake_db.name): - self.report.report_dropped(f"{snowflake_db.name}.*") - return - db_name = snowflake_db.name try: @@ -733,16 +725,6 @@ def fetch_schemas_for_database(self, snowflake_db, db_name): def _process_schema( self, snowflake_schema: SnowflakeSchema, db_name: str ) -> Iterable[MetadataWorkUnit]: - self.report.report_entity_scanned(snowflake_schema.name, "schema") - if not is_schema_allowed( - self.config.schema_pattern, - snowflake_schema.name, - db_name, - self.config.match_fully_qualified_names, - ): - self.report.report_dropped(f"{db_name}.{snowflake_schema.name}.*") - return - schema_name = snowflake_schema.name if self.config.extract_tags != TagOption.skip: @@ -833,12 +815,6 @@ def _process_table( ) -> Iterable[MetadataWorkUnit]: table_identifier = self.get_dataset_identifier(table.name, schema_name, db_name) - self.report.report_entity_scanned(table_identifier) - - if not self.config.table_pattern.allowed(table_identifier): - self.report.report_dropped(table_identifier) - return - self.fetch_columns_for_table(table, schema_name, db_name, table_identifier) self.fetch_pk_for_table(table, schema_name, db_name, table_identifier) @@ -950,12 +926,6 @@ def _process_view( ) -> Iterable[MetadataWorkUnit]: view_name = self.get_dataset_identifier(view.name, schema_name, db_name) - self.report.report_entity_scanned(view_name, "view") - - if not self.config.view_pattern.allowed(view_name): - self.report.report_dropped(view_name) - return - try: view.columns = self.get_columns_for_table(view.name, schema_name, db_name) if self.config.extract_tags != TagOption.skip: From 752b2dd3a8085072d60e4c97770d4bda39f132a0 Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Thu, 27 Jul 2023 16:43:24 +0530 Subject: [PATCH 02/12] add unit tests, docs, logging --- .../docs/sources/snowflake/snowflake_pre.md | 27 ++ .../source/snowflake/snowflake_config.py | 71 +++- .../source/snowflake/snowflake_shares.py | 69 +++- .../tests/unit/test_snowflake_shares.py | 359 ++++++++++++++++++ 4 files changed, 504 insertions(+), 22 deletions(-) create mode 100644 metadata-ingestion/tests/unit/test_snowflake_shares.py diff --git a/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md b/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md index 9a381fb351aec..7768ab89fdb88 100644 --- a/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md +++ b/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md @@ -99,6 +99,33 @@ The steps slightly differ based on which you decide to use. including `client_id` and `client_secret`, plus your Okta user's `Username` and `Password` * Note: the `username` and `password` config options are not nested under `oauth_config` +### Snowflake Shares +If you are using [Snowflake Shares](https://docs.snowflake.com/en/user-guide/data-sharing-provider) to share data across different snowflake accounts, and you have set up DataHub recipes for ingesting metadata from all these accounts, you may end up having multiple similar dataset entities corresponding to virtual versions of same table in different snowflake accounts. DataHub Snowflake connector can automatically link such tables together through Siblings and Lineage relationship if user provides information necessary to establish the relationship using configurations `inbound_shares_map` and `outbound_shares_map` in recipe. + +#### Example +- Snowflake account `account1` (ingested as platform_instance `instance1`) owns a database `db1`. A share `X` is created in `account1` that includes database `db1` along with schemas and tables inside it. +- Now, `X` is shared with snowflake account `account2` (ingested as platform_instance `instance2`). A database `db1_from_X` is created from inbound share `X` in `account2`. +- In this case, all tables and views included in share `X` will also be present in `instance2`.`db1_from_X`. You would need following configurations in snowflake recipe to setup Siblings and Lineage relationships correctly. +- In snowflake recipe of `account1` : + + ```yaml + account_id: account1 + platform_instance: instance1 + outbound_shares_map: + db1: + - platform_instance: instance2 # this is a list, as db1 can be shared with multiple snowflake accounts using X + database_name: db1_from_X + ``` +- In snowflake recipe of `account2` : + + ```yaml + account_id: account2 + platform_instance: instance2 + inbound_shares_map: + db1_from_X: + platform_instance: instance1 + database_name: db1 + ``` ### Caveats - Some of the features are only available in the Snowflake Enterprise Edition. This doc has notes mentioning where this applies. diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index b56b2dc549bc2..3f0d323d58586 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -1,10 +1,11 @@ import logging +from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, cast +from typing import Dict, List, Optional, Sequence, cast from pydantic import Field, SecretStr, root_validator, validator -from datahub.configuration.common import AllowDenyPattern, ConfigModel +from datahub.configuration.common import AllowDenyPattern from datahub.configuration.pattern_utils import UUID_REGEX from datahub.configuration.validate_field_removal import pydantic_removed_field from datahub.configuration.validate_field_rename import pydantic_renamed_field @@ -42,7 +43,8 @@ class TagOption(str, Enum): skip = "skip" -class SnowflakeDatabaseDataHubId(ConfigModel): +@dataclass(frozen=True) +class SnowflakeDatabaseDataHubId: platform_instance: str database_name: str @@ -211,3 +213,66 @@ def get_sql_alchemy_url( @property def parse_view_ddl(self) -> bool: return self.include_view_column_lineage + + @root_validator(pre=False) + def validator_inbound_outbound_shares_map(cls, values: Dict) -> Dict: + inbound_shares_map: Dict[str, SnowflakeDatabaseDataHubId] = ( + values.get("inbound_shares_map") or {} + ) + outbound_shares_map: Dict[str, List[SnowflakeDatabaseDataHubId]] = ( + values.get("outbound_shares_map") or {} + ) + + # Check: same database from current instance as inbound and outbound + common_keys = [key for key in inbound_shares_map if key in outbound_shares_map] + + assert ( + len(common_keys) == 0 + ), "Same database can not be present in both `inbound_shares_map` and `outbound_shares_map`." + + current_platform_instance = values.get("platform_instance") + + # Check: current platform_instance present as inbound and outbound + if current_platform_instance and any( + [ + db.platform_instance == current_platform_instance + for db in inbound_shares_map.values() + ] + ): + raise ValueError( + "Current `platform_instance` can not be present as any database in `inbound_shares_map`." + "Self-sharing not supported in Snowflake. Please check your configuration." + ) + + if current_platform_instance and any( + [ + db.platform_instance == current_platform_instance + for dbs in outbound_shares_map.values() + for db in dbs + ] + ): + raise ValueError( + "Current `platform_instance` can not be present as any database in `outbound_shares_map`." + "Self-sharing not supported in Snowflake. Please check your configuration." + ) + + # Check: platform_instance should be present + if ( + inbound_shares_map or outbound_shares_map + ) and not current_platform_instance: + logger.warn( + "Did you forget to set `platform_instance` for current ingestion ?" + "It is advisable to use `platform_instance` when ingesting from multiple snowflake accounts." + ) + + # Check: same database from some platform instance as inbound and outbound + other_platform_instance_databases: Sequence[SnowflakeDatabaseDataHubId] = [ + db for db in inbound_shares_map.values() + ] + [db for dbs in outbound_shares_map.values() for db in dbs] + + for other_instance_db in other_platform_instance_databases: + assert ( + other_platform_instance_databases.count(other_instance_db) == 1 + ), "A database can exist only once either in `inbound_shares_map` or in `outbound_shares_map`." + + return values diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py index 14a67a3f98578..4961e92be71d1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -36,30 +36,24 @@ def __init__( def get_workunits( self, databases: List[SnowflakeDatabase] ) -> Iterable[MetadataWorkUnit]: + logger.debug("Checking databases for inbound or outbound shares.") for db in databases: - inbound = ( - self.config.inbound_shares_map.get(db.name) - if self.config.inbound_shares_map - and db.name in self.config.inbound_shares_map - else None - ) - outbounds = ( - self.config.outbound_shares_map.get(db.name) - if self.config.outbound_shares_map - and db.name in self.config.outbound_shares_map - else None - ) + inbound, outbounds = self.get_sharing_details(db) + + if not (inbound or outbounds): + logger.debug(f"database {db.name} is not a shared.") + continue + sibling_dbs: List[SnowflakeDatabaseDataHubId] if inbound: sibling_dbs = [inbound] - elif outbounds: + logger.debug(f"database {db.name} is created from inbound share.") + else: # outbounds sibling_dbs = outbounds - else: - continue - # TODO: pydantic validator to check that database key in inbound_shares_map is not present in outbound_shares_map - # TODO: logger statements and exception handling + logger.debug(f"database {db.name} is shared as outbound share.") + for schema in db.schemas: - for table_name in schema.tables: + for table_name in schema.tables + schema.views: yield from self.get_siblings( db.name, schema.name, @@ -69,11 +63,48 @@ def get_workunits( ) if inbound: - # TODO: Should this be governed by any config flag ? Should this be part of lineage_extractor ? + # SnowflakeLineageExtractor is unaware of database->schema->table hierarchy + # hence this lineage code is not written in SnowflakeLineageExtractor + # also this is not governed by configs include_table_lineage and include_view_lineage yield self.get_upstream_lineage_with_primary_sibling( db.name, schema.name, table_name, inbound ) + self.report_missing_databases(databases) + + def get_sharing_details(self, db): + inbound = ( + self.config.inbound_shares_map.get(db.name) + if self.config.inbound_shares_map + and db.name in self.config.inbound_shares_map + else None + ) + outbounds = ( + self.config.outbound_shares_map.get(db.name) + if self.config.outbound_shares_map + and db.name in self.config.outbound_shares_map + else None + ) + + return inbound, outbounds + + def report_missing_databases(self, databases): + db_names = [db.name for db in databases] + missing_dbs = [] + if self.config.inbound_shares_map: + missing_dbs.extend( + [db for db in self.config.inbound_shares_map if db not in db_names] + ) + if self.config.outbound_shares_map: + missing_dbs.extend( + [db for db in self.config.outbound_shares_map if db not in db_names] + ) + if missing_dbs: + self.report_warning( + "snowflake-shares", + f"Databases {missing_dbs} were not ingested. Siblings/Lineage will not be set for these.", + ) + def get_siblings( self, database_name: str, diff --git a/metadata-ingestion/tests/unit/test_snowflake_shares.py b/metadata-ingestion/tests/unit/test_snowflake_shares.py new file mode 100644 index 0000000000000..d80cd2dc07ca1 --- /dev/null +++ b/metadata-ingestion/tests/unit/test_snowflake_shares.py @@ -0,0 +1,359 @@ +from typing import List + +import pytest + +from datahub.emitter.mce_builder import make_dataset_urn_with_platform_instance +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.source.snowflake.snowflake_config import ( + SnowflakeDatabaseDataHubId, + SnowflakeV2Config, +) +from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report +from datahub.ingestion.source.snowflake.snowflake_schema import ( + SnowflakeDatabase, + SnowflakeSchema, +) +from datahub.ingestion.source.snowflake.snowflake_shares import SnowflakeSharesHandler +from datahub.metadata.com.linkedin.pegasus2avro.common import Siblings +from datahub.metadata.com.linkedin.pegasus2avro.dataset import UpstreamLineage +from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeProposal + + +@pytest.fixture(scope="module") +def snowflake_databases() -> List[SnowflakeDatabase]: + return [ + SnowflakeDatabase( + name="db1", + created=None, + comment=None, + last_altered=None, + schemas=[ + SnowflakeSchema( + name="schema11", + created=None, + comment=None, + last_altered=None, + tables=["table111", "table112"], + views=["view111"], + ), + SnowflakeSchema( + name="schema12", + created=None, + comment=None, + last_altered=None, + tables=["table121", "table122"], + views=["view121"], + ), + ], + ), + SnowflakeDatabase( + name="db2", + created=None, + comment=None, + last_altered=None, + schemas=[ + SnowflakeSchema( + name="schema21", + created=None, + comment=None, + last_altered=None, + tables=["table211", "table212"], + views=["view211"], + ), + SnowflakeSchema( + name="schema22", + created=None, + comment=None, + last_altered=None, + tables=["table221", "table222"], + views=["view221"], + ), + ], + ), + SnowflakeDatabase( + name="db3", + created=None, + comment=None, + last_altered=None, + schemas=[ + SnowflakeSchema( + name="schema31", + created=None, + comment=None, + last_altered=None, + tables=["table311", "table312"], + views=["view311"], + ) + ], + ), + ] + + +def make_snowflake_urn(table_name, instance_name=None): + return make_dataset_urn_with_platform_instance( + "snowflake", table_name, instance_name + ) + + +def test_snowflake_shares_workunit_no_shares( + snowflake_databases: List[SnowflakeDatabase], +) -> None: + config = SnowflakeV2Config(account_id="abc12345", platform_instance="instance1") + + report = SnowflakeV2Report() + shares_handler = SnowflakeSharesHandler( + config, report, lambda x: make_snowflake_urn(x) + ) + + wus = list(shares_handler.get_workunits(snowflake_databases)) + + assert len(wus) == 0 + + +def test_same_database_inbound_and_outbound_invalid_config() -> None: + with pytest.raises(ValueError): + SnowflakeV2Config( + account_id="abc12345", + platform_instance="instance1", + inbound_shares_map={ + "db1": SnowflakeDatabaseDataHubId( + database_name="original_db1", platform_instance="instance2" + ) + }, + outbound_shares_map={ + "db1": [ + SnowflakeDatabaseDataHubId( + database_name="db2_from_share", platform_instance="instance2" + ), + SnowflakeDatabaseDataHubId( + database_name="db2", platform_instance="instance3" + ), + ] + }, + ) + + +def test_current_platform_instance_inbound_and_outbound_invalid() -> None: + with pytest.raises(ValueError): + SnowflakeV2Config( + account_id="abc12345", + platform_instance="instance1", + inbound_shares_map={ + "db1": SnowflakeDatabaseDataHubId( + database_name="db2", platform_instance="instance1" + ) + }, + outbound_shares_map={ + "db2": [ + SnowflakeDatabaseDataHubId( + database_name="db2_from_share", platform_instance="instance2" + ), + SnowflakeDatabaseDataHubId( + database_name="db2", platform_instance="instance3" + ), + ] + }, + ) + + with pytest.raises(ValueError): + SnowflakeV2Config( + account_id="abc12345", + platform_instance="instance1", + inbound_shares_map={ + "db1": SnowflakeDatabaseDataHubId( + database_name="db2", platform_instance="instance2" + ) + }, + outbound_shares_map={ + "db2": [ + SnowflakeDatabaseDataHubId( + database_name="db2_from_share", platform_instance="instance2" + ), + SnowflakeDatabaseDataHubId( + database_name="db2", platform_instance="instance1" + ), + ] + }, + ) + + +def test_another_instance_database_inbound_and_outbound_invalid() -> None: + with pytest.raises(ValueError): + SnowflakeV2Config( + account_id="abc12345", + platform_instance="instance1", + inbound_shares_map={ + "db1": SnowflakeDatabaseDataHubId( + database_name="db2", platform_instance="instance3" + ) + }, + outbound_shares_map={ + "db2": [ + SnowflakeDatabaseDataHubId( + database_name="db2_from_share", platform_instance="instance2" + ), + SnowflakeDatabaseDataHubId( + database_name="db2", platform_instance="instance3" + ), + ] + }, + ) + + +def test_snowflake_shares_workunit_inbound_share( + snowflake_databases: List[SnowflakeDatabase], +) -> None: + config = SnowflakeV2Config( + account_id="abc12345", + platform_instance="instance1", + inbound_shares_map={ + "db1": SnowflakeDatabaseDataHubId( + database_name="original_db1", platform_instance="instance2" + ) + }, + ) + + report = SnowflakeV2Report() + shares_handler = SnowflakeSharesHandler( + config, report, lambda x: make_snowflake_urn(x, "instance1") + ) + + wus = list(shares_handler.get_workunits(snowflake_databases)) + + # 2 schemas - 2 tables and 1 view in each schema making total 6 datasets + # Hence 6 Sibling and 6 upstreamLineage aspects + assert len(wus) == 12 + upstream_lineage_aspect_entity_urns = set() + sibling_aspect_entity_urns = set() + + for wu in wus: + assert isinstance( + wu.metadata, (MetadataChangeProposal, MetadataChangeProposalWrapper) + ) + if wu.metadata.aspectName == "upstreamLineage": + upstream_aspect = wu.get_aspect_of_type(UpstreamLineage) + assert upstream_aspect is not None + upstream_lineage_aspect_entity_urns.add(wu.get_urn()) + assert len(upstream_aspect.upstreams) == 1 + assert upstream_aspect.upstreams[0].dataset == wu.get_urn().replace( + "instance1.db1", "instance2.original_db1" + ) + else: + siblings_aspect = wu.get_aspect_of_type(Siblings) + assert siblings_aspect is not None + sibling_aspect_entity_urns.add(wu.get_urn()) + assert len(siblings_aspect.siblings) == 2 # upstream and itself + assert siblings_aspect.siblings == [ + wu.get_urn().replace("instance1.db1", "instance2.original_db1"), + wu.get_urn(), + ] + + assert len(upstream_lineage_aspect_entity_urns) == 6 + assert len(sibling_aspect_entity_urns) == 6 + + +def test_snowflake_shares_workunit_outbound_share( + snowflake_databases: List[SnowflakeDatabase], +) -> None: + config = SnowflakeV2Config( + account_id="abc12345", + platform_instance="instance1", + outbound_shares_map={ + "db2": [ + SnowflakeDatabaseDataHubId( + database_name="db2_from_share", platform_instance="instance2" + ), + SnowflakeDatabaseDataHubId( + database_name="db2", platform_instance="instance3" + ), + ] + }, + ) + + report = SnowflakeV2Report() + shares_handler = SnowflakeSharesHandler( + config, report, lambda x: make_snowflake_urn(x, "instance1") + ) + + wus = list(shares_handler.get_workunits(snowflake_databases)) + + # 2 schemas - 2 tables and 1 view in each schema making total 6 datasets + # Hence 6 Sibling aspects + assert len(wus) == 6 + entity_urns = set() + + for wu in wus: + siblings_aspect = wu.get_aspect_of_type(Siblings) + assert siblings_aspect is not None + entity_urns.add(wu.get_urn()) + assert len(siblings_aspect.siblings) == 3 # 2 consumers and itself + assert siblings_aspect.siblings == [ + wu.get_urn().replace("instance1.db2", "instance2.db2_from_share"), + wu.get_urn().replace("instance1.db2", "instance3.db2"), + wu.get_urn(), + ] + + assert len((entity_urns)) == 6 + + +def test_snowflake_shares_workunit_inbound_and_outbound_share( + snowflake_databases: List[SnowflakeDatabase], +) -> None: + config = SnowflakeV2Config( + account_id="abc12345", + platform_instance="instance1", + inbound_shares_map={ + "db1": SnowflakeDatabaseDataHubId( + database_name="original_db1", platform_instance="instance2" + ) + }, + outbound_shares_map={ + "db2": [ + SnowflakeDatabaseDataHubId( + database_name="db2_from_share", platform_instance="instance2" + ), + SnowflakeDatabaseDataHubId( + database_name="db2", platform_instance="instance3" + ), + ] + }, + ) + + report = SnowflakeV2Report() + shares_handler = SnowflakeSharesHandler( + config, report, lambda x: make_snowflake_urn(x, "instance1") + ) + + wus = list(shares_handler.get_workunits(snowflake_databases)) + + # 6 Sibling and 6 upstreamLineage aspects for db1 tables + # 6 Sibling aspects for db2 tables + assert len(wus) == 18 + + for wu in wus: + assert isinstance( + wu.metadata, (MetadataChangeProposal, MetadataChangeProposalWrapper) + ) + if wu.metadata.aspectName == "upstreamLineage": + upstream_aspect = wu.get_aspect_of_type(UpstreamLineage) + assert upstream_aspect is not None + assert len(upstream_aspect.upstreams) == 1 + assert upstream_aspect.upstreams[0].dataset == wu.get_urn().replace( + "instance1.db1", "instance2.original_db1" + ) + else: + siblings_aspect = wu.get_aspect_of_type(Siblings) + assert siblings_aspect is not None + if "db1" in wu.get_urn(): + assert len(siblings_aspect.siblings) == 2 # upstream and itself + assert siblings_aspect.siblings == [ + wu.get_urn().replace("instance1.db1", "instance2.original_db1"), + wu.get_urn(), + ] + else: + assert len(siblings_aspect.siblings) == 3 # 2 consumers and itself + assert siblings_aspect.siblings == [ + wu.get_urn().replace("instance1.db2", "instance2.db2_from_share"), + wu.get_urn().replace("instance1.db2", "instance3.db2"), + wu.get_urn(), + ] From daf87efe0ae2656e1ed9827388fd617a75b29b87 Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Thu, 27 Jul 2023 19:02:52 +0530 Subject: [PATCH 03/12] remove node itself from sibling, more TODO --- .../source/snowflake/snowflake_config.py | 2 +- .../source/snowflake/snowflake_shares.py | 8 ++++++-- .../tests/unit/test_snowflake_shares.py | 16 ++++++---------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 3f0d323d58586..f1cf1d6f38ae2 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -267,7 +267,7 @@ def validator_inbound_outbound_shares_map(cls, values: Dict) -> Dict: # Check: same database from some platform instance as inbound and outbound other_platform_instance_databases: Sequence[SnowflakeDatabaseDataHubId] = [ - db for db in inbound_shares_map.values() + db for db in set(inbound_shares_map.values()) ] + [db for dbs in outbound_shares_map.values() for db in dbs] for other_instance_db in other_platform_instance_databases: diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py index 4961e92be71d1..5a61c98fef916 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -54,6 +54,12 @@ def get_workunits( for schema in db.schemas: for table_name in schema.tables + schema.views: + # TODO: If this is outbound database, + # 1. attempt listing shares using `show shares` to identify name of share associated with this database (cache query result). + # 2. if corresponding share is listed, then run `show grants to share ` to identify exact tables, views included in share. + # 3. emit siblings only for the objects listed above. + # This will work only if the configured role has accountadmin role access OR is owner of share. + # Otherwise ghost nodes will be shown in "Composed Of" section for tables/views in original database which are not granted to share. yield from self.get_siblings( db.name, schema.name, @@ -131,8 +137,6 @@ def get_siblings( for sibling_db in sibling_databases ] - sibling_urns.append(urn) - yield MetadataChangeProposalWrapper( entityUrn=urn, aspect=Siblings(primary=primary, siblings=sibling_urns) ).as_workunit() diff --git a/metadata-ingestion/tests/unit/test_snowflake_shares.py b/metadata-ingestion/tests/unit/test_snowflake_shares.py index d80cd2dc07ca1..3d55ffd484a13 100644 --- a/metadata-ingestion/tests/unit/test_snowflake_shares.py +++ b/metadata-ingestion/tests/unit/test_snowflake_shares.py @@ -242,10 +242,9 @@ def test_snowflake_shares_workunit_inbound_share( siblings_aspect = wu.get_aspect_of_type(Siblings) assert siblings_aspect is not None sibling_aspect_entity_urns.add(wu.get_urn()) - assert len(siblings_aspect.siblings) == 2 # upstream and itself + assert len(siblings_aspect.siblings) == 1 assert siblings_aspect.siblings == [ - wu.get_urn().replace("instance1.db1", "instance2.original_db1"), - wu.get_urn(), + wu.get_urn().replace("instance1.db1", "instance2.original_db1") ] assert len(upstream_lineage_aspect_entity_urns) == 6 @@ -286,11 +285,10 @@ def test_snowflake_shares_workunit_outbound_share( siblings_aspect = wu.get_aspect_of_type(Siblings) assert siblings_aspect is not None entity_urns.add(wu.get_urn()) - assert len(siblings_aspect.siblings) == 3 # 2 consumers and itself + assert len(siblings_aspect.siblings) == 2 assert siblings_aspect.siblings == [ wu.get_urn().replace("instance1.db2", "instance2.db2_from_share"), wu.get_urn().replace("instance1.db2", "instance3.db2"), - wu.get_urn(), ] assert len((entity_urns)) == 6 @@ -345,15 +343,13 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share( siblings_aspect = wu.get_aspect_of_type(Siblings) assert siblings_aspect is not None if "db1" in wu.get_urn(): - assert len(siblings_aspect.siblings) == 2 # upstream and itself + assert len(siblings_aspect.siblings) == 1 assert siblings_aspect.siblings == [ - wu.get_urn().replace("instance1.db1", "instance2.original_db1"), - wu.get_urn(), + wu.get_urn().replace("instance1.db1", "instance2.original_db1") ] else: - assert len(siblings_aspect.siblings) == 3 # 2 consumers and itself + assert len(siblings_aspect.siblings) == 2 assert siblings_aspect.siblings == [ wu.get_urn().replace("instance1.db2", "instance2.db2_from_share"), wu.get_urn().replace("instance1.db2", "instance3.db2"), - wu.get_urn(), ] From 8018e5a649e7adadd97e49ce3a644755c87a72c6 Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Mon, 31 Jul 2023 17:47:28 +0530 Subject: [PATCH 04/12] update tests, comments --- .../source/snowflake/snowflake_config.py | 9 +++++--- .../source/snowflake/snowflake_shares.py | 2 +- .../tests/unit/test_snowflake_shares.py | 21 +++++++++++-------- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index f1cf1d6f38ae2..0f708d965c64a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -138,7 +138,7 @@ class SnowflakeV2Config( default=None, description="Required if the current account has created any outbound snowflake shares and there is at least one consumer account in which database is created from such share." " If specified, connector creates siblings relationship between current account's database tables and all database tables created in consumer accounts from the share including current account's database." - " Map of database name X -> list of (platform instance of snowflake consumer account who've created database from share, name of database created from share) for all shares created from database name X.", + " Map of database name D -> list of (platform instance of snowflake consumer account who've created database from share, name of database created from share) for all shares created from database name D.", ) @validator("include_column_lineage") @@ -253,7 +253,7 @@ def validator_inbound_outbound_shares_map(cls, values: Dict) -> Dict: ): raise ValueError( "Current `platform_instance` can not be present as any database in `outbound_shares_map`." - "Self-sharing not supported in Snowflake. Please check your configuration." + "Self-sharing is not supported in Snowflake. Please check your configuration." ) # Check: platform_instance should be present @@ -267,7 +267,10 @@ def validator_inbound_outbound_shares_map(cls, values: Dict) -> Dict: # Check: same database from some platform instance as inbound and outbound other_platform_instance_databases: Sequence[SnowflakeDatabaseDataHubId] = [ - db for db in set(inbound_shares_map.values()) + db + for db in set( + inbound_shares_map.values() + ) # using set as multiple inbound shares may be present from same original database ] + [db for dbs in outbound_shares_map.values() for db in dbs] for other_instance_db in other_platform_instance_databases: diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py index 5a61c98fef916..0d89b8569ba34 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -59,7 +59,7 @@ def get_workunits( # 2. if corresponding share is listed, then run `show grants to share ` to identify exact tables, views included in share. # 3. emit siblings only for the objects listed above. # This will work only if the configured role has accountadmin role access OR is owner of share. - # Otherwise ghost nodes will be shown in "Composed Of" section for tables/views in original database which are not granted to share. + # Otherwise ghost nodes may be shown in "Composed Of" section for tables/views in original database which are not granted to share. yield from self.get_siblings( db.name, schema.name, diff --git a/metadata-ingestion/tests/unit/test_snowflake_shares.py b/metadata-ingestion/tests/unit/test_snowflake_shares.py index 3d55ffd484a13..8b4611d380f6c 100644 --- a/metadata-ingestion/tests/unit/test_snowflake_shares.py +++ b/metadata-ingestion/tests/unit/test_snowflake_shares.py @@ -111,7 +111,7 @@ def test_snowflake_shares_workunit_no_shares( def test_same_database_inbound_and_outbound_invalid_config() -> None: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Same database can not be present in both"): SnowflakeV2Config( account_id="abc12345", platform_instance="instance1", @@ -134,7 +134,9 @@ def test_same_database_inbound_and_outbound_invalid_config() -> None: def test_current_platform_instance_inbound_and_outbound_invalid() -> None: - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Current `platform_instance` can not be present as" + ): SnowflakeV2Config( account_id="abc12345", platform_instance="instance1", @@ -155,7 +157,9 @@ def test_current_platform_instance_inbound_and_outbound_invalid() -> None: }, ) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Current `platform_instance` can not be present as" + ): SnowflakeV2Config( account_id="abc12345", platform_instance="instance1", @@ -178,7 +182,7 @@ def test_current_platform_instance_inbound_and_outbound_invalid() -> None: def test_another_instance_database_inbound_and_outbound_invalid() -> None: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="A database can exist only once either in"): SnowflakeV2Config( account_id="abc12345", platform_instance="instance1", @@ -233,22 +237,21 @@ def test_snowflake_shares_workunit_inbound_share( if wu.metadata.aspectName == "upstreamLineage": upstream_aspect = wu.get_aspect_of_type(UpstreamLineage) assert upstream_aspect is not None - upstream_lineage_aspect_entity_urns.add(wu.get_urn()) assert len(upstream_aspect.upstreams) == 1 assert upstream_aspect.upstreams[0].dataset == wu.get_urn().replace( "instance1.db1", "instance2.original_db1" ) + upstream_lineage_aspect_entity_urns.add(wu.get_urn()) else: siblings_aspect = wu.get_aspect_of_type(Siblings) assert siblings_aspect is not None - sibling_aspect_entity_urns.add(wu.get_urn()) assert len(siblings_aspect.siblings) == 1 assert siblings_aspect.siblings == [ wu.get_urn().replace("instance1.db1", "instance2.original_db1") ] + sibling_aspect_entity_urns.add(wu.get_urn()) - assert len(upstream_lineage_aspect_entity_urns) == 6 - assert len(sibling_aspect_entity_urns) == 6 + assert upstream_lineage_aspect_entity_urns == sibling_aspect_entity_urns def test_snowflake_shares_workunit_outbound_share( @@ -284,12 +287,12 @@ def test_snowflake_shares_workunit_outbound_share( for wu in wus: siblings_aspect = wu.get_aspect_of_type(Siblings) assert siblings_aspect is not None - entity_urns.add(wu.get_urn()) assert len(siblings_aspect.siblings) == 2 assert siblings_aspect.siblings == [ wu.get_urn().replace("instance1.db2", "instance2.db2_from_share"), wu.get_urn().replace("instance1.db2", "instance3.db2"), ] + entity_urns.add(wu.get_urn()) assert len((entity_urns)) == 6 From 69e37928b731bb0201ce5801a10b9f243cfe571b Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Tue, 1 Aug 2023 17:46:42 +0530 Subject: [PATCH 05/12] revert push down filtering changes near query --- .../source/snowflake/snowflake_schema.py | 141 ++++++------------ .../source/snowflake/snowflake_v2.py | 69 +++++++-- 2 files changed, 102 insertions(+), 108 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py index 298a8dbde48c8..e5b214ba35e4b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py @@ -8,15 +8,9 @@ import pandas as pd from snowflake.connector import SnowflakeConnection -from datahub.configuration.pattern_utils import is_schema_allowed from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain -from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery -from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report -from datahub.ingestion.source.snowflake.snowflake_utils import ( - SnowflakeCommonMixin, - SnowflakeQueryMixin, -) +from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeQueryMixin from datahub.ingestion.source.sql.sql_generic import BaseColumn, BaseTable, BaseView logger: logging.Logger = logging.getLogger(__name__) @@ -183,10 +177,8 @@ def get_column_tags_for_table( ) -class SnowflakeDataDictionary(SnowflakeQueryMixin, SnowflakeCommonMixin): - def __init__(self, config: SnowflakeV2Config, report: SnowflakeV2Report) -> None: - self.config = config - self.report = report +class SnowflakeDataDictionary(SnowflakeQueryMixin): + def __init__(self) -> None: self.logger = logger self.connection: Optional[SnowflakeConnection] = None @@ -229,11 +221,7 @@ def get_databases(self, db_name: str) -> List[SnowflakeDatabase]: last_altered=database["LAST_ALTERED"], comment=database["COMMENT"], ) - self.report.report_entity_scanned(snowflake_db.name, "database") - if not self.config.database_pattern.allowed(snowflake_db.name): - self.report.report_dropped(f"{snowflake_db.name}.*") - else: - databases.append(snowflake_db) + databases.append(snowflake_db) return databases @@ -251,16 +239,7 @@ def get_schemas_for_database(self, db_name: str) -> List[SnowflakeSchema]: last_altered=schema["LAST_ALTERED"], comment=schema["COMMENT"], ) - self.report.report_entity_scanned(snowflake_schema.name, "schema") - if not is_schema_allowed( - self.config.schema_pattern, - snowflake_schema.name, - db_name, - self.config.match_fully_qualified_names, - ): - self.report.report_dropped(f"{db_name}.{snowflake_schema.name}.*") - else: - snowflake_schemas.append(snowflake_schema) + snowflake_schemas.append(snowflake_schema) return snowflake_schemas @lru_cache(maxsize=1) @@ -283,24 +262,17 @@ def get_tables_for_database( if table["TABLE_SCHEMA"] not in tables: tables[table["TABLE_SCHEMA"]] = [] - table_identifier = self.get_dataset_identifier( - table["TABLE_NAME"], table["TABLE_SCHEMA"], db_name - ) - self.report.report_entity_scanned(table_identifier) - if not self.config.table_pattern.allowed(table_identifier): - self.report.report_dropped(table_identifier) - else: - tables[table["TABLE_SCHEMA"]].append( - SnowflakeTable( - name=table["TABLE_NAME"], - created=table["CREATED"], - last_altered=table["LAST_ALTERED"], - size_in_bytes=table["BYTES"], - rows_count=table["ROW_COUNT"], - comment=table["COMMENT"], - clustering_key=table["CLUSTERING_KEY"], - ) + tables[table["TABLE_SCHEMA"]].append( + SnowflakeTable( + name=table["TABLE_NAME"], + created=table["CREATED"], + last_altered=table["LAST_ALTERED"], + size_in_bytes=table["BYTES"], + rows_count=table["ROW_COUNT"], + comment=table["COMMENT"], + clustering_key=table["CLUSTERING_KEY"], ) + ) return tables def get_tables_for_schema( @@ -313,24 +285,17 @@ def get_tables_for_schema( ) for table in cur: - table_identifier = self.get_dataset_identifier( - table["TABLE_NAME"], schema_name, db_name - ) - self.report.report_entity_scanned(table_identifier) - if not self.config.table_pattern.allowed(table_identifier): - self.report.report_dropped(table_identifier) - else: - tables.append( - SnowflakeTable( - name=table["TABLE_NAME"], - created=table["CREATED"], - last_altered=table["LAST_ALTERED"], - size_in_bytes=table["BYTES"], - rows_count=table["ROW_COUNT"], - comment=table["COMMENT"], - clustering_key=table["CLUSTERING_KEY"], - ) + tables.append( + SnowflakeTable( + name=table["TABLE_NAME"], + created=table["CREATED"], + last_altered=table["LAST_ALTERED"], + size_in_bytes=table["BYTES"], + rows_count=table["ROW_COUNT"], + comment=table["COMMENT"], + clustering_key=table["CLUSTERING_KEY"], ) + ) return tables @lru_cache(maxsize=1) @@ -350,27 +315,18 @@ def get_views_for_database( for table in cur: if table["schema_name"] not in views: views[table["schema_name"]] = [] - view_name = self.get_dataset_identifier( - table["name"], table["schema_name"], db_name - ) - - self.report.report_entity_scanned(view_name, "view") - - if not self.config.view_pattern.allowed(view_name): - self.report.report_dropped(view_name) - else: - views[table["schema_name"]].append( - SnowflakeView( - name=table["name"], - created=table["created_on"], - # last_altered=table["last_altered"], - comment=table["comment"], - view_definition=table["text"], - last_altered=table["created_on"], - materialized=table.get("is_materialized", "false").lower() - == "true", - ) + views[table["schema_name"]].append( + SnowflakeView( + name=table["name"], + created=table["created_on"], + # last_altered=table["last_altered"], + comment=table["comment"], + view_definition=table["text"], + last_altered=table["created_on"], + materialized=table.get("is_materialized", "false").lower() + == "true", ) + ) return views def get_views_for_schema( @@ -380,23 +336,16 @@ def get_views_for_schema( cur = self.query(SnowflakeQuery.show_views_for_schema(schema_name, db_name)) for table in cur: - view_name = self.get_dataset_identifier(table["name"], schema_name, db_name) - - self.report.report_entity_scanned(view_name, "view") - - if not self.config.view_pattern.allowed(view_name): - self.report.report_dropped(view_name) - else: - views.append( - SnowflakeView( - name=table["name"], - created=table["created_on"], - # last_altered=table["last_altered"], - comment=table["comment"], - view_definition=table["text"], - last_altered=table["created_on"], - ) + views.append( + SnowflakeView( + name=table["name"], + created=table["created_on"], + # last_altered=table["last_altered"], + comment=table["comment"], + view_definition=table["text"], + last_altered=table["created_on"], ) + ) return views @lru_cache(maxsize=1) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index 46e0b04688547..d237b525cb796 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -9,6 +9,7 @@ import pandas as pd from snowflake.connector import SnowflakeConnection +from datahub.configuration.pattern_utils import is_schema_allowed from datahub.emitter.mce_builder import ( make_data_platform_urn, make_dataset_urn, @@ -238,7 +239,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): ) # For database, schema, tables, views, etc - self.data_dictionary = SnowflakeDataDictionary(self.config, self.report) + self.data_dictionary = SnowflakeDataDictionary() self.lineage_extractor: Union[ SnowflakeLineageExtractor, SnowflakeLineageLegacyExtractor @@ -503,9 +504,16 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: return self.data_dictionary.set_connection(self.connection) - databases = self.get_databases() + databases: List[SnowflakeDatabase] = [] - if databases is None or len(databases) == 0: + for database in self.get_databases() or []: + self.report.report_entity_scanned(database.name, "database") + if not self.config.database_pattern.allowed(database.name): + self.report.report_dropped(f"{database.name}.*") + else: + databases.append(database) + + if len(databases) == 0: return for snowflake_db in databases: @@ -696,11 +704,22 @@ def _process_database( if self.config.profiling.enabled and self.db_tables: yield from self.profiler.get_workunits(snowflake_db, self.db_tables) - def fetch_schemas_for_database(self, snowflake_db, db_name): + def fetch_schemas_for_database( + self, snowflake_db: SnowflakeDatabase, db_name: str + ) -> None: + schemas: List[SnowflakeSchema] = [] try: - snowflake_db.schemas = self.data_dictionary.get_schemas_for_database( - db_name - ) + for schema in self.data_dictionary.get_schemas_for_database(db_name): + self.report.report_entity_scanned(schema.name, "schema") + if not is_schema_allowed( + self.config.schema_pattern, + schema.name, + db_name, + self.config.match_fully_qualified_names, + ): + self.report.report_dropped(f"{db_name}.{schema.name}.*") + else: + schemas.append(schema) except Exception as e: if isinstance(e, SnowflakePermissionError): error_msg = f"Failed to get schemas for database {db_name}. Please check permissions." @@ -716,11 +735,13 @@ def fetch_schemas_for_database(self, snowflake_db, db_name): db_name, ) - if not snowflake_db.schemas: + if not schemas: self.report_warning( "No schemas found in database. If schemas exist, please grant USAGE permissions on them.", db_name, ) + else: + snowflake_db.schemas = schemas def _process_schema( self, snowflake_schema: SnowflakeSchema, db_name: str @@ -766,9 +787,20 @@ def _process_schema( f"{db_name}.{schema_name}", ) - def fetch_views_for_schema(self, snowflake_schema, db_name, schema_name): + def fetch_views_for_schema( + self, snowflake_schema: SnowflakeSchema, db_name: str, schema_name: str + ) -> List[SnowflakeView]: try: - views = self.get_views_for_schema(schema_name, db_name) + views: List[SnowflakeView] = [] + for view in self.get_views_for_schema(schema_name, db_name): + view_name = self.get_dataset_identifier(view.name, schema_name, db_name) + + self.report.report_entity_scanned(view_name, "view") + + if not self.config.view_pattern.allowed(view_name): + self.report.report_dropped(view_name) + else: + views.append(view) snowflake_schema.views = [view.name for view in views] return views except Exception as e: @@ -786,10 +818,22 @@ def fetch_views_for_schema(self, snowflake_schema, db_name, schema_name): "Failed to get views for schema", f"{db_name}.{schema_name}", ) + return [] - def fetch_tables_for_schema(self, snowflake_schema, db_name, schema_name): + def fetch_tables_for_schema( + self, snowflake_schema: SnowflakeSchema, db_name: str, schema_name: str + ) -> List[SnowflakeTable]: try: - tables = self.get_tables_for_schema(schema_name, db_name) + tables: List[SnowflakeTable] = [] + for table in self.get_tables_for_schema(schema_name, db_name): + table_identifier = self.get_dataset_identifier( + table.name, schema_name, db_name + ) + self.report.report_entity_scanned(table_identifier) + if not self.config.table_pattern.allowed(table_identifier): + self.report.report_dropped(table_identifier) + else: + tables.append(table) snowflake_schema.tables = [table.name for table in tables] return tables except Exception as e: @@ -806,6 +850,7 @@ def fetch_tables_for_schema(self, snowflake_schema, db_name, schema_name): "Failed to get tables for schema", f"{db_name}.{schema_name}", ) + return [] def _process_table( self, From 4ce37e54f74167fc196bd62bd896b8c392140bd2 Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Tue, 1 Aug 2023 17:55:41 +0530 Subject: [PATCH 06/12] change config structure to allow same shares config across accounts --- .../docs/sources/snowflake/snowflake_pre.md | 24 ++- .../source/snowflake/snowflake_config.py | 128 +++++------ .../source/snowflake/snowflake_shares.py | 158 +++++++++----- .../source/snowflake/snowflake_v2.py | 4 +- .../tests/unit/test_snowflake_shares.py | 200 +++++++++--------- 5 files changed, 275 insertions(+), 239 deletions(-) diff --git a/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md b/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md index 7768ab89fdb88..d8edc7c2184cb 100644 --- a/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md +++ b/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md @@ -100,7 +100,7 @@ The steps slightly differ based on which you decide to use. * Note: the `username` and `password` config options are not nested under `oauth_config` ### Snowflake Shares -If you are using [Snowflake Shares](https://docs.snowflake.com/en/user-guide/data-sharing-provider) to share data across different snowflake accounts, and you have set up DataHub recipes for ingesting metadata from all these accounts, you may end up having multiple similar dataset entities corresponding to virtual versions of same table in different snowflake accounts. DataHub Snowflake connector can automatically link such tables together through Siblings and Lineage relationship if user provides information necessary to establish the relationship using configurations `inbound_shares_map` and `outbound_shares_map` in recipe. +If you are using [Snowflake Shares](https://docs.snowflake.com/en/user-guide/data-sharing-provider) to share data across different snowflake accounts, and you have set up DataHub recipes for ingesting metadata from all these accounts, you may end up having multiple similar dataset entities corresponding to virtual versions of same table in different snowflake accounts. DataHub Snowflake connector can automatically link such tables together through Siblings and Lineage relationship if user provides information necessary to establish the relationship using configuration `shares` in recipe. #### Example - Snowflake account `account1` (ingested as platform_instance `instance1`) owns a database `db1`. A share `X` is created in `account1` that includes database `db1` along with schemas and tables inside it. @@ -111,20 +111,26 @@ If you are using [Snowflake Shares](https://docs.snowflake.com/en/user-guide/dat ```yaml account_id: account1 platform_instance: instance1 - outbound_shares_map: - db1: - - platform_instance: instance2 # this is a list, as db1 can be shared with multiple snowflake accounts using X - database_name: db1_from_X + shares: + X: + platform_instance: instance1 + database_name: db1 + consumers: + - platform_instance: instance2 # this is a list, as db1 can be shared with multiple snowflake accounts using X + database_name: db1_from_X ``` - In snowflake recipe of `account2` : ```yaml account_id: account2 platform_instance: instance2 - inbound_shares_map: - db1_from_X: - platform_instance: instance1 - database_name: db1 + shares: + X: + platform_instance: instance1 + database_name: db1 + consumers: + - platform_instance: instance2 # this is a list, as db1 can be shared with multiple snowflake accounts using X + database_name: db1_from_X ``` ### Caveats diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 0f708d965c64a..c3a71e320dfc5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -1,11 +1,11 @@ import logging from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Sequence, cast +from typing import Dict, List, Optional, Set, cast from pydantic import Field, SecretStr, root_validator, validator -from datahub.configuration.common import AllowDenyPattern +from datahub.configuration.common import AllowDenyPattern, ConfigModel from datahub.configuration.pattern_utils import UUID_REGEX from datahub.configuration.validate_field_removal import pydantic_removed_field from datahub.configuration.validate_field_rename import pydantic_renamed_field @@ -44,9 +44,24 @@ class TagOption(str, Enum): @dataclass(frozen=True) -class SnowflakeDatabaseDataHubId: - platform_instance: str - database_name: str +class DatabaseId: + database: str = Field( + description="Database created from share in consumer account." + ) + platform_instance: str = Field( + description="Platform instance of consumer snowflake account." + ) + + +class SnowflakeShareConfig(ConfigModel): + database: str = Field(description="Database from which share is created.") + platform_instance: str = Field( + description="Platform instance for snowflake account in which share is created." + ) + + consumers: Set[DatabaseId] = Field( + description="List of databases created in consumer accounts." + ) class SnowflakeV2Config( @@ -127,18 +142,11 @@ class SnowflakeV2Config( "upstreams_deny_pattern", "temporary_tables_pattern" ) - inbound_shares_map: Optional[Dict[str, SnowflakeDatabaseDataHubId]] = Field( + shares: Optional[Dict[str, SnowflakeShareConfig]] = Field( default=None, - description="Required if the current account has any database created from inbound snowflake share." - " If specified, connector creates lineage and siblings relationship between current account's database tables and original database tables from which snowflake share was created." - " Map of database name -> (platform instance of snowflake account containing original database, original database name).", - ) - - outbound_shares_map: Optional[Dict[str, List[SnowflakeDatabaseDataHubId]]] = Field( - default=None, - description="Required if the current account has created any outbound snowflake shares and there is at least one consumer account in which database is created from such share." - " If specified, connector creates siblings relationship between current account's database tables and all database tables created in consumer accounts from the share including current account's database." - " Map of database name D -> list of (platform instance of snowflake consumer account who've created database from share, name of database created from share) for all shares created from database name D.", + description="Required if current account owns or consumes snowflake share." + " If specified, connector creates lineage and siblings relationship between current account's database tables and consumer/producer account's database tables." + " Map of share name -> details of share.", ) @validator("include_column_lineage") @@ -214,68 +222,40 @@ def get_sql_alchemy_url( def parse_view_ddl(self) -> bool: return self.include_view_column_lineage - @root_validator(pre=False) - def validator_inbound_outbound_shares_map(cls, values: Dict) -> Dict: - inbound_shares_map: Dict[str, SnowflakeDatabaseDataHubId] = ( - values.get("inbound_shares_map") or {} - ) - outbound_shares_map: Dict[str, List[SnowflakeDatabaseDataHubId]] = ( - values.get("outbound_shares_map") or {} - ) - - # Check: same database from current instance as inbound and outbound - common_keys = [key for key in inbound_shares_map if key in outbound_shares_map] - - assert ( - len(common_keys) == 0 - ), "Same database can not be present in both `inbound_shares_map` and `outbound_shares_map`." - + @validator("shares") + def validate_shares( + cls, shares: Optional[Dict[str, SnowflakeShareConfig]], values: Dict + ) -> Optional[Dict[str, SnowflakeShareConfig]]: current_platform_instance = values.get("platform_instance") - # Check: current platform_instance present as inbound and outbound - if current_platform_instance and any( - [ - db.platform_instance == current_platform_instance - for db in inbound_shares_map.values() - ] - ): - raise ValueError( - "Current `platform_instance` can not be present as any database in `inbound_shares_map`." - "Self-sharing not supported in Snowflake. Please check your configuration." - ) - - if current_platform_instance and any( - [ - db.platform_instance == current_platform_instance - for dbs in outbound_shares_map.values() - for db in dbs - ] - ): - raise ValueError( - "Current `platform_instance` can not be present as any database in `outbound_shares_map`." - "Self-sharing is not supported in Snowflake. Please check your configuration." - ) - # Check: platform_instance should be present - if ( - inbound_shares_map or outbound_shares_map - ) and not current_platform_instance: - logger.warn( + if shares: + assert current_platform_instance is not None, ( "Did you forget to set `platform_instance` for current ingestion ?" "It is advisable to use `platform_instance` when ingesting from multiple snowflake accounts." ) - # Check: same database from some platform instance as inbound and outbound - other_platform_instance_databases: Sequence[SnowflakeDatabaseDataHubId] = [ - db - for db in set( - inbound_shares_map.values() - ) # using set as multiple inbound shares may be present from same original database - ] + [db for dbs in outbound_shares_map.values() for db in dbs] - - for other_instance_db in other_platform_instance_databases: - assert ( - other_platform_instance_databases.count(other_instance_db) == 1 - ), "A database can exist only once either in `inbound_shares_map` or in `outbound_shares_map`." - - return values + databases_included_in_share: List[DatabaseId] = [] + databases_created_from_share: List[DatabaseId] = [] + + for _, share_details in shares.items(): + shared_db = DatabaseId( + share_details.database, share_details.platform_instance + ) + assert all( + consumer.platform_instance != share_details.platform_instance + for consumer in share_details.consumers + ), "Share's platform_instance can not be same as consumer's platform instance. Self-sharing not supported in Snowflake." + + databases_included_in_share.append(shared_db) + databases_created_from_share.extend(share_details.consumers) + + for db_from_share in databases_created_from_share: + assert ( + db_from_share not in databases_included_in_share + ), "Database included in a share can not be present as consumer in any share." + assert ( + databases_created_from_share.count(db_from_share) == 1 + ), "Same database can not be present as consumer in more than one share." + + return shares diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py index 0d89b8569ba34..57011b72b1a86 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -1,11 +1,13 @@ import logging -from typing import Callable, Iterable, List +from dataclasses import dataclass +from typing import Callable, Dict, Iterable, List, Optional from datahub.emitter.mce_builder import make_dataset_urn_with_platform_instance from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.snowflake.snowflake_config import ( - SnowflakeDatabaseDataHubId, + DatabaseId, + SnowflakeShareConfig, SnowflakeV2Config, ) from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report @@ -21,6 +23,20 @@ logger: logging.Logger = logging.getLogger(__name__) +@dataclass +class SharedDatabase: + """ + Represents shared database from current platform instance + This is either created from an inbound share or included in an outbound share. + """ + + name: str + created_from_share: bool + + # This will have exactly entry if created_from_share = True + shares: List[str] + + class SnowflakeSharesHandler(SnowflakeCommonMixin): def __init__( self, @@ -33,24 +49,61 @@ def __init__( self.logger = logger self.dataset_urn_builder = dataset_urn_builder - def get_workunits( + def _get_shared_databases( + self, shares: Dict[str, SnowflakeShareConfig], platform_instance: Optional[str] + ) -> Dict[str, SharedDatabase]: + # this is ensured in config validators + assert platform_instance is not None + + shared_databases: Dict[str, SharedDatabase] = {} + + for share_name, share_details in shares.items(): + if share_details.platform_instance == platform_instance: + if share_details.database not in shared_databases: + shared_databases[share_details.database] = SharedDatabase( + name=share_details.database, + created_from_share=False, + shares=[share_name], + ) + + else: + shared_databases[share_details.database].shares += share_name + + else: + for consumer in share_details.consumers: + if consumer.platform_instance == platform_instance: + shared_databases[consumer.database] = SharedDatabase( + name=share_details.database, + created_from_share=True, + shares=[share_name], + ) + break + else: + self.report_warning( + f"Skipping Share, as it does not include current platform instance {platform_instance}", + share_name, + ) + + return shared_databases + + def get_shares_workunits( self, databases: List[SnowflakeDatabase] ) -> Iterable[MetadataWorkUnit]: + shared_databases = self._get_shared_databases( + self.config.shares or {}, self.config.platform_instance + ) + + # None of the databases are shared + if not shared_databases: + return + logger.debug("Checking databases for inbound or outbound shares.") for db in databases: - inbound, outbounds = self.get_sharing_details(db) - - if not (inbound or outbounds): - logger.debug(f"database {db.name} is not a shared.") + if db.name not in shared_databases: + logger.debug(f"database {db.name} is not shared.") continue - sibling_dbs: List[SnowflakeDatabaseDataHubId] - if inbound: - sibling_dbs = [inbound] - logger.debug(f"database {db.name} is created from inbound share.") - else: # outbounds - sibling_dbs = outbounds - logger.debug(f"database {db.name} is shared as outbound share.") + sibling_dbs = self.get_sibling_databases(shared_databases[db.name]) for schema in db.schemas: for table_name in schema.tables + schema.views: @@ -60,64 +113,70 @@ def get_workunits( # 3. emit siblings only for the objects listed above. # This will work only if the configured role has accountadmin role access OR is owner of share. # Otherwise ghost nodes may be shown in "Composed Of" section for tables/views in original database which are not granted to share. - yield from self.get_siblings( + yield from self.gen_siblings( db.name, schema.name, table_name, - True if outbounds else False, + not shared_databases[db.name].created_from_share, sibling_dbs, ) - if inbound: + if shared_databases[db.name].created_from_share: + assert len(sibling_dbs) == 1 # SnowflakeLineageExtractor is unaware of database->schema->table hierarchy # hence this lineage code is not written in SnowflakeLineageExtractor # also this is not governed by configs include_table_lineage and include_view_lineage yield self.get_upstream_lineage_with_primary_sibling( - db.name, schema.name, table_name, inbound + db.name, schema.name, table_name, sibling_dbs[0] ) - self.report_missing_databases(databases) + self.report_missing_databases(databases, shared_databases) - def get_sharing_details(self, db): - inbound = ( - self.config.inbound_shares_map.get(db.name) - if self.config.inbound_shares_map - and db.name in self.config.inbound_shares_map - else None - ) - outbounds = ( - self.config.outbound_shares_map.get(db.name) - if self.config.outbound_shares_map - and db.name in self.config.outbound_shares_map - else None - ) + def get_sibling_databases(self, db: SharedDatabase) -> List[DatabaseId]: + assert self.config.shares is not None + sibling_dbs: List[DatabaseId] = [] + if db.created_from_share: + share_details = self.config.shares[db.shares[0]] + logger.debug( + f"database {db.name} is created from inbound share {db.shares[0]}." + ) + sibling_dbs = [ + DatabaseId(share_details.database, share_details.platform_instance) + ] + + else: # not created from share, but is in fact included in share + logger.debug( + f"database {db.name} is included as outbound share(s) {db.shares}." + ) + sibling_dbs = [ + consumer + for share_name in db.shares + for consumer in self.config.shares[share_name].consumers + ] - return inbound, outbounds + return sibling_dbs - def report_missing_databases(self, databases): + def report_missing_databases( + self, + databases: List[SnowflakeDatabase], + shared_databases: Dict[str, SharedDatabase], + ) -> None: db_names = [db.name for db in databases] - missing_dbs = [] - if self.config.inbound_shares_map: - missing_dbs.extend( - [db for db in self.config.inbound_shares_map if db not in db_names] - ) - if self.config.outbound_shares_map: - missing_dbs.extend( - [db for db in self.config.outbound_shares_map if db not in db_names] - ) + missing_dbs = [db for db in shared_databases.keys() if db not in db_names] + if missing_dbs: self.report_warning( "snowflake-shares", f"Databases {missing_dbs} were not ingested. Siblings/Lineage will not be set for these.", ) - def get_siblings( + def gen_siblings( self, database_name: str, schema_name: str, table_name: str, primary: bool, - sibling_databases: List[SnowflakeDatabaseDataHubId], + sibling_databases: List[DatabaseId], ) -> Iterable[MetadataWorkUnit]: if not sibling_databases: return @@ -130,7 +189,7 @@ def get_siblings( make_dataset_urn_with_platform_instance( self.platform, self.get_dataset_identifier( - table_name, schema_name, sibling_db.database_name + table_name, schema_name, sibling_db.database ), sibling_db.platform_instance, ) @@ -138,7 +197,8 @@ def get_siblings( ] yield MetadataChangeProposalWrapper( - entityUrn=urn, aspect=Siblings(primary=primary, siblings=sibling_urns) + entityUrn=urn, + aspect=Siblings(primary=primary, siblings=sorted(sibling_urns)), ).as_workunit() def get_upstream_lineage_with_primary_sibling( @@ -146,7 +206,7 @@ def get_upstream_lineage_with_primary_sibling( database_name: str, schema_name: str, table_name: str, - primary_sibling_db: SnowflakeDatabaseDataHubId, + primary_sibling_db: DatabaseId, ) -> MetadataWorkUnit: dataset_identifier = self.get_dataset_identifier( table_name, schema_name, database_name @@ -156,7 +216,7 @@ def get_upstream_lineage_with_primary_sibling( upstream_urn = make_dataset_urn_with_platform_instance( self.platform, self.get_dataset_identifier( - table_name, schema_name, primary_sibling_db.database_name + table_name, schema_name, primary_sibling_db.database ), primary_sibling_db.platform_instance, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index d237b525cb796..99e9f9a62ab16 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -540,10 +540,10 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # TODO: The checkpoint state for stale entity detection can be committed here. - if self.config.inbound_shares_map or self.config.outbound_shares_map: + if self.config.shares: yield from SnowflakeSharesHandler( self.config, self.report, self.gen_dataset_urn - ).get_workunits(databases) + ).get_shares_workunits(databases) discovered_tables: List[str] = [ self.get_dataset_identifier(table_name, schema.name, db.name) diff --git a/metadata-ingestion/tests/unit/test_snowflake_shares.py b/metadata-ingestion/tests/unit/test_snowflake_shares.py index 8b4611d380f6c..7de86139baf39 100644 --- a/metadata-ingestion/tests/unit/test_snowflake_shares.py +++ b/metadata-ingestion/tests/unit/test_snowflake_shares.py @@ -5,7 +5,8 @@ from datahub.emitter.mce_builder import make_dataset_urn_with_platform_instance from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.source.snowflake.snowflake_config import ( - SnowflakeDatabaseDataHubId, + DatabaseId, + SnowflakeShareConfig, SnowflakeV2Config, ) from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report @@ -105,101 +106,84 @@ def test_snowflake_shares_workunit_no_shares( config, report, lambda x: make_snowflake_urn(x) ) - wus = list(shares_handler.get_workunits(snowflake_databases)) + wus = list(shares_handler.get_shares_workunits(snowflake_databases)) assert len(wus) == 0 def test_same_database_inbound_and_outbound_invalid_config() -> None: - with pytest.raises(ValueError, match="Same database can not be present in both"): - SnowflakeV2Config( - account_id="abc12345", - platform_instance="instance1", - inbound_shares_map={ - "db1": SnowflakeDatabaseDataHubId( - database_name="original_db1", platform_instance="instance2" - ) - }, - outbound_shares_map={ - "db1": [ - SnowflakeDatabaseDataHubId( - database_name="db2_from_share", platform_instance="instance2" - ), - SnowflakeDatabaseDataHubId( - database_name="db2", platform_instance="instance3" - ), - ] - }, - ) - - -def test_current_platform_instance_inbound_and_outbound_invalid() -> None: with pytest.raises( - ValueError, match="Current `platform_instance` can not be present as" + ValueError, + match="Same database can not be present as consumer in more than one share", ): SnowflakeV2Config( account_id="abc12345", platform_instance="instance1", - inbound_shares_map={ - "db1": SnowflakeDatabaseDataHubId( - database_name="db2", platform_instance="instance1" - ) - }, - outbound_shares_map={ - "db2": [ - SnowflakeDatabaseDataHubId( - database_name="db2_from_share", platform_instance="instance2" - ), - SnowflakeDatabaseDataHubId( - database_name="db2", platform_instance="instance3" - ), - ] + shares={ + "share1": SnowflakeShareConfig( + database="db1", + platform_instance="instance2", + consumers=[ + DatabaseId(database="db1", platform_instance="instance1") + ], + ), + "share2": SnowflakeShareConfig( + database="db1", + platform_instance="instance3", + consumers=[ + DatabaseId(database="db1", platform_instance="instance1") + ], + ), }, ) with pytest.raises( - ValueError, match="Current `platform_instance` can not be present as" + ValueError, + match="Database included in a share can not be present as consumer in any share", ): SnowflakeV2Config( account_id="abc12345", platform_instance="instance1", - inbound_shares_map={ - "db1": SnowflakeDatabaseDataHubId( - database_name="db2", platform_instance="instance2" - ) - }, - outbound_shares_map={ - "db2": [ - SnowflakeDatabaseDataHubId( - database_name="db2_from_share", platform_instance="instance2" - ), - SnowflakeDatabaseDataHubId( - database_name="db2", platform_instance="instance1" - ), - ] + shares={ + "share1": SnowflakeShareConfig( + database="db1", + platform_instance="instance2", + consumers=[ + DatabaseId(database="db1", platform_instance="instance1") + ], + ), + "share2": SnowflakeShareConfig( + database="db1", + platform_instance="instance1", + consumers=[ + DatabaseId(database="db1", platform_instance="instance3") + ], + ), }, ) - -def test_another_instance_database_inbound_and_outbound_invalid() -> None: - with pytest.raises(ValueError, match="A database can exist only once either in"): + with pytest.raises( + ValueError, + match="Database included in a share can not be present as consumer in any share", + ): SnowflakeV2Config( account_id="abc12345", platform_instance="instance1", - inbound_shares_map={ - "db1": SnowflakeDatabaseDataHubId( - database_name="db2", platform_instance="instance3" - ) - }, - outbound_shares_map={ - "db2": [ - SnowflakeDatabaseDataHubId( - database_name="db2_from_share", platform_instance="instance2" - ), - SnowflakeDatabaseDataHubId( - database_name="db2", platform_instance="instance3" - ), - ] + shares={ + "share2": SnowflakeShareConfig( + database="db1", + platform_instance="instance1", + consumers=[ + DatabaseId(database="db1", platform_instance="instance3") + ], + ), + "share1": SnowflakeShareConfig( + database="db1", + platform_instance="instance2", + consumers=[ + DatabaseId(database="db1", platform_instance="instance1") + ], + ), }, ) @@ -210,9 +194,11 @@ def test_snowflake_shares_workunit_inbound_share( config = SnowflakeV2Config( account_id="abc12345", platform_instance="instance1", - inbound_shares_map={ - "db1": SnowflakeDatabaseDataHubId( - database_name="original_db1", platform_instance="instance2" + shares={ + "share1": SnowflakeShareConfig( + database="db1", + platform_instance="instance2", + consumers=[DatabaseId(database="db1", platform_instance="instance1")], ) }, ) @@ -222,7 +208,7 @@ def test_snowflake_shares_workunit_inbound_share( config, report, lambda x: make_snowflake_urn(x, "instance1") ) - wus = list(shares_handler.get_workunits(snowflake_databases)) + wus = list(shares_handler.get_shares_workunits(snowflake_databases)) # 2 schemas - 2 tables and 1 view in each schema making total 6 datasets # Hence 6 Sibling and 6 upstreamLineage aspects @@ -239,7 +225,7 @@ def test_snowflake_shares_workunit_inbound_share( assert upstream_aspect is not None assert len(upstream_aspect.upstreams) == 1 assert upstream_aspect.upstreams[0].dataset == wu.get_urn().replace( - "instance1.db1", "instance2.original_db1" + "instance1.db1", "instance2.db1" ) upstream_lineage_aspect_entity_urns.add(wu.get_urn()) else: @@ -247,7 +233,7 @@ def test_snowflake_shares_workunit_inbound_share( assert siblings_aspect is not None assert len(siblings_aspect.siblings) == 1 assert siblings_aspect.siblings == [ - wu.get_urn().replace("instance1.db1", "instance2.original_db1") + wu.get_urn().replace("instance1.db1", "instance2.db1") ] sibling_aspect_entity_urns.add(wu.get_urn()) @@ -260,15 +246,17 @@ def test_snowflake_shares_workunit_outbound_share( config = SnowflakeV2Config( account_id="abc12345", platform_instance="instance1", - outbound_shares_map={ - "db2": [ - SnowflakeDatabaseDataHubId( - database_name="db2_from_share", platform_instance="instance2" - ), - SnowflakeDatabaseDataHubId( - database_name="db2", platform_instance="instance3" - ), - ] + shares={ + "share2": SnowflakeShareConfig( + database="db2", + platform_instance="instance1", + consumers=[ + DatabaseId( + database="db2_from_share", platform_instance="instance2" + ), + DatabaseId(database="db2", platform_instance="instance3"), + ], + ) }, ) @@ -277,7 +265,7 @@ def test_snowflake_shares_workunit_outbound_share( config, report, lambda x: make_snowflake_urn(x, "instance1") ) - wus = list(shares_handler.get_workunits(snowflake_databases)) + wus = list(shares_handler.get_shares_workunits(snowflake_databases)) # 2 schemas - 2 tables and 1 view in each schema making total 6 datasets # Hence 6 Sibling aspects @@ -303,20 +291,22 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share( config = SnowflakeV2Config( account_id="abc12345", platform_instance="instance1", - inbound_shares_map={ - "db1": SnowflakeDatabaseDataHubId( - database_name="original_db1", platform_instance="instance2" - ) - }, - outbound_shares_map={ - "db2": [ - SnowflakeDatabaseDataHubId( - database_name="db2_from_share", platform_instance="instance2" - ), - SnowflakeDatabaseDataHubId( - database_name="db2", platform_instance="instance3" - ), - ] + shares={ + "share1": SnowflakeShareConfig( + database="db1", + platform_instance="instance2", + consumers=[DatabaseId(database="db1", platform_instance="instance1")], + ), + "share2": SnowflakeShareConfig( + database="db2", + platform_instance="instance1", + consumers=[ + DatabaseId( + database="db2_from_share", platform_instance="instance2" + ), + DatabaseId(database="db2", platform_instance="instance3"), + ], + ), }, ) @@ -325,7 +315,7 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share( config, report, lambda x: make_snowflake_urn(x, "instance1") ) - wus = list(shares_handler.get_workunits(snowflake_databases)) + wus = list(shares_handler.get_shares_workunits(snowflake_databases)) # 6 Sibling and 6 upstreamLineage aspects for db1 tables # 6 Sibling aspects for db2 tables @@ -340,7 +330,7 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share( assert upstream_aspect is not None assert len(upstream_aspect.upstreams) == 1 assert upstream_aspect.upstreams[0].dataset == wu.get_urn().replace( - "instance1.db1", "instance2.original_db1" + "instance1.db1", "instance2.db1" ) else: siblings_aspect = wu.get_aspect_of_type(Siblings) @@ -348,7 +338,7 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share( if "db1" in wu.get_urn(): assert len(siblings_aspect.siblings) == 1 assert siblings_aspect.siblings == [ - wu.get_urn().replace("instance1.db1", "instance2.original_db1") + wu.get_urn().replace("instance1.db1", "instance2.db1") ] else: assert len(siblings_aspect.siblings) == 2 From b33cafd6e379a7beecc5c3ff86b8d552f7b22ebe Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Tue, 1 Aug 2023 18:37:20 +0530 Subject: [PATCH 07/12] fix --- metadata-ingestion/docs/sources/snowflake/snowflake_pre.md | 2 ++ .../src/datahub/ingestion/source/snowflake/snowflake_shares.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md b/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md index d8edc7c2184cb..aa63c7e35537e 100644 --- a/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md +++ b/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md @@ -132,6 +132,8 @@ If you are using [Snowflake Shares](https://docs.snowflake.com/en/user-guide/dat - platform_instance: instance2 # this is a list, as db1 can be shared with multiple snowflake accounts using X database_name: db1_from_X ``` + +- If share X is shared with more snowflake accounts and database is created from share X in those, additional entries need to be added in `consumers` list for share X, one per snowflake account. ### Caveats - Some of the features are only available in the Snowflake Enterprise Edition. This doc has notes mentioning where this applies. diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py index 57011b72b1a86..99ee1b4e4a35f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -67,7 +67,7 @@ def _get_shared_databases( ) else: - shared_databases[share_details.database].shares += share_name + shared_databases[share_details.database].shares.append(share_name) else: for consumer in share_details.consumers: From f19d1ed5334a56c08857e8e8aefeebb62a37aa3a Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Tue, 1 Aug 2023 19:02:07 +0530 Subject: [PATCH 08/12] fix indent --- .../src/datahub/ingestion/source/snowflake/snowflake_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index 99e9f9a62ab16..6ef26590dec41 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -718,8 +718,8 @@ def fetch_schemas_for_database( self.config.match_fully_qualified_names, ): self.report.report_dropped(f"{db_name}.{schema.name}.*") - else: - schemas.append(schema) + else: + schemas.append(schema) except Exception as e: if isinstance(e, SnowflakePermissionError): error_msg = f"Failed to get schemas for database {db_name}. Please check permissions." From 78eba3052ed4fb01ee1d79aae43ce3b2b6773c4c Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Wed, 23 Aug 2023 16:12:19 +0530 Subject: [PATCH 09/12] update doc content, refractor to avoid assers --- .../docs/sources/snowflake/snowflake_pre.md | 33 +++++-------------- .../source/snowflake/snowflake_config.py | 8 ++--- .../source/snowflake/snowflake_shares.py | 16 ++++----- 3 files changed, 19 insertions(+), 38 deletions(-) diff --git a/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md b/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md index aa63c7e35537e..2139321d6f185 100644 --- a/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md +++ b/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md @@ -104,36 +104,19 @@ If you are using [Snowflake Shares](https://docs.snowflake.com/en/user-guide/dat #### Example - Snowflake account `account1` (ingested as platform_instance `instance1`) owns a database `db1`. A share `X` is created in `account1` that includes database `db1` along with schemas and tables inside it. -- Now, `X` is shared with snowflake account `account2` (ingested as platform_instance `instance2`). A database `db1_from_X` is created from inbound share `X` in `account2`. -- In this case, all tables and views included in share `X` will also be present in `instance2`.`db1_from_X`. You would need following configurations in snowflake recipe to setup Siblings and Lineage relationships correctly. -- In snowflake recipe of `account1` : - +- Now, `X` is shared with snowflake account `account2` (ingested as platform_instance `instance2`). A database `db1_from_X` is created from inbound share `X` in `account2`. In this case, all tables and views included in share `X` will also be present in `instance2`.`db1_from_X`. +- This can be represented in `shares` configuration section as ```yaml - account_id: account1 - platform_instance: instance1 shares: - X: - platform_instance: instance1 + X: # name of the share database_name: db1 - consumers: - - platform_instance: instance2 # this is a list, as db1 can be shared with multiple snowflake accounts using X - database_name: db1_from_X - ``` -- In snowflake recipe of `account2` : - - ```yaml - account_id: account2 - platform_instance: instance2 - shares: - X: platform_instance: instance1 - database_name: db1 - consumers: - - platform_instance: instance2 # this is a list, as db1 can be shared with multiple snowflake accounts using X - database_name: db1_from_X + consumers: # list of all databases created from share Xu + - database_name: db1_from_X + platform_instance: instance2 + ``` - -- If share X is shared with more snowflake accounts and database is created from share X in those, additional entries need to be added in `consumers` list for share X, one per snowflake account. +- If share `X` is shared with more snowflake accounts and database is created from share `X` in those account then additional entries need to be added in `consumers` list for share `X`, one per snowflake account. The same `shares` config can then be copied across recipes of all accounts. ### Caveats - Some of the features are only available in the Snowflake Enterprise Edition. This doc has notes mentioning where this applies. diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index d7742c06fab84..597a88d4ed632 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -233,17 +233,17 @@ def validate_shares( ) -> Optional[Dict[str, SnowflakeShareConfig]]: current_platform_instance = values.get("platform_instance") - # Check: platform_instance should be present if shares: + # Check: platform_instance should be present assert current_platform_instance is not None, ( - "Did you forget to set `platform_instance` for current ingestion ?" - "It is advisable to use `platform_instance` when ingesting from multiple snowflake accounts." + "Did you forget to set `platform_instance` for current ingestion ? " + "It is required to use `platform_instance` when ingesting from multiple snowflake accounts." ) databases_included_in_share: List[DatabaseId] = [] databases_created_from_share: List[DatabaseId] = [] - for _, share_details in shares.items(): + for share_details in shares.values(): shared_db = DatabaseId( share_details.database, share_details.platform_instance ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py index 99ee1b4e4a35f..c43da90ea7cf6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from typing import Callable, Dict, Iterable, List, Optional +from typing import Callable, Dict, Iterable, List from datahub.emitter.mce_builder import make_dataset_urn_with_platform_instance from datahub.emitter.mcp import MetadataChangeProposalWrapper @@ -45,16 +45,15 @@ def __init__( dataset_urn_builder: Callable[[str], str], ) -> None: self.config = config + self.shares = self.config.shares or {} + self.platform_instance = self.config.platform_instance or "" self.report = report self.logger = logger self.dataset_urn_builder = dataset_urn_builder def _get_shared_databases( - self, shares: Dict[str, SnowflakeShareConfig], platform_instance: Optional[str] + self, shares: Dict[str, SnowflakeShareConfig], platform_instance: str ) -> Dict[str, SharedDatabase]: - # this is ensured in config validators - assert platform_instance is not None - shared_databases: Dict[str, SharedDatabase] = {} for share_name, share_details in shares.items(): @@ -90,7 +89,7 @@ def get_shares_workunits( self, databases: List[SnowflakeDatabase] ) -> Iterable[MetadataWorkUnit]: shared_databases = self._get_shared_databases( - self.config.shares or {}, self.config.platform_instance + self.shares, self.platform_instance ) # None of the databases are shared @@ -133,10 +132,9 @@ def get_shares_workunits( self.report_missing_databases(databases, shared_databases) def get_sibling_databases(self, db: SharedDatabase) -> List[DatabaseId]: - assert self.config.shares is not None sibling_dbs: List[DatabaseId] = [] if db.created_from_share: - share_details = self.config.shares[db.shares[0]] + share_details = self.shares[db.shares[0]] logger.debug( f"database {db.name} is created from inbound share {db.shares[0]}." ) @@ -151,7 +149,7 @@ def get_sibling_databases(self, db: SharedDatabase) -> List[DatabaseId]: sibling_dbs = [ consumer for share_name in db.shares - for consumer in self.config.shares[share_name].consumers + for consumer in self.shares[share_name].consumers ] return sibling_dbs From 1a772455cc0333b668a8d105e15dda8726f3f105 Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Wed, 23 Aug 2023 16:46:54 +0530 Subject: [PATCH 10/12] simplification refractor --- .../source/snowflake/snowflake_config.py | 41 +++++++ .../source/snowflake/snowflake_shares.py | 108 ++++-------------- 2 files changed, 61 insertions(+), 88 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 597a88d4ed632..0b99520e8df6e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -1,4 +1,5 @@ import logging +from collections import defaultdict from dataclasses import dataclass from enum import Enum from typing import Dict, List, Optional, Set, cast @@ -63,6 +64,10 @@ class SnowflakeShareConfig(ConfigModel): description="List of databases created in consumer accounts." ) + @property + def source_database(self) -> DatabaseId: + return DatabaseId(self.database, self.platform_instance) + class SnowflakeV2Config( SnowflakeConfig, @@ -264,3 +269,39 @@ def validate_shares( ), "Same database can not be present as consumer in more than one share." return shares + + def outbounds(self) -> Dict[str, Set[DatabaseId]]: + """ + Returns mapping of + database included in current account's outbound share -> all databases created from this share in other accounts + """ + outbounds: Dict[str, Set[DatabaseId]] = defaultdict(set) + if self.shares: + for share_name, share_details in self.shares.items(): + if share_details.platform_instance == self.platform_instance: + logger.debug( + f"database {share_details.database} is included in outbound share(s) {share_name}." + ) + outbounds[share_details.database].update(share_details.consumers) + return outbounds + + def inbounds(self) -> Dict[str, DatabaseId]: + """ + Returns mapping of + database created from an current account's inbound share -> other-account database from which this share was created + """ + inbounds: Dict[str, DatabaseId] = {} + if self.shares: + for share_name, share_details in self.shares.items(): + for consumer in share_details.consumers: + if consumer.platform_instance == self.platform_instance: + logger.debug( + f"database {consumer.database} is created from inbound share {share_name}." + ) + inbounds[consumer.database] = share_details.source_database + break + else: + logger.info( + f"Skipping Share {share_name}, as it does not include current platform instance {self.platform_instance}", + ) + return inbounds diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py index c43da90ea7cf6..849325ebb4aef 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -1,13 +1,11 @@ import logging -from dataclasses import dataclass -from typing import Callable, Dict, Iterable, List +from typing import Callable, Iterable, List from datahub.emitter.mce_builder import make_dataset_urn_with_platform_instance from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.snowflake.snowflake_config import ( DatabaseId, - SnowflakeShareConfig, SnowflakeV2Config, ) from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report @@ -23,20 +21,6 @@ logger: logging.Logger = logging.getLogger(__name__) -@dataclass -class SharedDatabase: - """ - Represents shared database from current platform instance - This is either created from an inbound share or included in an outbound share. - """ - - name: str - created_from_share: bool - - # This will have exactly entry if created_from_share = True - shares: List[str] - - class SnowflakeSharesHandler(SnowflakeCommonMixin): def __init__( self, @@ -45,64 +29,32 @@ def __init__( dataset_urn_builder: Callable[[str], str], ) -> None: self.config = config - self.shares = self.config.shares or {} - self.platform_instance = self.config.platform_instance or "" self.report = report self.logger = logger self.dataset_urn_builder = dataset_urn_builder - def _get_shared_databases( - self, shares: Dict[str, SnowflakeShareConfig], platform_instance: str - ) -> Dict[str, SharedDatabase]: - shared_databases: Dict[str, SharedDatabase] = {} - - for share_name, share_details in shares.items(): - if share_details.platform_instance == platform_instance: - if share_details.database not in shared_databases: - shared_databases[share_details.database] = SharedDatabase( - name=share_details.database, - created_from_share=False, - shares=[share_name], - ) - - else: - shared_databases[share_details.database].shares.append(share_name) - - else: - for consumer in share_details.consumers: - if consumer.platform_instance == platform_instance: - shared_databases[consumer.database] = SharedDatabase( - name=share_details.database, - created_from_share=True, - shares=[share_name], - ) - break - else: - self.report_warning( - f"Skipping Share, as it does not include current platform instance {platform_instance}", - share_name, - ) - - return shared_databases - def get_shares_workunits( self, databases: List[SnowflakeDatabase] ) -> Iterable[MetadataWorkUnit]: - shared_databases = self._get_shared_databases( - self.shares, self.platform_instance - ) - + inbounds = self.config.inbounds() + outbounds = self.config.outbounds() # None of the databases are shared - if not shared_databases: + if not (inbounds or outbounds): return logger.debug("Checking databases for inbound or outbound shares.") for db in databases: - if db.name not in shared_databases: + db.name = db.name + is_inbound = db.name in inbounds + is_outbound = db.name in outbounds + + if not (is_inbound or is_outbound): logger.debug(f"database {db.name} is not shared.") continue - sibling_dbs = self.get_sibling_databases(shared_databases[db.name]) + sibling_dbs = ( + list(outbounds[db.name]) if is_outbound else [inbounds[db.name]] + ) for schema in db.schemas: for table_name in schema.tables + schema.views: @@ -116,11 +68,11 @@ def get_shares_workunits( db.name, schema.name, table_name, - not shared_databases[db.name].created_from_share, + is_outbound, sibling_dbs, ) - if shared_databases[db.name].created_from_share: + if is_inbound: assert len(sibling_dbs) == 1 # SnowflakeLineageExtractor is unaware of database->schema->table hierarchy # hence this lineage code is not written in SnowflakeLineageExtractor @@ -129,38 +81,18 @@ def get_shares_workunits( db.name, schema.name, table_name, sibling_dbs[0] ) - self.report_missing_databases(databases, shared_databases) - - def get_sibling_databases(self, db: SharedDatabase) -> List[DatabaseId]: - sibling_dbs: List[DatabaseId] = [] - if db.created_from_share: - share_details = self.shares[db.shares[0]] - logger.debug( - f"database {db.name} is created from inbound share {db.shares[0]}." - ) - sibling_dbs = [ - DatabaseId(share_details.database, share_details.platform_instance) - ] - - else: # not created from share, but is in fact included in share - logger.debug( - f"database {db.name} is included as outbound share(s) {db.shares}." - ) - sibling_dbs = [ - consumer - for share_name in db.shares - for consumer in self.shares[share_name].consumers - ] - - return sibling_dbs + self.report_missing_databases( + databases, list(inbounds.keys()), list(outbounds.keys()) + ) def report_missing_databases( self, databases: List[SnowflakeDatabase], - shared_databases: Dict[str, SharedDatabase], + inbounds: List[str], + outbounds: List[str], ) -> None: db_names = [db.name for db in databases] - missing_dbs = [db for db in shared_databases.keys() if db not in db_names] + missing_dbs = [db for db in inbounds + outbounds if db not in db_names] if missing_dbs: self.report_warning( From afbe095ab59a3137a479a40212ad15a6d078efc6 Mon Sep 17 00:00:00 2001 From: Mayuri Nehate <33225191+mayurinehate@users.noreply.github.com> Date: Wed, 23 Aug 2023 20:04:51 +0530 Subject: [PATCH 11/12] Update snowflake_pre.md --- metadata-ingestion/docs/sources/snowflake/snowflake_pre.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md b/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md index 2139321d6f185..75bd579417a48 100644 --- a/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md +++ b/metadata-ingestion/docs/sources/snowflake/snowflake_pre.md @@ -111,7 +111,7 @@ If you are using [Snowflake Shares](https://docs.snowflake.com/en/user-guide/dat X: # name of the share database_name: db1 platform_instance: instance1 - consumers: # list of all databases created from share Xu + consumers: # list of all databases created from share X - database_name: db1_from_X platform_instance: instance2 From 02585c4ecdd234c8d95a874e2f14d29433bdba60 Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Wed, 23 Aug 2023 21:18:51 +0530 Subject: [PATCH 12/12] remove no-op --- .../src/datahub/ingestion/source/snowflake/snowflake_shares.py | 1 - 1 file changed, 1 deletion(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py index 849325ebb4aef..6f7520bbf1988 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -44,7 +44,6 @@ def get_shares_workunits( logger.debug("Checking databases for inbound or outbound shares.") for db in databases: - db.name = db.name is_inbound = db.name in inbounds is_outbound = db.name in outbounds