From 9659d6086739990dff7aa4551361b30a9e7993ee Mon Sep 17 00:00:00 2001 From: Aseem Bansal Date: Thu, 21 Mar 2024 15:21:17 +0530 Subject: [PATCH] feat(ingest/datahub-gc): gc source to cleanup things (#10085) --- metadata-ingestion/setup.py | 1 + .../datahub/ingestion/source/gc/__init__.py | 0 .../datahub/ingestion/source/gc/datahub_gc.py | 118 ++++++++++++++++++ 3 files changed, 119 insertions(+) create mode 100644 metadata-ingestion/src/datahub/ingestion/source/gc/__init__.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 5570893b7d1df..b2b451c04f44c 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -609,6 +609,7 @@ "ldap = datahub.ingestion.source.ldap:LDAPSource", "looker = datahub.ingestion.source.looker.looker_source:LookerDashboardSource", "lookml = datahub.ingestion.source.looker.lookml_source:LookMLSource", + "datahub-gc = datahub.ingestion.source.gc.datahub_gc:DataHubGcSource", "datahub-lineage-file = datahub.ingestion.source.metadata.lineage:LineageFileSource", "datahub-business-glossary = datahub.ingestion.source.metadata.business_glossary:BusinessGlossaryFileSource", "mlflow = datahub.ingestion.source.mlflow:MLflowSource", diff --git a/metadata-ingestion/src/datahub/ingestion/source/gc/__init__.py b/metadata-ingestion/src/datahub/ingestion/source/gc/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py b/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py new file mode 100644 index 0000000000000..bf21e293e6a2f --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py @@ -0,0 +1,118 @@ +import time +from dataclasses import dataclass +from typing import Iterable + +from pydantic import Field + +from datahub.configuration.common import ConfigModel +from datahub.ingestion.api.common import PipelineContext +from datahub.ingestion.api.decorators import ( + SupportStatus, + config_class, + platform_name, + support_status, +) +from datahub.ingestion.api.source import Source, SourceReport +from datahub.ingestion.api.workunit import MetadataWorkUnit + + +class DataHubGcSourceConfig(ConfigModel): + cleanup_expired_tokens: bool = Field( + default=True, + description="Whether to clean up expired tokens or not", + ) + + +@dataclass +class DataHubGcSourceReport(SourceReport): + expired_tokens_revoked: int = 0 + + +@platform_name("DataHubGc") +@config_class(DataHubGcSourceConfig) +@support_status(SupportStatus.TESTING) +class DataHubGcSource(Source): + def __init__(self, ctx: PipelineContext, config: DataHubGcSourceConfig): + self.ctx = ctx + self.config = config + self.report = DataHubGcSourceReport() + self.graph = ctx.graph + assert ( + self.graph is not None + ), "DataHubGc source requires a graph. Please either use datahub-rest sink or set datahub_api" + + @classmethod + def create(cls, config_dict, ctx): + config = DataHubGcSourceConfig.parse_obj(config_dict) + return cls(ctx, config) + + def get_workunits_internal( + self, + ) -> Iterable[MetadataWorkUnit]: + if self.config.cleanup_expired_tokens: + self.revoke_expired_tokens() + yield from [] + + def revoke_expired_tokens(self) -> None: + total = 1 + while total > 0: + expired_tokens_res = self._get_expired_tokens() + list_access_tokens = expired_tokens_res.get("listAccessTokens", {}) + tokens = list_access_tokens.get("tokens", []) + total = list_access_tokens.get("total", 0) + for token in tokens: + self.report.expired_tokens_revoked += 1 + token_id = token["id"] + self._revoke_token(token_id) + + def _revoke_token(self, token_id: str) -> None: + assert self.graph is not None + self.graph.execute_graphql( + query="""mutation revokeAccessToken($tokenId: String!) { + revokeAccessToken(tokenId: $tokenId) +} +""", + variables={"tokenId": token_id}, + ) + + def _get_expired_tokens(self) -> dict: + assert self.graph is not None + return self.graph.execute_graphql( + query="""query listAccessTokens($input: ListAccessTokenInput!) { + listAccessTokens(input: $input) { + start + count + total + tokens { + urn + type + id + name + description + actorUrn + ownerUrn + createdAt + expiresAt + __typename + } + __typename + } +} +""", + variables={ + "input": { + "start": 0, + "count": 10, + "filters": [ + { + "field": "expiresAt", + "values": [str(int(time.time() * 1000))], + "condition": "LESS_THAN", + } + ], + } + }, + ) + + def get_report(self) -> SourceReport: + return self.report