-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(ingest/datahub-gc): gc source to cleanup things (#10085)
- Loading branch information
1 parent
e6e5c09
commit 9659d60
Showing
3 changed files
with
119 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
118 changes: 118 additions & 0 deletions
118
metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import time | ||
from dataclasses import dataclass | ||
from typing import Iterable | ||
|
||
from pydantic import Field | ||
|
||
from datahub.configuration.common import ConfigModel | ||
from datahub.ingestion.api.common import PipelineContext | ||
from datahub.ingestion.api.decorators import ( | ||
SupportStatus, | ||
config_class, | ||
platform_name, | ||
support_status, | ||
) | ||
from datahub.ingestion.api.source import Source, SourceReport | ||
from datahub.ingestion.api.workunit import MetadataWorkUnit | ||
|
||
|
||
class DataHubGcSourceConfig(ConfigModel): | ||
cleanup_expired_tokens: bool = Field( | ||
default=True, | ||
description="Whether to clean up expired tokens or not", | ||
) | ||
|
||
|
||
@dataclass | ||
class DataHubGcSourceReport(SourceReport): | ||
expired_tokens_revoked: int = 0 | ||
|
||
|
||
@platform_name("DataHubGc") | ||
@config_class(DataHubGcSourceConfig) | ||
@support_status(SupportStatus.TESTING) | ||
class DataHubGcSource(Source): | ||
def __init__(self, ctx: PipelineContext, config: DataHubGcSourceConfig): | ||
self.ctx = ctx | ||
self.config = config | ||
self.report = DataHubGcSourceReport() | ||
self.graph = ctx.graph | ||
assert ( | ||
self.graph is not None | ||
), "DataHubGc source requires a graph. Please either use datahub-rest sink or set datahub_api" | ||
|
||
@classmethod | ||
def create(cls, config_dict, ctx): | ||
config = DataHubGcSourceConfig.parse_obj(config_dict) | ||
return cls(ctx, config) | ||
|
||
def get_workunits_internal( | ||
self, | ||
) -> Iterable[MetadataWorkUnit]: | ||
if self.config.cleanup_expired_tokens: | ||
self.revoke_expired_tokens() | ||
yield from [] | ||
|
||
def revoke_expired_tokens(self) -> None: | ||
total = 1 | ||
while total > 0: | ||
expired_tokens_res = self._get_expired_tokens() | ||
list_access_tokens = expired_tokens_res.get("listAccessTokens", {}) | ||
tokens = list_access_tokens.get("tokens", []) | ||
total = list_access_tokens.get("total", 0) | ||
for token in tokens: | ||
self.report.expired_tokens_revoked += 1 | ||
token_id = token["id"] | ||
self._revoke_token(token_id) | ||
|
||
def _revoke_token(self, token_id: str) -> None: | ||
assert self.graph is not None | ||
self.graph.execute_graphql( | ||
query="""mutation revokeAccessToken($tokenId: String!) { | ||
revokeAccessToken(tokenId: $tokenId) | ||
} | ||
""", | ||
variables={"tokenId": token_id}, | ||
) | ||
|
||
def _get_expired_tokens(self) -> dict: | ||
assert self.graph is not None | ||
return self.graph.execute_graphql( | ||
query="""query listAccessTokens($input: ListAccessTokenInput!) { | ||
listAccessTokens(input: $input) { | ||
start | ||
count | ||
total | ||
tokens { | ||
urn | ||
type | ||
id | ||
name | ||
description | ||
actorUrn | ||
ownerUrn | ||
createdAt | ||
expiresAt | ||
__typename | ||
} | ||
__typename | ||
} | ||
} | ||
""", | ||
variables={ | ||
"input": { | ||
"start": 0, | ||
"count": 10, | ||
"filters": [ | ||
{ | ||
"field": "expiresAt", | ||
"values": [str(int(time.time() * 1000))], | ||
"condition": "LESS_THAN", | ||
} | ||
], | ||
} | ||
}, | ||
) | ||
|
||
def get_report(self) -> SourceReport: | ||
return self.report |