diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 0bc8bb17934f7..95f6444384408 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -51,15 +51,17 @@ class DatabaseId: database: str = Field( description="Database created from share in consumer account." ) - platform_instance: str = Field( - description="Platform instance of consumer snowflake account." + platform_instance: Optional[str] = Field( + default=None, + description="Platform instance of consumer snowflake account.", ) class SnowflakeShareConfig(ConfigModel): database: str = Field(description="Database from which share is created.") - platform_instance: str = Field( - description="Platform instance for snowflake account in which share is created." + platform_instance: Optional[str] = Field( + default=None, + description="Platform instance for snowflake account in which share is created.", ) consumers: Set[DatabaseId] = Field( @@ -247,10 +249,11 @@ def validate_shares( if shares: # Check: platform_instance should be present - assert current_platform_instance is not None, ( - "Did you forget to set `platform_instance` for current ingestion ? " - "It is required to use `platform_instance` when ingesting from multiple snowflake accounts." - ) + if current_platform_instance is None: + logger.info( + "It is advisable to use `platform_instance` when ingesting from multiple snowflake accounts, if they contain databases with same name. " + "Setting `platform_instance` allows distinguishing such databases without conflict and correctly ingest their metadata." + ) databases_included_in_share: List[DatabaseId] = [] databases_created_from_share: List[DatabaseId] = [] @@ -259,10 +262,11 @@ def validate_shares( shared_db = DatabaseId( share_details.database, share_details.platform_instance ) - assert all( - consumer.platform_instance != share_details.platform_instance - for consumer in share_details.consumers - ), "Share's platform_instance can not be same as consumer's platform instance. Self-sharing not supported in Snowflake." + if current_platform_instance: + assert all( + consumer.platform_instance != share_details.platform_instance + for consumer in share_details.consumers + ), "Share's platform_instance can not be same as consumer's platform instance. Self-sharing not supported in Snowflake." databases_included_in_share.append(shared_db) databases_created_from_share.extend(share_details.consumers) @@ -306,7 +310,11 @@ def inbounds(self) -> Dict[str, DatabaseId]: f"database {consumer.database} is created from inbound share {share_name}." ) inbounds[consumer.database] = share_details.source_database - break + if self.platform_instance: + break + # If not using platform_instance, any one of consumer databases + # can be the database from this instance. so we include all relevant + # databases in inbounds. else: logger.info( f"Skipping Share {share_name}, as it does not include current platform instance {self.platform_instance}", diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py index 6f7520bbf1988..dad0ce7b59ee1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -93,11 +93,15 @@ def report_missing_databases( db_names = [db.name for db in databases] missing_dbs = [db for db in inbounds + outbounds if db not in db_names] - if missing_dbs: + if missing_dbs and self.config.platform_instance: self.report_warning( "snowflake-shares", f"Databases {missing_dbs} were not ingested. Siblings/Lineage will not be set for these.", ) + elif missing_dbs: + logger.debug( + f"Databases {missing_dbs} were not ingested in this recipe.", + ) def gen_siblings( self, diff --git a/metadata-ingestion/tests/unit/test_snowflake_shares.py b/metadata-ingestion/tests/unit/test_snowflake_shares.py index 7de86139baf39..9e33ba6132e06 100644 --- a/metadata-ingestion/tests/unit/test_snowflake_shares.py +++ b/metadata-ingestion/tests/unit/test_snowflake_shares.py @@ -231,6 +231,7 @@ def test_snowflake_shares_workunit_inbound_share( else: siblings_aspect = wu.get_aspect_of_type(Siblings) assert siblings_aspect is not None + assert not siblings_aspect.primary assert len(siblings_aspect.siblings) == 1 assert siblings_aspect.siblings == [ wu.get_urn().replace("instance1.db1", "instance2.db1") @@ -275,6 +276,7 @@ def test_snowflake_shares_workunit_outbound_share( for wu in wus: siblings_aspect = wu.get_aspect_of_type(Siblings) assert siblings_aspect is not None + assert siblings_aspect.primary assert len(siblings_aspect.siblings) == 2 assert siblings_aspect.siblings == [ wu.get_urn().replace("instance1.db2", "instance2.db2_from_share"), @@ -336,13 +338,85 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share( siblings_aspect = wu.get_aspect_of_type(Siblings) assert siblings_aspect is not None if "db1" in wu.get_urn(): + assert not siblings_aspect.primary assert len(siblings_aspect.siblings) == 1 assert siblings_aspect.siblings == [ wu.get_urn().replace("instance1.db1", "instance2.db1") ] else: + assert siblings_aspect.primary assert len(siblings_aspect.siblings) == 2 assert siblings_aspect.siblings == [ wu.get_urn().replace("instance1.db2", "instance2.db2_from_share"), wu.get_urn().replace("instance1.db2", "instance3.db2"), ] + + +def test_snowflake_shares_workunit_inbound_and_outbound_share_no_platform_instance( + snowflake_databases: List[SnowflakeDatabase], +) -> None: + config = SnowflakeV2Config( + account_id="abc12345", + shares={ + "share1": SnowflakeShareConfig( + database="db1", + consumers=[ + DatabaseId(database="db1_from_share"), + DatabaseId(database="db1_other"), + ], + ), + "share2": SnowflakeShareConfig( + database="db2_main", + consumers=[ + DatabaseId(database="db2"), + DatabaseId(database="db2_other"), + ], + ), + }, + ) + + report = SnowflakeV2Report() + shares_handler = SnowflakeSharesHandler( + config, report, lambda x: make_snowflake_urn(x) + ) + + assert sorted(config.outbounds().keys()) == ["db1", "db2_main"] + assert sorted(config.inbounds().keys()) == [ + "db1_from_share", + "db1_other", + "db2", + "db2_other", + ] + wus = list(shares_handler.get_shares_workunits(snowflake_databases)) + + # 6 Sibling aspects for db1 tables + # 6 Sibling aspects and and 6 upstreamLineage for db2 tables + assert len(wus) == 18 + + for wu in wus: + assert isinstance( + wu.metadata, (MetadataChangeProposal, MetadataChangeProposalWrapper) + ) + if wu.metadata.aspectName == "upstreamLineage": + upstream_aspect = wu.get_aspect_of_type(UpstreamLineage) + assert upstream_aspect is not None + assert len(upstream_aspect.upstreams) == 1 + assert upstream_aspect.upstreams[0].dataset == wu.get_urn().replace( + "db2.", "db2_main." + ) + else: + siblings_aspect = wu.get_aspect_of_type(Siblings) + assert siblings_aspect is not None + if "db1" in wu.get_urn(): + assert siblings_aspect.primary + assert len(siblings_aspect.siblings) == 2 + assert siblings_aspect.siblings == [ + wu.get_urn().replace("db1.", "db1_from_share."), + wu.get_urn().replace("db1.", "db1_other."), + ] + else: + assert not siblings_aspect.primary + assert len(siblings_aspect.siblings) == 1 + assert siblings_aspect.siblings == [ + wu.get_urn().replace("db2.", "db2_main.") + ]