Skip to content

Commit

Permalink
Merge branch 'master' into add-GroupMembershipFieldResolverProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
amanda-her committed Sep 25, 2023
2 parents 4981c39 + 874109f commit 6681e9b
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,17 @@ class DatabaseId:
database: str = Field(
description="Database created from share in consumer account."
)
platform_instance: str = Field(
description="Platform instance of consumer snowflake account."
platform_instance: Optional[str] = Field(
default=None,
description="Platform instance of consumer snowflake account.",
)


class SnowflakeShareConfig(ConfigModel):
database: str = Field(description="Database from which share is created.")
platform_instance: str = Field(
description="Platform instance for snowflake account in which share is created."
platform_instance: Optional[str] = Field(
default=None,
description="Platform instance for snowflake account in which share is created.",
)

consumers: Set[DatabaseId] = Field(
Expand Down Expand Up @@ -247,10 +249,11 @@ def validate_shares(

if shares:
# Check: platform_instance should be present
assert current_platform_instance is not None, (
"Did you forget to set `platform_instance` for current ingestion ? "
"It is required to use `platform_instance` when ingesting from multiple snowflake accounts."
)
if current_platform_instance is None:
logger.info(
"It is advisable to use `platform_instance` when ingesting from multiple snowflake accounts, if they contain databases with same name. "
"Setting `platform_instance` allows distinguishing such databases without conflict and correctly ingest their metadata."
)

databases_included_in_share: List[DatabaseId] = []
databases_created_from_share: List[DatabaseId] = []
Expand All @@ -259,10 +262,11 @@ def validate_shares(
shared_db = DatabaseId(
share_details.database, share_details.platform_instance
)
assert all(
consumer.platform_instance != share_details.platform_instance
for consumer in share_details.consumers
), "Share's platform_instance can not be same as consumer's platform instance. Self-sharing not supported in Snowflake."
if current_platform_instance:
assert all(
consumer.platform_instance != share_details.platform_instance
for consumer in share_details.consumers
), "Share's platform_instance can not be same as consumer's platform instance. Self-sharing not supported in Snowflake."

databases_included_in_share.append(shared_db)
databases_created_from_share.extend(share_details.consumers)
Expand Down Expand Up @@ -306,7 +310,11 @@ def inbounds(self) -> Dict[str, DatabaseId]:
f"database {consumer.database} is created from inbound share {share_name}."
)
inbounds[consumer.database] = share_details.source_database
break
if self.platform_instance:
break
# If not using platform_instance, any one of consumer databases
# can be the database from this instance. so we include all relevant
# databases in inbounds.
else:
logger.info(
f"Skipping Share {share_name}, as it does not include current platform instance {self.platform_instance}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,15 @@ def report_missing_databases(
db_names = [db.name for db in databases]
missing_dbs = [db for db in inbounds + outbounds if db not in db_names]

if missing_dbs:
if missing_dbs and self.config.platform_instance:
self.report_warning(
"snowflake-shares",
f"Databases {missing_dbs} were not ingested. Siblings/Lineage will not be set for these.",
)
elif missing_dbs:
logger.debug(
f"Databases {missing_dbs} were not ingested in this recipe.",
)

def gen_siblings(
self,
Expand Down
74 changes: 74 additions & 0 deletions metadata-ingestion/tests/unit/test_snowflake_shares.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def test_snowflake_shares_workunit_inbound_share(
else:
siblings_aspect = wu.get_aspect_of_type(Siblings)
assert siblings_aspect is not None
assert not siblings_aspect.primary
assert len(siblings_aspect.siblings) == 1
assert siblings_aspect.siblings == [
wu.get_urn().replace("instance1.db1", "instance2.db1")
Expand Down Expand Up @@ -275,6 +276,7 @@ def test_snowflake_shares_workunit_outbound_share(
for wu in wus:
siblings_aspect = wu.get_aspect_of_type(Siblings)
assert siblings_aspect is not None
assert siblings_aspect.primary
assert len(siblings_aspect.siblings) == 2
assert siblings_aspect.siblings == [
wu.get_urn().replace("instance1.db2", "instance2.db2_from_share"),
Expand Down Expand Up @@ -336,13 +338,85 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share(
siblings_aspect = wu.get_aspect_of_type(Siblings)
assert siblings_aspect is not None
if "db1" in wu.get_urn():
assert not siblings_aspect.primary
assert len(siblings_aspect.siblings) == 1
assert siblings_aspect.siblings == [
wu.get_urn().replace("instance1.db1", "instance2.db1")
]
else:
assert siblings_aspect.primary
assert len(siblings_aspect.siblings) == 2
assert siblings_aspect.siblings == [
wu.get_urn().replace("instance1.db2", "instance2.db2_from_share"),
wu.get_urn().replace("instance1.db2", "instance3.db2"),
]


def test_snowflake_shares_workunit_inbound_and_outbound_share_no_platform_instance(
snowflake_databases: List[SnowflakeDatabase],
) -> None:
config = SnowflakeV2Config(
account_id="abc12345",
shares={
"share1": SnowflakeShareConfig(
database="db1",
consumers=[
DatabaseId(database="db1_from_share"),
DatabaseId(database="db1_other"),
],
),
"share2": SnowflakeShareConfig(
database="db2_main",
consumers=[
DatabaseId(database="db2"),
DatabaseId(database="db2_other"),
],
),
},
)

report = SnowflakeV2Report()
shares_handler = SnowflakeSharesHandler(
config, report, lambda x: make_snowflake_urn(x)
)

assert sorted(config.outbounds().keys()) == ["db1", "db2_main"]
assert sorted(config.inbounds().keys()) == [
"db1_from_share",
"db1_other",
"db2",
"db2_other",
]
wus = list(shares_handler.get_shares_workunits(snowflake_databases))

# 6 Sibling aspects for db1 tables
# 6 Sibling aspects and and 6 upstreamLineage for db2 tables
assert len(wus) == 18

for wu in wus:
assert isinstance(
wu.metadata, (MetadataChangeProposal, MetadataChangeProposalWrapper)
)
if wu.metadata.aspectName == "upstreamLineage":
upstream_aspect = wu.get_aspect_of_type(UpstreamLineage)
assert upstream_aspect is not None
assert len(upstream_aspect.upstreams) == 1
assert upstream_aspect.upstreams[0].dataset == wu.get_urn().replace(
"db2.", "db2_main."
)
else:
siblings_aspect = wu.get_aspect_of_type(Siblings)
assert siblings_aspect is not None
if "db1" in wu.get_urn():
assert siblings_aspect.primary
assert len(siblings_aspect.siblings) == 2
assert siblings_aspect.siblings == [
wu.get_urn().replace("db1.", "db1_from_share."),
wu.get_urn().replace("db1.", "db1_other."),
]
else:
assert not siblings_aspect.primary
assert len(siblings_aspect.siblings) == 1
assert siblings_aspect.siblings == [
wu.get_urn().replace("db2.", "db2_main.")
]

0 comments on commit 6681e9b

Please sign in to comment.