diff --git a/.github/workflows/datahub-actions-docker.yml b/.github/workflows/datahub-actions-docker.yml index 4fe4c5e5..f9ab3c57 100644 --- a/.github/workflows/datahub-actions-docker.yml +++ b/.github/workflows/datahub-actions-docker.yml @@ -43,10 +43,17 @@ jobs: env: ENABLE_PUBLISH: ${{ secrets.ACRYL_DOCKER_PASSWORD }} run: | - echo "Enable publish: ${{ env.ENABLE_PUBLISH != '' }}" - echo "publish=${{ env.ENABLE_PUBLISH != '' }}" >> "$GITHUB_OUTPUT" + if [[ -n "$ENABLE_PUBLISH" && "${{ github.repository }}" == "acryldata/datahub-actions" ]]; then + echo "Publishing is enabled" + echo "publish=true" >> "$GITHUB_OUTPUT" + else + echo "Publishing is not enabled" + echo "publish=false" >> "$GITHUB_OUTPUT" + fi + echo "Enable publish: ${{ env.ENABLE_PUBLISH != '' && github.repository == 'acryldata/datahub-actions' }}" regular_image: name: Build & Push Image to DockerHub + if: ${{ needs.setup.outputs.publish == 'true' }} # Only build the regular image if publishing is enabled runs-on: ubuntu-latest needs: setup steps: @@ -180,11 +187,9 @@ jobs: - name: Load Docker image (if not publishing) if: needs.setup.outputs.publish != 'true' run: docker load < image.tar - - name: Download image (if publishing) + - name: Pull Docker image (if publishing) if: needs.setup.outputs.publish == 'true' - uses: ishworkh/docker-image-artifact-download@v1 - with: - image: acryldata/datahub-actions-slim:${{ needs.setup.outputs.unique_tag }} + run: docker pull acryldata/datahub-actions-slim:${{ needs.setup.outputs.unique_tag }} - name: Run Trivy vulnerability scanner (slim) uses: aquasecurity/trivy-action@master env: @@ -198,7 +203,7 @@ jobs: ignore-unfixed: true vuln-type: 'os,library' - name: Upload Trivy scan results to GitHub Security tab (slim) - uses: github/codeql-action/upload-sarif@v2 + uses: github/codeql-action/upload-sarif@v3 with: sarif_file: 'trivy-results.sarif' smoke_test: @@ -233,11 +238,9 @@ jobs: - name: Load Docker image (if not publishing) if: needs.setup.outputs.publish != 'true' run: docker load < image.tar - - name: Download image (if publishing) + - name: Pull Docker image (if publishing) if: needs.setup.outputs.publish == 'true' - uses: ishworkh/docker-image-artifact-download@v1 - with: - image: acryldata/datahub-actions-slim:${{ needs.setup.outputs.unique_tag }} + run: docker pull acryldata/datahub-actions-slim:${{ needs.setup.outputs.unique_tag }} - name: run quickstart env: DATAHUB_TELEMETRY_ENABLED: false diff --git a/datahub-actions/src/datahub_actions/api/action_graph.py b/datahub-actions/src/datahub_actions/api/action_graph.py index 7186a982..680a2073 100644 --- a/datahub-actions/src/datahub_actions/api/action_graph.py +++ b/datahub-actions/src/datahub_actions/api/action_graph.py @@ -176,25 +176,51 @@ def query_ingestion_sources(self) -> List: break return sources - def get_downstreams(self, entity_urn: str) -> List[str]: - url_frag = f"/relationships?direction=INCOMING&types=List(DownstreamOf)&urn={urllib.parse.quote(entity_urn)}" - url = f"{self.graph._gms_server}{url_frag}" - response = self.graph._get_generic(url) - if response["count"] > 0: - relnships = response["relationships"] - entities = [x["entity"] for x in relnships] - return entities - return [] + def get_downstreams( + self, entity_urn: str, max_downstreams: int = 3000 + ) -> List[str]: + start = 0 + count_per_page = 1000 + entities = [] + done = False + total_downstreams = 0 + while not done: + # if start > 0: + # breakpoint() + url_frag = f"/relationships?direction=INCOMING&types=List(DownstreamOf)&urn={urllib.parse.quote(entity_urn)}&count={count_per_page}&start={start}" + url = f"{self.graph._gms_server}{url_frag}" + response = self.graph._get_generic(url) + if response["count"] > 0: + relnships = response["relationships"] + entities.extend([x["entity"] for x in relnships]) + start += count_per_page + total_downstreams += response["count"] + if start >= response["total"] or total_downstreams >= max_downstreams: + done = True + else: + done = True + return entities - def get_upstreams(self, entity_urn: str) -> List[str]: - url_frag = f"/relationships?direction=OUTGOING&types=List(DownstreamOf)&urn={urllib.parse.quote(entity_urn)}" - url = f"{self.graph._gms_server}{url_frag}" - response = self.graph._get_generic(url) - if response["count"] > 0: - relnships = response["relationships"] - entities = [x["entity"] for x in relnships] - return entities - return [] + def get_upstreams(self, entity_urn: str, max_upstreams: int = 3000) -> List[str]: + start = 0 + count_per_page = 100 + entities = [] + done = False + total_upstreams = 0 + while not done: + url_frag = f"/relationships?direction=OUTGOING&types=List(DownstreamOf)&urn={urllib.parse.quote(entity_urn)}&count={count_per_page}&start={start}" + url = f"{self.graph._gms_server}{url_frag}" + response = self.graph._get_generic(url) + if response["count"] > 0: + relnships = response["relationships"] + entities.extend([x["entity"] for x in relnships]) + start += count_per_page + total_upstreams += response["count"] + if start >= response["total"] or total_upstreams >= max_upstreams: + done = True + else: + done = True + return entities def get_relationships( self, entity_urn: str, direction: str, relationship_types: List[str] diff --git a/datahub-actions/src/datahub_actions/plugin/action/propagation/docs/propagation_action.py b/datahub-actions/src/datahub_actions/plugin/action/propagation/docs/propagation_action.py index 109597a2..d5b885dc 100644 --- a/datahub-actions/src/datahub_actions/plugin/action/propagation/docs/propagation_action.py +++ b/datahub-actions/src/datahub_actions/plugin/action/propagation/docs/propagation_action.py @@ -15,9 +15,9 @@ import json import logging import time -from typing import Any, Iterable, Optional +from enum import Enum +from typing import Iterable, List, Optional, Tuple -from datahub.configuration.common import ConfigModel from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.metadata.schema_classes import ( AuditStampClass, @@ -32,8 +32,8 @@ MetadataChangeLogClass, ) from datahub.metadata.urns import DatasetUrn -from datahub.utilities.urns.urn import Urn -from pydantic import BaseModel, Field, validator +from datahub.utilities.urns.urn import Urn, guess_entity_type +from pydantic import Field from datahub_actions.action.action import Action from datahub_actions.api.action_graph import AcrylDataHubGraph @@ -41,6 +41,11 @@ from datahub_actions.pipeline.pipeline_context import PipelineContext from datahub_actions.plugin.action.mcl_utils import MCLProcessor from datahub_actions.plugin.action.propagation.propagation_utils import ( + DirectionType, + PropagationConfig, + PropagationDirective, + RelationshipType, + SourceDetails, get_unique_siblings, ) from datahub_actions.plugin.action.stats_util import ( @@ -51,58 +56,19 @@ logger = logging.getLogger(__name__) -class DocPropagationDirective(BaseModel): - propagate: bool = Field( - description="Indicates whether the documentation should be propagated." - ) +class DocPropagationDirective(PropagationDirective): doc_string: Optional[str] = Field( default=None, description="Documentation string to be propagated." ) - operation: str = Field( - description="Operation to be performed on the documentation. Can be ADD, MODIFY or REMOVE." - ) - entity: str = Field( - description="Entity URN from which the documentation is propagated. This will either be the same as the origin or the via entity, depending on the propagation path." - ) - origin: str = Field( - description="Origin entity for the documentation. This is the entity that triggered the documentation propagation.", - ) - via: Optional[str] = Field( - None, - description="Via entity for the documentation. This is the direct entity that the documentation was propagated through.", - ) - actor: Optional[str] = Field( - None, - description="Actor that triggered the documentation propagation.", - ) -class SourceDetails(BaseModel): - origin: Optional[str] = Field( - None, - description="Origin entity for the documentation. This is the entity that triggered the documentation propagation.", - ) - via: Optional[str] = Field( - None, - description="Via entity for the documentation. This is the direct entity that the documentation was propagated through.", - ) - propagated: Optional[str] = Field( - None, - description="Indicates whether the documentation was propagated.", - ) - actor: Optional[str] = Field( - None, - description="Actor that triggered the documentation propagation.", - ) - - @validator("propagated", pre=True) - def convert_boolean_to_lowercase_string(cls, v: Any) -> Optional[str]: - if isinstance(v, bool): - return str(v).lower() - return v +class ColumnPropagationRelationships(str, Enum): + UPSTREAM = "upstream" + DOWNSTREAM = "downstream" + SIBLING = "sibling" -class DocPropagationConfig(ConfigModel): +class DocPropagationConfig(PropagationConfig): """ Configuration model for documentation propagation. @@ -131,6 +97,19 @@ class DocPropagationConfig(ConfigModel): description="Indicates whether dataset level documentation propagation is enabled or not.", example=False, ) + column_propagation_relationships: List[ColumnPropagationRelationships] = Field( + [ + ColumnPropagationRelationships.SIBLING, + ColumnPropagationRelationships.DOWNSTREAM, + ColumnPropagationRelationships.UPSTREAM, + ], + description="Relationships for column documentation propagation.", + example=[ + ColumnPropagationRelationships.UPSTREAM, + ColumnPropagationRelationships.SIBLING, + ColumnPropagationRelationships.DOWNSTREAM, + ], + ) def get_field_path(schema_field_urn: str) -> str: @@ -182,6 +161,10 @@ def __init__(self, config: DocPropagationConfig, ctx: PipelineContext): self.refresh_config() self._stats = ActionStageReport() self._stats.start() + assert self.ctx.graph + self._rate_limited_emit_mcp = self.config.get_rate_limited_emit_mcp( + self.ctx.graph.graph + ) def name(self) -> str: return "DocPropagator" @@ -192,6 +175,67 @@ def create(cls, config_dict: dict, ctx: PipelineContext) -> "Action": logger.info(f"Doc Propagation Config action configured with {action_config}") return cls(action_config, ctx) + def should_stop_propagation( + self, source_details: SourceDetails + ) -> Tuple[bool, str]: + """ + Check if the propagation should be stopped based on the source details. + Return result and reason. + """ + if source_details.propagation_started_at and ( + int(time.time() * 1000.0) - source_details.propagation_started_at + >= self.config.max_propagation_time_millis + ): + return (True, "Propagation time exceeded.") + if ( + source_details.propagation_depth + and source_details.propagation_depth >= self.config.max_propagation_depth + ): + return (True, "Propagation depth exceeded.") + return False, "" + + def get_propagation_relationships( + self, entity_type: str, source_details: Optional[SourceDetails] + ) -> List[Tuple[RelationshipType, DirectionType]]: + + possible_relationships = [] + if entity_type == "schemaField": + if (source_details is not None) and ( + source_details.propagation_relationship + and source_details.propagation_direction + ): + restricted_relationship = source_details.propagation_relationship + restricted_direction = source_details.propagation_direction + else: + restricted_relationship = None + restricted_direction = None + + for relationship in self.config.column_propagation_relationships: + if relationship == ColumnPropagationRelationships.UPSTREAM: + if ( + restricted_relationship == RelationshipType.LINEAGE + and restricted_direction == DirectionType.DOWN + ): # Skip upstream if the propagation has been restricted to downstream + continue + possible_relationships.append( + (RelationshipType.LINEAGE, DirectionType.UP) + ) + elif relationship == ColumnPropagationRelationships.DOWNSTREAM: + if ( + restricted_relationship == RelationshipType.LINEAGE + and restricted_direction == DirectionType.UP + ): # Skip upstream if the propagation has been restricted to downstream + continue + possible_relationships.append( + (RelationshipType.LINEAGE, DirectionType.DOWN) + ) + elif relationship == ColumnPropagationRelationships.SIBLING: + possible_relationships.append( + (RelationshipType.SIBLING, DirectionType.ALL) + ) + logger.debug(f"Possible relationships: {possible_relationships}") + return possible_relationships + def process_schema_field_documentation( self, entity_urn: str, @@ -199,53 +243,111 @@ def process_schema_field_documentation( aspect_value: GenericAspectClass, previous_aspect_value: Optional[GenericAspectClass], ) -> Optional[DocPropagationDirective]: - if aspect_name == "documentation": - logger.debug("Processing 'documentation' MCL") - if self.config.columns_enabled: - current_docs = DocumentationClass.from_obj( - json.loads(aspect_value.value) + """ + Process changes in the documentation aspect of schemaField entities. + Produce a directive to propagate the documentation. + Business Logic checks: + - If the documentation is sourced by this action, then we propagate + it. + - If the documentation is not sourced by this action, then we log a + warning and propagate it. + - If we have exceeded the maximum depth of propagation or maximum + time for propagation, then we stop propagation and don't return a directive. + """ + if ( + aspect_name != "documentation" + or guess_entity_type(entity_urn) != "schemaField" + ): + # not a documentation aspect or not a schemaField entity + return None + + logger.debug("Processing 'documentation' MCL") + if self.config.columns_enabled: + current_docs = DocumentationClass.from_obj(json.loads(aspect_value.value)) + old_docs = ( + None + if previous_aspect_value is None + else DocumentationClass.from_obj( + json.loads(previous_aspect_value.value) ) - old_docs = ( - None - if previous_aspect_value is None - else DocumentationClass.from_obj( - json.loads(previous_aspect_value.value) + ) + if current_docs.documentations: + # get the most recently updated documentation with attribution + current_documentation_instance = sorted( + [doc for doc in current_docs.documentations if doc.attribution], + key=lambda x: x.attribution.time if x.attribution else 0, + )[-1] + assert current_documentation_instance.attribution + if ( + current_documentation_instance.attribution.source is None + or current_documentation_instance.attribution.source + != self.action_urn + ): + logger.warning( + f"Documentation is not sourced by this action which is unexpected. Will be propagating for {entity_urn}" ) + source_details = ( + (current_documentation_instance.attribution.sourceDetail) + if current_documentation_instance.attribution + else {} + ) + source_details_parsed: SourceDetails = SourceDetails.parse_obj( + source_details ) - if current_docs.documentations: - # we assume that the first documentation is the primary one - # we can change this later - current_documentation_instance = current_docs.documentations[0] - source_details = ( - (current_documentation_instance.attribution.sourceDetail) - if current_documentation_instance.attribution - else {} + should_stop_propagation, reason = self.should_stop_propagation( + source_details_parsed + ) + if should_stop_propagation: + logger.warning(f"Stopping propagation for {entity_urn}. {reason}") + return None + else: + logger.debug(f"Propagating documentation for {entity_urn}") + propagation_relationships = self.get_propagation_relationships( + entity_type="schemaField", source_details=source_details_parsed + ) + origin_entity = source_details_parsed.origin + if old_docs is None or not old_docs.documentations: + return DocPropagationDirective( + propagate=True, + doc_string=current_documentation_instance.documentation, + operation="ADD", + entity=entity_urn, + origin=origin_entity, + via=entity_urn, + actor=self.actor_urn, + propagation_started_at=source_details_parsed.propagation_started_at, + propagation_depth=( + source_details_parsed.propagation_depth + 1 + if source_details_parsed.propagation_depth + else 1 + ), + relationships=propagation_relationships, ) - origin_entity = source_details.get("origin") - if old_docs is None or not old_docs.documentations: + else: + old_docs_instance = sorted( + old_docs.documentations, + key=lambda x: x.attribution.time if x.attribution else 0, + )[-1] + if ( + current_documentation_instance.documentation + != old_docs_instance.documentation + ): return DocPropagationDirective( propagate=True, doc_string=current_documentation_instance.documentation, - operation="ADD", + operation="MODIFY", entity=entity_urn, origin=origin_entity, via=entity_urn, actor=self.actor_urn, + propagation_started_at=source_details_parsed.propagation_started_at, + propagation_depth=( + source_details_parsed.propagation_depth + 1 + if source_details_parsed.propagation_depth + else 1 + ), + relationships=propagation_relationships, ) - else: - if ( - current_docs.documentations[0].documentation - != old_docs.documentations[0].documentation - ): - return DocPropagationDirective( - propagate=True, - doc_string=current_documentation_instance.documentation, - operation="MODIFY", - entity=entity_urn, - origin=origin_entity, - via=entity_urn, - actor=self.actor_urn, - ) return None def should_propagate( @@ -256,7 +358,6 @@ def should_propagate( return self.mcl_processor.process(event) if event.event_type == "EntityChangeEvent_v1": assert isinstance(event.event, EntityChangeEvent) - # logger.info(f"Received event {event}") assert self.ctx.graph is not None semantic_event = event.event if ( @@ -299,6 +400,12 @@ def should_propagate( if semantic_event.auditStamp else self.actor_urn ), + propagation_started_at=int(time.time() * 1000.0), + propagation_depth=1, # we start at 1 because this is the first propagation + relationships=self.get_propagation_relationships( + entity_type="schemaField", + source_details=None, + ), ) return None @@ -327,7 +434,7 @@ def modify_docs_on_columns( time=int(time.time() * 1000.0), actor=self.actor_urn ) - source_details = context.dict(exclude_none=True) + source_details = context.for_metadata_attribution() attribution: MetadataAttributionClass = MetadataAttributionClass( source=self.action_urn, time=auditStamp.time, @@ -339,11 +446,16 @@ def modify_docs_on_columns( mutation_needed = False action_sourced = False # we check if there are any existing documentations generated by - # this action, if so, we update them + # this action and sourced from the same origin, if so, we update them # otherwise, we add a new documentation entry sourced by this action - for doc_association in documentations.documentations: + for doc_association in documentations.documentations[:]: if doc_association.attribution and doc_association.attribution.source: - if doc_association.attribution.source == self.action_urn: + source_details_parsed: SourceDetails = SourceDetails.parse_obj( + doc_association.attribution.sourceDetail + ) + if doc_association.attribution.source == self.action_urn and ( + source_details_parsed.origin == context.origin + ): action_sourced = True if doc_association.documentation != field_doc: mutation_needed = True @@ -351,10 +463,7 @@ def modify_docs_on_columns( doc_association.documentation = field_doc or "" doc_association.attribution = attribution elif operation == "REMOVE": - # TODO : should we remove the documentation or just set it to empty string? - # Ideally we remove it - doc_association.documentation = "" - doc_association.attribution = attribution + documentations.documentations.remove(doc_association) if not action_sourced: documentations.documentations.append( DocumentationAssociationClass( @@ -459,12 +568,12 @@ def _only_one_upstream_field( upstreams = graph.get_upstreams(entity_urn=downstream_field) # Use a set here in case there are duplicated upstream edges upstream_fields = list( - {x for x in upstreams if x.startswith("urn:li:schemaField")} + {x for x in upstreams if guess_entity_type(x) == "schemaField"} ) # If we found no upstreams for the downstream field, simply skip. if not upstream_fields: - logger.warning( + logger.debug( f"No upstream fields found. Skipping propagation to downstream {downstream_field}" ) return False @@ -480,7 +589,7 @@ def _only_one_upstream_field( def act(self, event: EventEnvelope) -> None: assert self.ctx.graph for mcp in self.act_async(event): - self.ctx.graph.graph.emit(mcp) + self._rate_limited_emit_mcp(mcp) def act_async( self, event: EventEnvelope @@ -503,6 +612,7 @@ def act_async( try: doc_propagation_directive = self.should_propagate(event) + # breakpoint() logger.debug( f"Doc propagation directive for {event}: {doc_propagation_directive}" ) @@ -517,21 +627,47 @@ def act_async( via=doc_propagation_directive.via, propagated=True, actor=doc_propagation_directive.actor, + propagation_started_at=doc_propagation_directive.propagation_started_at, + propagation_depth=doc_propagation_directive.propagation_depth, ) assert self.ctx.graph - - # TODO: Put each mechanism behind a config flag to be controlled externally. - - # Step 1: Propagate to downstream entities - yield from self._propagate_to_downstreams( - doc_propagation_directive, context - ) - - # Step 2: Propagate to sibling entities - yield from self._propagate_to_siblings( - doc_propagation_directive, context + logger.debug(f"Doc Propagation Directive: {doc_propagation_directive}") + # TODO: Put each mechanism behind a config flag to be controlled + # externally. + lineage_downstream = ( + RelationshipType.LINEAGE, + DirectionType.DOWN, + ) in doc_propagation_directive.relationships + lineage_upstream = ( + RelationshipType.LINEAGE, + DirectionType.UP, + ) in doc_propagation_directive.relationships + lineage_any = ( + RelationshipType.LINEAGE, + DirectionType.ALL, + ) in doc_propagation_directive.relationships + logger.debug( + f"Lineage Downstream: {lineage_downstream}, Lineage Upstream: {lineage_upstream}, Lineage Any: {lineage_any}" ) + if lineage_downstream or lineage_any: + # Step 1: Propagate to downstream entities + yield from self._propagate_to_downstreams( + doc_propagation_directive, context + ) + if lineage_upstream or lineage_any: + # Step 2: Propagate to upstream entities + yield from self._propagate_to_upstreams( + doc_propagation_directive, context + ) + if ( + RelationshipType.SIBLING, + DirectionType.ALL, + ) in doc_propagation_directive.relationships: + # Step 3: Propagate to sibling entities + yield from self._propagate_to_siblings( + doc_propagation_directive, context + ) stats.end(event, success=True) except Exception: @@ -552,10 +688,14 @@ def _propagate_to_downstreams( f"Downstreams: {downstreams} for {doc_propagation_directive.entity}" ) entity_urn = doc_propagation_directive.entity - - if entity_urn.startswith("urn:li:schemaField"): + propagated_context = SourceDetails.parse_obj(context.dict()) + propagated_context.propagation_relationship = RelationshipType.LINEAGE + propagated_context.propagation_direction = DirectionType.DOWN + propagated_entities_this_hop_count = 0 + # breakpoint() + if guess_entity_type(entity_urn) == "schemaField": downstream_fields = { - x for x in downstreams if x.startswith("urn:li:schemaField") + x for x in downstreams if guess_entity_type(x) == "schemaField" } for field in downstream_fields: schema_field_urn = Urn.create_from_string(field) @@ -566,31 +706,114 @@ def _propagate_to_downstreams( f"Will {doc_propagation_directive.operation} documentation {doc_propagation_directive.doc_string} for {field_path} on {parent_urn}" ) - if parent_urn.startswith("urn:li:dataset"): + parent_entity_type = guess_entity_type(parent_urn) + + if parent_entity_type == "dataset": if self._only_one_upstream_field( self.ctx.graph, downstream_field=str(schema_field_urn), upstream_field=entity_urn, ): + if ( + propagated_entities_this_hop_count + >= self.config.max_propagation_fanout + ): + # breakpoint() + logger.warning( + f"Exceeded max propagation fanout of {self.config.max_propagation_fanout}. Skipping propagation to downstream {field}" + ) + # No need to propagate to more downstreams + return + maybe_mcp = self.modify_docs_on_columns( self.ctx.graph, doc_propagation_directive.operation, field, parent_urn, field_doc=doc_propagation_directive.doc_string, - context=context, + context=propagated_context, ) if maybe_mcp: + propagated_entities_this_hop_count += 1 yield maybe_mcp - elif parent_urn.startswith("urn:li:chart"): + elif parent_entity_type == "chart": logger.warning( "Charts are expected to have fields that are dataset schema fields. Skipping for now..." ) self._stats.increment_assets_impacted(field) - elif entity_urn.startswith("urn:li:dataset"): + elif guess_entity_type(entity_urn) == "dataset": + logger.debug( + "Dataset level documentation propagation is not yet supported!" + ) + + def _propagate_to_upstreams( + self, doc_propagation_directive: DocPropagationDirective, context: SourceDetails + ) -> Iterable[MetadataChangeProposalWrapper]: + """ + Propagate the documentation to upstream entities. + """ + assert self.ctx.graph + upstreams = self.ctx.graph.get_upstreams( + entity_urn=doc_propagation_directive.entity + ) + logger.debug(f"Upstreams: {upstreams} for {doc_propagation_directive.entity}") + entity_urn = doc_propagation_directive.entity + propagated_context = SourceDetails.parse_obj(context.dict()) + propagated_context.propagation_relationship = RelationshipType.LINEAGE + propagated_context.propagation_direction = DirectionType.UP + propagated_entities_this_hop_count = 0 + + if guess_entity_type(entity_urn) == "schemaField": + upstream_fields = { + x for x in upstreams if guess_entity_type(x) == "schemaField" + } + # We only propagate to the upstream field if there is only one + # upstream field + if len(upstream_fields) == 1: + for field in upstream_fields: + schema_field_urn = Urn.create_from_string(field) + parent_urn = schema_field_urn.get_entity_id()[0] + field_path = schema_field_urn.get_entity_id()[1] + + logger.debug( + f"Will {doc_propagation_directive.operation} documentation {doc_propagation_directive.doc_string} for {field_path} on {parent_urn}" + ) + + parent_entity_type = guess_entity_type(parent_urn) + + if parent_entity_type == "dataset": + if ( + propagated_entities_this_hop_count + >= self.config.max_propagation_fanout + ): + logger.warning( + f"Exceeded max propagation fanout of {self.config.max_propagation_fanout}. Skipping propagation to upstream {field}" + ) + # No need to propagate to more upstreams + return + maybe_mcp = self.modify_docs_on_columns( + self.ctx.graph, + doc_propagation_directive.operation, + field, + parent_urn, + field_doc=doc_propagation_directive.doc_string, + context=propagated_context, + ) + if maybe_mcp: + propagated_entities_this_hop_count += 1 + yield maybe_mcp + + elif parent_entity_type == "chart": + logger.warning( + "Charts are expected to have fields that are dataset schema fields. Skipping for now..." + ) + + self._stats.increment_assets_impacted(field) + + elif guess_entity_type(entity_urn) == "dataset": logger.debug( "Dataset level documentation propagation is not yet supported!" ) @@ -604,12 +827,16 @@ def _propagate_to_siblings( assert self.ctx.graph entity_urn = doc_propagation_directive.entity siblings = get_unique_siblings(self.ctx.graph, entity_urn) + propagated_context = SourceDetails.parse_obj(context.dict()) + propagated_context.propagation_relationship = RelationshipType.SIBLING + propagated_context.propagation_direction = DirectionType.ALL logger.debug(f"Siblings: {siblings} for {doc_propagation_directive.entity}") for sibling in siblings: - if entity_urn.startswith("urn:li:schemaField") and sibling.startswith( - "urn:li:schemaField" + if ( + guess_entity_type(entity_urn) == "schemaField" + and guess_entity_type(sibling) == "schemaField" ): parent_urn = Urn.create_from_string(sibling).get_entity_id()[0] self._stats.increment_assets_impacted(sibling) @@ -619,7 +846,7 @@ def _propagate_to_siblings( schema_field_urn=sibling, dataset_urn=parent_urn, field_doc=doc_propagation_directive.doc_string, - context=context, + context=propagated_context, ) if maybe_mcp: yield maybe_mcp diff --git a/datahub-actions/src/datahub_actions/plugin/action/propagation/propagation_utils.py b/datahub-actions/src/datahub_actions/plugin/action/propagation/propagation_utils.py index b203a952..c27ec602 100644 --- a/datahub-actions/src/datahub_actions/plugin/action/propagation/propagation_utils.py +++ b/datahub-actions/src/datahub_actions/plugin/action/propagation/propagation_utils.py @@ -16,27 +16,33 @@ import time from abc import abstractmethod from enum import Enum -from typing import Dict, Iterable, List, Optional, Tuple +from functools import wraps +from typing import Any, Dict, Iterable, List, Optional, Tuple import datahub.metadata.schema_classes as models +from datahub.configuration.common import ConfigModel from datahub.emitter.mce_builder import make_schema_field_urn -from datahub.ingestion.graph.client import SearchFilterRule +from datahub.ingestion.graph.client import DataHubGraph, SearchFilterRule from datahub.metadata.schema_classes import MetadataAttributionClass -from datahub.utilities.urns.urn import Urn +from datahub.utilities.str_enum import StrEnum +from datahub.utilities.urns.urn import Urn, guess_entity_type +from pydantic import validator from pydantic.fields import Field from pydantic.main import BaseModel +from ratelimit import limits, sleep_and_retry from datahub_actions.api.action_graph import AcrylDataHubGraph SYSTEM_ACTOR = "urn:li:corpuser:__datahub_system" -class RelationshipType(Enum): +class RelationshipType(StrEnum): LINEAGE = "lineage" # signifies all types of lineage HIERARCHY = "hierarchy" # signifies all types of hierarchy + SIBLING = "sibling" # signifies all types of sibling -class DirectionType(Enum): +class DirectionType(StrEnum): UP = "up" # signifies upstream or parent (depending on relationship type) DOWN = "down" # signifies downstream or child (depending on relationship type) ALL = "all" # signifies all directions @@ -45,8 +51,7 @@ class DirectionType(Enum): class PropagationDirective(BaseModel): propagate: bool operation: str - relationship: RelationshipType = RelationshipType.LINEAGE - direction: DirectionType = DirectionType.UP + relationships: List[Tuple[RelationshipType, DirectionType]] entity: str = Field( description="Entity that currently triggered the propagation directive", ) @@ -61,6 +66,104 @@ class PropagationDirective(BaseModel): None, description="Actor that triggered the propagation through the original association.", ) + propagation_started_at: Optional[int] = Field( + None, + description="Timestamp (in millis) when the original propagation event happened.", + ) + propagation_depth: Optional[int] = Field( + default=0, + description="Depth of propagation. This is used to track the depth of the propagation.", + ) + + +class SourceDetails(BaseModel): + origin: Optional[str] = Field( + None, + description="Origin entity for the documentation. This is the entity that triggered the documentation propagation.", + ) + via: Optional[str] = Field( + None, + description="Via entity for the documentation. This is the direct entity that the documentation was propagated through.", + ) + propagated: Optional[str] = Field( + None, + description="Indicates whether the metadata element was propagated.", + ) + actor: Optional[str] = Field( + None, + description="Actor that triggered the metadata propagation.", + ) + propagation_started_at: Optional[int] = Field( + None, + description="Timestamp when the metadata propagation event happened.", + ) + propagation_depth: Optional[int] = Field( + default=0, + description="Depth of metadata propagation.", + ) + propagation_relationship: Optional[RelationshipType] = Field( + None, + description="The relationship that the metadata was propagated through.", + ) + propagation_direction: Optional[DirectionType] = Field( + None, + description="The direction that the metadata was propagated through.", + ) + + @validator("propagated", pre=True) + def convert_boolean_to_lowercase_string(cls, v: Any) -> Optional[str]: + if isinstance(v, bool): + return str(v).lower() + return v + + @validator("propagation_depth", "propagation_started_at", pre=True) + def convert_to_int(cls, v: Any) -> Optional[int]: + if v is not None: + return int(v) + return v + + def for_metadata_attribution(self) -> Dict[str, str]: + """ + Convert the SourceDetails object to a dictionary that can be used in + Metadata Attribution MCPs. + """ + result = {} + for k, v in self.dict(exclude_none=True).items(): + if isinstance(v, Enum): + result[k] = v.value # Use the enum's value + elif isinstance(v, int): + result[k] = str(v) # Convert int to string + else: + result[k] = str(v) # Convert everything else to string + return result + + +class PropagationConfig(ConfigModel): + """ + Base class for all propagation configs + """ + + max_propagation_depth: int = 5 + max_propagation_fanout: int = 1000 + max_propagation_time_millis: int = 1000 * 60 * 60 * 1 # 1 hour + rate_limit_propagated_writes: int = 15000 # 15000 writes per 15 seconds (default) + rate_limit_propagated_writes_period: int = 15 # Every 15 seconds + + def get_rate_limited_emit_mcp(self, emitter: DataHubGraph) -> Any: + """ + Returns a rate limited emitter that can be used to emit metadata for propagation + """ + + @sleep_and_retry + @limits( + calls=self.rate_limit_propagated_writes, + period=self.rate_limit_propagated_writes_period, + ) + @wraps(emitter.emit_mcp) + def wrapper(*args, **kwargs): + return emitter.emit_mcp(*args, **kwargs) + + return wrapper def get_attribution_and_context_from_directive( @@ -80,7 +183,20 @@ def get_attribution_and_context_from_directive( source_detail: dict[str, str] = { "origin": propagation_directive.origin, "propagated": "true", + "propagation_depth": str(propagation_directive.propagation_depth), + "propagation_started_at": str( + propagation_directive.propagation_started_at + if propagation_directive.propagation_started_at + else time + ), } + if propagation_directive.relationships: + source_detail["propagation_relationship"] = propagation_directive.relationships[ + 0 + ][0].value + source_detail["propagation_direction"] = propagation_directive.relationships[0][ + 1 + ].value if propagation_directive.actor: source_detail["actor"] = propagation_directive.actor else: @@ -144,7 +260,7 @@ def get_unique_siblings(graph: AcrylDataHubGraph, entity_urn: str) -> list[str]: Get unique siblings for the entity urn """ - if entity_urn.startswith("urn:li:schemaField"): + if guess_entity_type(entity_urn) == "schemaField": parent_urn = Urn.create_from_string(entity_urn).get_entity_id()[0] entity_field_path = Urn.create_from_string(entity_urn).get_entity_id()[1] # Does my parent have siblings? @@ -158,7 +274,7 @@ def get_unique_siblings(graph: AcrylDataHubGraph, entity_urn: str) -> list[str]: target_sibling = other_siblings[0] # now we need to find the schema field in this sibling that # matches us - if target_sibling.startswith("urn:li:dataset"): + if guess_entity_type(target_sibling) == "dataset": schema_fields = graph.graph.get_aspect( target_sibling, models.SchemaMetadataClass ) diff --git a/docker/config/doc_propagation_action.yaml b/docker/config/doc_propagation_action.yaml index bb822b66..c570ba18 100644 --- a/docker/config/doc_propagation_action.yaml +++ b/docker/config/doc_propagation_action.yaml @@ -19,6 +19,8 @@ source: connection: bootstrap: ${KAFKA_BOOTSTRAP_SERVER:-localhost:9092} schema_registry_url: ${SCHEMA_REGISTRY_URL:-http://localhost:8081} + consumer_config: + max.poll.interval.ms: ${MAX_POLL_INTERVAL_MS:-60000} # 1 minute per poll topic_routes: mcl: ${METADATA_CHANGE_LOG_VERSIONED_TOPIC_NAME:-MetadataChangeLog_Versioned_v1} pe: ${PLATFORM_EVENT_TOPIC_NAME:-PlatformEvent_v1} @@ -27,6 +29,11 @@ action: config: # Action-specific configs (map) columns_enabled: ${DATAHUB_ACTIONS_DOC_PROPAGATION_COLUMNS_ENABLED:-true} + max_propagation_depth: ${DATAHUB_ACTIONS_DOC_PROPAGATION_MAX_PROPAGATION_DEPTH:-5} + max_propagation_fanout: ${DATAHUB_ACTIONS_DOC_PROPAGATION_MAX_PROPAGATION_FANOUT:-1000} + max_propagation_time_millis: ${DATAHUB_ACTIONS_DOC_PROPAGATION_MAX_PROPAGATION_TIME_MILLIS:-3600000} # 1 hour + rate_limit_propagated_writes: ${DATAHUB_ACTIONS_DOC_PROPAGATION_RATE_LIMIT_PROPAGATED_WRITES:-1500} # 100 writes per second (default) + rate_limit_propagated_writes_period: ${DATAHUB_ACTIONS_DOC_PROPAGATION_RATE_LIMIT_PROPAGATED_WRITES_PERIOD:-15} # Every 15 seconds datahub: server: 'http://${DATAHUB_GMS_HOST:-localhost}:${DATAHUB_GMS_PORT:-8080}' diff --git a/smoke-test/inject_actions_env_file.py b/smoke-test/inject_actions_env_file.py new file mode 100644 index 00000000..0bc00d68 --- /dev/null +++ b/smoke-test/inject_actions_env_file.py @@ -0,0 +1,31 @@ +import yaml +import sys + + +def modify_yaml(file_path): + # Read the existing file + with open(file_path, "r") as file: + content = file.read() + + # Parse the YAML content + data = yaml.safe_load(content) + + # Modify the datahub-actions section + if "services" in data and "datahub-actions" in data["services"]: + datahub_actions = data["services"]["datahub-actions"] + datahub_actions["env_file"] = ["${ACTIONS_ENV_FILE:-}"] + + # Write the modified content back to the file + with open(file_path, "w") as file: + yaml.dump(data, file, sort_keys=False) + + print(f"Successfully added env_file to datahub-actions in {file_path}") + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python script.py ") + sys.exit(1) + + file_path = sys.argv[1] + modify_yaml(file_path) diff --git a/smoke-test/requirements.txt b/smoke-test/requirements.txt index e209324c..3715d174 100644 --- a/smoke-test/requirements.txt +++ b/smoke-test/requirements.txt @@ -19,4 +19,6 @@ types-PyYAML requests<=2.31.0 # Missing numpy requirement in 8.0.0 deepdiff!=8.0.0 -acryl-datahub \ No newline at end of file +acryl-datahub +jinja2 +-e ../datahub-actions diff --git a/smoke-test/run-quickstart.sh b/smoke-test/run-quickstart.sh index e6cfa678..81388631 100755 --- a/smoke-test/run-quickstart.sh +++ b/smoke-test/run-quickstart.sh @@ -14,4 +14,17 @@ echo "DATAHUB_ACTIONS_VERSION = $DATAHUB_ACTIONS_VERSION" DATAHUB_TELEMETRY_ENABLED=false \ ACTIONS_VERSION=${DATAHUB_ACTIONS_VERSION} \ -datahub docker quickstart \ No newline at end of file +DATAHUB_ACTIONS_IMAGE=acryldata/datahub-actions-slim \ +datahub docker quickstart + +# After quickstart succeeds, we modify the docker-compose file to inject the env +# file variable +python inject_actions_env_file.py ~/.datahub/quickstart/docker-compose.yml + +# Then we run quickstart again with the modified docker-compose file + +DATAHUB_TELEMETRY_ENABLED=false \ +ACTIONS_VERSION=${DATAHUB_ACTIONS_VERSION} \ +ACTIONS_ENV_FILE=`pwd`/tests/resources/actions.env \ +DATAHUB_ACTIONS_IMAGE=acryldata/datahub-actions-slim \ +datahub docker quickstart -f ~/.datahub/quickstart/docker-compose.yml \ No newline at end of file diff --git a/smoke-test/tests/actions/doc_propagation/resources/datasets.yaml b/smoke-test/tests/actions/doc_propagation/resources/datasets.yaml deleted file mode 100644 index 77e37178..00000000 --- a/smoke-test/tests/actions/doc_propagation/resources/datasets.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# This file is used to define a dataset and provide metadata for it -- urn: urn:li:dataset:(urn:li:dataPlatform:hive,user.clicks,PROD) - subtype: Table - schema: - file: tests/actions/doc_propagation/resources/user_clicks.avsc - -- urn: urn:li:dataset:(urn:li:dataPlatform:events,ClickEvent,PROD) - subtype: Topic - description: | - This is a sample event that is generated when a user clicks on a link. - Do not use this event for any purpose other than testing. - schema: - file: tests/actions/doc_propagation/resources/user_clicks.avsc - fields: - - id: ip - description: 'the ip address of the user' - - id: user_id - description: 'the user id' - downstreams: - - urn:li:dataset:(urn:li:dataPlatform:hive,user.clicks,PROD) diff --git a/smoke-test/tests/actions/doc_propagation/resources/datasets_for_cycles_template.yaml b/smoke-test/tests/actions/doc_propagation/resources/datasets_for_cycles_template.yaml new file mode 100644 index 00000000..ddd1b802 --- /dev/null +++ b/smoke-test/tests/actions/doc_propagation/resources/datasets_for_cycles_template.yaml @@ -0,0 +1,47 @@ +# This file is used to define a dataset and provide metadata for it +- urn: urn:li:dataset:(urn:li:dataPlatform:events,{{ test_id }}.ClickEvent,PROD) + subtype: Topic + description: | + This is a sample event that is generated when a user clicks on a link. + Do not use this event for any purpose other than testing. + schema: + file: tests/actions/doc_propagation/resources/user_clicks.avsc + fields: + - id: ip + description: 'the ip address of the user' + - id: user_id + description: 'the user id' + +# dataset 4 has downstream - user.clicks (which creates a cycle) (as long as +# ClickEvent and user.clicks are siblings, the cycle is created) +# Since yaml doesn't support creation of siblings, we create siblings programmatically +- urn: urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_4,PROD) + subtype: Table + schema: + file: tests/actions/doc_propagation/resources/user_clicks.avsc + downstreams: + - urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks,PROD) + +- urn: urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_3,PROD) + subtype: Table + schema: + file: tests/actions/doc_propagation/resources/user_clicks.avsc + downstreams: + - urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_4,PROD) + +- urn: urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_2,PROD) + subtype: Table + schema: + file: tests/actions/doc_propagation/resources/user_clicks.avsc + downstreams: + - urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_3,PROD) + +- urn: urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks,PROD) + subtype: Table + schema: + file: tests/actions/doc_propagation/resources/user_clicks.avsc + downstreams: + - urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_2,PROD) + siblings: + - urn: urn:li:dataset:(urn:li:dataPlatform:events,{{ test_id }}.ClickEvent,PROD) + primary: true diff --git a/smoke-test/tests/actions/doc_propagation/resources/datasets_template.yaml b/smoke-test/tests/actions/doc_propagation/resources/datasets_template.yaml new file mode 100644 index 00000000..b5281dda --- /dev/null +++ b/smoke-test/tests/actions/doc_propagation/resources/datasets_template.yaml @@ -0,0 +1,52 @@ +# This file is used to define a dataset and provide metadata for it +- urn: urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_6,PROD) + subtype: Table + schema: + file: tests/actions/doc_propagation/resources/user_clicks.avsc +- urn: urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_5,PROD) + subtype: Table + schema: + file: tests/actions/doc_propagation/resources/user_clicks.avsc + downstreams: + - urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_6,PROD) +- urn: urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_4,PROD) + subtype: Table + schema: + file: tests/actions/doc_propagation/resources/user_clicks.avsc + downstreams: + - urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_5,PROD) +- urn: urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_3,PROD) + subtype: Table + schema: + file: tests/actions/doc_propagation/resources/user_clicks.avsc + downstreams: + - urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_4,PROD) + +- urn: urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_2,PROD) + subtype: Table + schema: + file: tests/actions/doc_propagation/resources/user_clicks.avsc + downstreams: + - urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_3,PROD) + +- urn: urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks,PROD) + subtype: Table + schema: + file: tests/actions/doc_propagation/resources/user_clicks.avsc + downstreams: + - urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks_2,PROD) + +- urn: urn:li:dataset:(urn:li:dataPlatform:events,{{ test_id }}.ClickEvent,PROD) + subtype: Topic + description: | + This is a sample event that is generated when a user clicks on a link. + Do not use this event for any purpose other than testing. + schema: + file: tests/actions/doc_propagation/resources/user_clicks.avsc + fields: + - id: ip + description: 'the ip address of the user' + - id: user_id + description: 'the user id' + downstreams: + - urn:li:dataset:(urn:li:dataPlatform:hive,{{ test_id }}.user.clicks,PROD) diff --git a/smoke-test/tests/actions/doc_propagation/test_propagation.py b/smoke-test/tests/actions/doc_propagation/test_propagation.py index e9b70139..c6a9ca79 100644 --- a/smoke-test/tests/actions/doc_propagation/test_propagation.py +++ b/smoke-test/tests/actions/doc_propagation/test_propagation.py @@ -1,20 +1,29 @@ import logging import os import tempfile +import time +import uuid +from contextlib import contextmanager from pathlib import Path -from random import randint -from typing import Any, List +from typing import Any, Dict, Iterable, List, Tuple import datahub.metadata.schema_classes as models +from pydantic import BaseModel import pytest from datahub.api.entities.dataset.dataset import Dataset -from datahub.emitter.mce_builder import make_dataset_urn, make_schema_field_urn +from datahub.emitter.mce_builder import make_schema_field_urn from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.common import PipelineContext, RecordEnvelope from datahub.ingestion.api.sink import NoopWriteCallback from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph from datahub.ingestion.sink.file import FileSink, FileSinkConfig from datahub.utilities.urns.urn import Urn +from jinja2 import Template +import tenacity +from datahub_actions.plugin.action.propagation.docs.propagation_action import ( + DocPropagationConfig, +) + from tests.utils import ( delete_urns_from_file, @@ -25,12 +34,7 @@ logger = logging.getLogger(__name__) - -start_index = randint(10, 10000) -dataset_urns = [ - make_dataset_urn("snowflake", f"table_foo_{i}") - for i in range(start_index, start_index + 10) -] +DELETE_AFTER_TEST = os.getenv("DELETE_AFTER_TEST", "false").lower() == "true" class FileEmitter: @@ -69,7 +73,53 @@ def sanitize(event: Any) -> Any: return event -def create_test_data(filename: str, test_resources_dir: str) -> List[str]: +def generate_temp_yaml(template_path: Path, output_path: Path, test_id: str): + # Load the YAML template + with open(template_path, "r") as file: + template_content = file.read() + + # Render the template with Jinja2 + template = Template(template_content) + rendered_yaml = template.render(test_id=test_id) + + # Write the rendered YAML to a temporary file + with open(output_path, "w") as file: + file.write(rendered_yaml) + + return output_path + + +class ActionTestEnv(BaseModel): + class Config: + allow_extra = True + + DATAHUB_ACTIONS_DOC_PROPAGATION_MAX_PROPAGATION_FANOUT: int + + +@pytest.fixture(scope="module") +def action_env_vars(pytestconfig) -> ActionTestEnv: + common_test_resources_dir = Path(pytestconfig.rootdir) / "tests" / "resources" + env_file = common_test_resources_dir / "actions.env" + # validate the env file exists + assert env_file.exists() + # read the env file, ignore comments and empty lines and convert to dict + env_vars = {} + with open(env_file, "r") as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + key, value = line.split("=", 1) + env_vars[key] = value + + return ActionTestEnv(**env_vars) + + +@pytest.fixture(scope="function") +def test_id(): + return f"test_{uuid.uuid4().hex[:8]}" + + +def create_test_data(filename: str, template_path: Path, test_id: str) -> List[str]: def get_urns_from_mcp(mcp: MetadataChangeProposalWrapper) -> List[str]: assert mcp.entityUrn urns = [mcp.entityUrn] @@ -82,9 +132,13 @@ def get_urns_from_mcp(mcp: MetadataChangeProposalWrapper) -> List[str]: urns.append(field_urn) return urns + # Generate temporary YAML file + temp_yaml_path = template_path.parent / f"temp_{template_path.name}_{test_id}.yaml" + generate_temp_yaml(template_path, temp_yaml_path, test_id) + mcps = [] all_urns = [] - for dataset in Dataset.from_yaml(file=f"{test_resources_dir}/datasets.yaml"): + for dataset in Dataset.from_yaml(file=str(temp_yaml_path)): mcps.extend([sanitize(event) for event in dataset.generate_mcp()]) file_emitter = FileEmitter(filename) @@ -93,6 +147,10 @@ def get_urns_from_mcp(mcp: MetadataChangeProposalWrapper) -> List[str]: file_emitter.emit(mcp) file_emitter.close() + + # Clean up the temporary YAML file + temp_yaml_path.unlink() + return list(set(all_urns)) @@ -106,21 +164,40 @@ def test_resources_dir(root_dir): return Path(root_dir) / "tests" / "actions" / "doc_propagation" / "resources" -@pytest.fixture(scope="module", autouse=False) -def ingest_cleanup_data(request, test_resources_dir, graph): - new_file, filename = tempfile.mkstemp(suffix=".json") - try: - all_urns = create_test_data(filename, test_resources_dir) - print("ingesting datasets test data") - ingest_file_via_rest(filename) - yield - print("removing test data") - delete_urns_from_file(filename) - for urn in all_urns: - graph.delete_entity(urn, hard=True) - wait_for_writes_to_sync() - finally: - os.remove(filename) +@pytest.fixture(scope="function") +def ingest_cleanup_data_function(request, test_resources_dir, graph, test_id): + @contextmanager + def _ingest_cleanup_data(template_file="datasets_template.yaml"): + new_file, filename = tempfile.mkstemp(suffix=f"_{test_id}.json") + try: + template_path = Path(test_resources_dir) / template_file + all_urns = create_test_data(filename, template_path, test_id) + print( + f"Ingesting datasets test data for test_id: {test_id} using template: {template_file}" + ) + ingest_file_via_rest(filename) + yield all_urns + finally: + if DELETE_AFTER_TEST: + print(f"Removing test data for test_id: {test_id}") + delete_urns_from_file(filename) + for urn in all_urns: + graph.delete_entity(urn, hard=True) + wait_for_writes_to_sync() + os.remove(filename) + + return _ingest_cleanup_data + + +@pytest.fixture(scope="function") +def ingest_cleanup_data(ingest_cleanup_data_function): + """ + This fixture is a wrapper around ingest_cleanup_data_function() that yields + the urns to make default usage easier. + """ + with ingest_cleanup_data_function() as urns: + # Convert the generator to a list to ensure it is fully consumed + yield urns @pytest.fixture(scope="module", autouse=False) @@ -135,24 +212,220 @@ def test_healthchecks(wait_for_healthchecks): pass -def add_col_col_lineage(graph): +@pytest.fixture(scope="function") +def test_data(tmp_path, test_resources_dir, test_id): + filename = tmp_path / f"test_data_{test_id}.json" + return create_test_data(str(filename), test_resources_dir, test_id) + + +@pytest.fixture(scope="function") +def dataset_depth_map(test_id): + return { + 0: f"urn:li:dataset:(urn:li:dataPlatform:events,{test_id}.ClickEvent,PROD)", + 1: f"urn:li:dataset:(urn:li:dataPlatform:hive,{test_id}.user.clicks,PROD)", + 2: f"urn:li:dataset:(urn:li:dataPlatform:hive,{test_id}.user.clicks_2,PROD)", + 3: f"urn:li:dataset:(urn:li:dataPlatform:hive,{test_id}.user.clicks_3,PROD)", + 4: f"urn:li:dataset:(urn:li:dataPlatform:hive,{test_id}.user.clicks_4,PROD)", + 5: f"urn:li:dataset:(urn:li:dataPlatform:hive,{test_id}.user.clicks_5,PROD)", + 6: f"urn:li:dataset:(urn:li:dataPlatform:hive,{test_id}.user.clicks_6,PROD)", + } + + +@pytest.fixture(scope="function") +def ingest_cleanup_data_function(request, test_resources_dir, graph, test_id): + @contextmanager + def _ingest_cleanup_data(template_file="datasets_template.yaml"): + new_file, filename = tempfile.mkstemp(suffix=f"_{test_id}.json") + try: + template_path = Path(test_resources_dir) / template_file + all_urns = create_test_data(filename, template_path, test_id) + print( + f"Ingesting datasets test data for test_id: {test_id} using template: {template_file}" + ) + ingest_file_via_rest(filename) + yield all_urns + finally: + if DELETE_AFTER_TEST: + print(f"Removing test data for test_id: {test_id}") + delete_urns_from_file(filename) + for urn in all_urns: + graph.delete_entity(urn, hard=True) + wait_for_writes_to_sync() + os.remove(filename) + + return _ingest_cleanup_data + + +@pytest.fixture(scope="function") +def large_fanout_graph_function(graph: DataHubGraph): + @contextmanager + def _large_fanout_graph( + test_id: str, max_fanout: int + ) -> Iterable[Tuple[str, List[str]]]: + max_index = max_fanout + 1 + all_urns = [] + dataset_base_name = f"large_fanout_dataset_{test_id}" + try: + delete_prior_to_running = False + if delete_prior_to_running: + for i in range(1, max_index + 1): + dataset_urn = f"urn:li:dataset:(urn:li:dataPlatform:hive,{dataset_base_name}_{i},PROD)" + graph.delete_entity(dataset_urn, hard=True) + graph.delete_entity( + f"urn:li:dataset:(urn:li:dataPlatform:events,{dataset_base_name}_0,PROD)", + hard=True, + ) + graph.delete_entity( + f"urn:li:dataset:(urn:li:dataPlatform:events,{dataset_base_name}_1,PROD)", + hard=True, + ) + wait_for_writes_to_sync() + + dataset_1 = f"urn:li:dataset:(urn:li:dataPlatform:events,{dataset_base_name}_0,PROD)" + schema_metadata_1 = models.SchemaMetadataClass( + schemaName="large_fanout_dataset_0", + platform="urn:li:dataPlatform:events", + version=0, + hash="", + platformSchema=models.OtherSchemaClass(rawSchema=""), + fields=[ + models.SchemaFieldClass( + fieldPath="ip", + type=models.SchemaFieldDataTypeClass( + type=models.StringTypeClass() + ), + description="This is the description", + nativeDataType="string", + ) + ], + ) + graph.emit( + MetadataChangeProposalWrapper( + entityUrn=dataset_1, aspect=schema_metadata_1 + ) + ) + all_urns.append(dataset_1) + + total_fanout = max_index + for i in range(1, total_fanout + 1): + dataset_i = f"urn:li:dataset:(urn:li:dataPlatform:hive,{dataset_base_name}_{i},PROD)" + schema_metadata_i = models.SchemaMetadataClass( + schemaName=f"large_fanout_dataset_{i}", + platform="urn:li:dataPlatform:hive", + version=0, + hash="", + platformSchema=models.OtherSchemaClass(rawSchema=""), + fields=[ + models.SchemaFieldClass( + fieldPath="ip", + type=models.SchemaFieldDataTypeClass( + type=models.StringTypeClass() + ), + nativeDataType="string", + ) + ], + ) + upstreams = models.UpstreamLineageClass( + upstreams=[ + models.UpstreamClass( + dataset=dataset_1, + type=models.DatasetLineageTypeClass.COPY, + ) + ], + fineGrainedLineages=[ + models.FineGrainedLineageClass( + upstreamType=models.FineGrainedLineageUpstreamTypeClass.FIELD_SET, + downstreamType=models.FineGrainedLineageDownstreamTypeClass.FIELD, + upstreams=[ + f"urn:li:schemaField:({dataset_1},ip)", + ], + downstreams=[ + f"urn:li:schemaField:({dataset_i},ip)", + ], + ) + ], + ) + for mcp in MetadataChangeProposalWrapper.construct_many( + entityUrn=dataset_i, + aspects=[ + schema_metadata_i, + upstreams, + ], + ): + graph.emit(mcp) + all_urns.append(dataset_i) + + wait_for_writes_to_sync() + yield (dataset_1, all_urns) + finally: + if DELETE_AFTER_TEST: + for urn in all_urns: + graph.delete_entity(urn, hard=True) + + return _large_fanout_graph + + +def add_col_col_lineage( + graph, test_id: str, depth: int, dataset_depth_map: Dict[int, str] +): field_path = "ip" - downstream_field = f"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:hive,user.clicks,PROD),{field_path})" - upstream_field = f"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:events,ClickEvent,PROD),{field_path})" - dataset1 = "urn:li:dataset:(urn:li:dataPlatform:hive,user.clicks,PROD)" - upstreams = graph.get_aspect(dataset1, models.UpstreamLineageClass) - upstreams.fineGrainedLineages = [ - models.FineGrainedLineageClass( - upstreamType=models.FineGrainedLineageUpstreamTypeClass.FIELD_SET, - downstreamType=models.FineGrainedLineageDownstreamTypeClass.FIELD, - upstreams=[upstream_field], - downstreams=[downstream_field], + field_pairs = [] + for current_depth in range(depth): + upstream_dataset = dataset_depth_map[current_depth] + downstream_dataset = dataset_depth_map[current_depth + 1] + downstream_field = f"urn:li:schemaField:({downstream_dataset},{field_path})" + upstream_field = f"urn:li:schemaField:({upstream_dataset},{field_path})" + upstreams = graph.get_aspect(downstream_dataset, models.UpstreamLineageClass) + upstreams.fineGrainedLineages = [ + models.FineGrainedLineageClass( + upstreamType=models.FineGrainedLineageUpstreamTypeClass.FIELD_SET, + downstreamType=models.FineGrainedLineageDownstreamTypeClass.FIELD, + upstreams=[upstream_field], + downstreams=[downstream_field], + ) + ] + graph.emit( + MetadataChangeProposalWrapper( + entityUrn=downstream_dataset, aspect=upstreams + ) ) - ] - graph.emit(MetadataChangeProposalWrapper(entityUrn=dataset1, aspect=upstreams)) + field_pairs.append((downstream_field, upstream_field)) wait_for_writes_to_sync() - return downstream_field, upstream_field + return field_pairs + + +def add_col_col_cycle_lineage( + graph, test_id: str, dataset_depth_map: Dict[int, str], cycle: List[int] +): + field_path = "ip" + + lineage_pairs = [(cycle[i], cycle[i + 1]) for i in range(len(cycle) - 1)] + + field_pairs = [] + + for src, dest in lineage_pairs: + upstream_dataset = dataset_depth_map[src] + downstream_dataset = dataset_depth_map[dest] + downstream_field = f"urn:li:schemaField:({downstream_dataset},{field_path})" + upstream_field = f"urn:li:schemaField:({upstream_dataset},{field_path})" + upstreams = graph.get_aspect(downstream_dataset, models.UpstreamLineageClass) + upstreams.fineGrainedLineages = [ + models.FineGrainedLineageClass( + upstreamType=models.FineGrainedLineageUpstreamTypeClass.FIELD_SET, + downstreamType=models.FineGrainedLineageDownstreamTypeClass.FIELD, + upstreams=[upstream_field], + downstreams=[downstream_field], + ) + ] + graph.emit( + MetadataChangeProposalWrapper( + entityUrn=downstream_dataset, aspect=upstreams + ) + ) + field_pairs.append((downstream_field, upstream_field)) + wait_for_writes_to_sync() + return field_pairs def add_field_description(f1, description, graph): @@ -167,16 +440,168 @@ def add_field_description(f1, description, graph): wait_for_writes_to_sync() +@tenacity.retry( + wait=tenacity.wait_exponential(multiplier=1, max=10), + stop=tenacity.stop_after_delay(60), +) def check_propagated_description(downstream_field, description, graph): documentation = graph.get_aspect(downstream_field, models.DocumentationClass) assert any(doc.documentation == description for doc in documentation.documentations) +def ensure_no_propagated_description(graph, schema_field): + documentation = graph.get_aspect(schema_field, models.DocumentationClass) + assert documentation is None or not documentation.documentations + + @pytest.mark.dependency(depends=["test_healthchecks"]) -def test_col_col_propagation(ingest_cleanup_data, graph): - downstream_field, upstream_field = add_col_col_lineage(graph) +def test_col_col_propagation_depth_1( + ingest_cleanup_data, graph, test_id, dataset_depth_map +): + downstream_field, upstream_field = add_col_col_lineage( + graph, depth=1, test_id=test_id, dataset_depth_map=dataset_depth_map + )[0] add_field_description(upstream_field, "This is the new description", graph) check_propagated_description(downstream_field, "This is the new description", graph) + + +@pytest.mark.dependency(depends=["test_healthchecks"]) +def test_col_col_propagation_depth_6( + ingest_cleanup_data, graph, test_id, dataset_depth_map +): + field_pairs = add_col_col_lineage( + graph, depth=6, test_id=test_id, dataset_depth_map=dataset_depth_map + ) + upstream_field = field_pairs[0][1] + add_field_description( + upstream_field, f"This is the new description {test_id}", graph + ) + for downstream_field, _ in field_pairs[:-1]: + check_propagated_description( + downstream_field, f"This is the new description {test_id}", graph + ) + + # last hop should NOT be propagated + last_downstream_field = f"urn:li:schemaField:({dataset_depth_map[6]},ip)" + ensure_no_propagated_description(graph, last_downstream_field) # Call to wait_for_healthchecks fixture will do the actual functionality. - print("test_col_col_propagation") - pass + # now check upstream propagation + add_field_description( + last_downstream_field, f"This is the new upstream description {test_id}", graph + ) + propagated_upstream_field = f"urn:li:schemaField:({dataset_depth_map[1]},ip)" + check_propagated_description( + propagated_upstream_field, + f"This is the new upstream description {test_id}", + graph, + ) + # propagation depth will prevent the last hop from being propagated + ensure_no_propagated_description(graph, upstream_field) + # also check that the previously propagated descriptions (for downstream + # fields) are still there + for index in [1, 2, 3, 4]: + check_propagated_description( + f"urn:li:schemaField:({dataset_depth_map[index]},ip)", + description=f"This is the new description {test_id}", + graph=graph, + ) + + +@pytest.mark.dependency(depends=["test_healthchecks"]) +def test_col_col_propagation_cycles( + ingest_cleanup_data_function, graph, test_id, dataset_depth_map +): + custom_template = "datasets_for_cycles_template.yaml" + with ingest_cleanup_data_function(custom_template) as urns: + [u for u in urns] + click_dataset_urn = dataset_depth_map[0] + hive_dataset_urn = dataset_depth_map[1] + graph.emit( + MetadataChangeProposalWrapper( + entityUrn=click_dataset_urn, + aspect=models.SiblingsClass(siblings=[hive_dataset_urn], primary=True), + ) + ) + graph.emit( + MetadataChangeProposalWrapper( + entityUrn=hive_dataset_urn, + aspect=models.SiblingsClass( + siblings=[click_dataset_urn], primary=False + ), + ) + ) + # create field level lineage + add_col_col_cycle_lineage( + graph, + test_id=test_id, + dataset_depth_map=dataset_depth_map, + cycle=[1, 2, 3, 4, 1], + ) + wait_for_writes_to_sync() + field = f"urn:li:schemaField:({click_dataset_urn},ip)" + add_field_description( + f1=field, description=f"This is the new description {test_id}", graph=graph + ) + for index in [1, 2, 3, 4]: + check_propagated_description( + f"urn:li:schemaField:({dataset_depth_map[index]},ip)", + description=f"This is the new description {test_id}", + graph=graph, + ) + # make sure the original field does not have a propagated description + ensure_no_propagated_description( + graph, f"urn:li:schemaField:({click_dataset_urn},ip)" + ) + + +@pytest.mark.dependency(depends=["test_healthchecks"]) +def test_col_col_propagation_large_fanout( + large_fanout_graph_function, test_id: str, action_env_vars: ActionTestEnv, graph +): + default_max_fanout = ( + action_env_vars.DATAHUB_ACTIONS_DOC_PROPAGATION_MAX_PROPAGATION_FANOUT + ) + with large_fanout_graph_function(test_id, default_max_fanout) as ( + dataset_1, + all_urns, + ): + new_description = f"This is the new description + {int(time.time())}" + # we change the description of the first field + editable_schema_metadata = models.EditableSchemaMetadataClass( + editableSchemaFieldInfo=[ + models.EditableSchemaFieldInfoClass( + fieldPath="ip", + description=new_description, + ) + ] + ) + graph.emit( + MetadataChangeProposalWrapper( + entityUrn=dataset_1, aspect=editable_schema_metadata + ) + ) + wait_for_writes_to_sync() + # now we check that the description has been propagated to all the + # downstream fields + num_fields_with_propagated_description = 0 + num_fields_missing_descriptions = 0 + for i in range(1, default_max_fanout + 2): + downstream_field = f"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:hive,large_fanout_dataset_{test_id}_{i},PROD),ip)" + + try: + check_propagated_description(downstream_field, new_description, graph) + num_fields_with_propagated_description += 1 + except tenacity.RetryError: + logger.error( + f"Field {downstream_field} does not have the propagated description" + ) + num_fields_missing_descriptions += 1 + logger.warning( + f"Number of fields with propagated description: {num_fields_with_propagated_description}" + ) + logger.warning( + f"Number of fields missing description: {num_fields_missing_descriptions}" + ) + assert num_fields_missing_descriptions == 1 + assert num_fields_with_propagated_description == default_max_fanout + # fanout is 1000, so the last field should not have the propagated description diff --git a/smoke-test/tests/resources/actions.env b/smoke-test/tests/resources/actions.env new file mode 100644 index 00000000..db6f97e5 --- /dev/null +++ b/smoke-test/tests/resources/actions.env @@ -0,0 +1,6 @@ +# Env vars for smoke tests + +## Increase rate limit for propagated writes to speed up tests +DATAHUB_ACTIONS_DOC_PROPAGATION_RATE_LIMIT_PROPAGATED_WRITES=1500 +DATAHUB_ACTIONS_DOC_PROPAGATION_RATE_LIMIT_PROPAGATED_WRITES_PERIOD=1 +DATAHUB_ACTIONS_DOC_PROPAGATION_MAX_PROPAGATION_FANOUT=20 \ No newline at end of file