Skip to content

Commit

Permalink
feat(ingest): clean up DataHubRestEmitter return type
Browse files Browse the repository at this point in the history
Also makes the Airflow hook support arbitrary args.
  • Loading branch information
hsheth2 committed Nov 21, 2023
1 parent 15e68bb commit b22401d
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ export const HomePageOnboardingConfig: OnboardingStep[] = [
<Typography.Paragraph>
Here are your organization&apos;s <strong>Data Platforms</strong>. Data Platforms represent specific
third-party Data Systems or Tools. Examples include Data Warehouses like <strong>Snowflake</strong>,
Orchestrators like
<strong>Airflow</strong>, and Dashboarding tools like <strong>Looker</strong>.
Orchestrators like <strong>Airflow</strong>, and Dashboarding tools like <strong>Looker</strong>.
</Typography.Paragraph>
),
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def test_connection(self) -> Tuple[bool, str]:
return True, "Successfully connected to DataHub."

def _get_config(self) -> Tuple[str, Optional[str], Optional[int]]:
# We have a few places in the codebase that use this method directly, despite
# it being "private". For now, we retain backwards compatibility by keeping
# this method around, but should stop using it in the future.
config = self._get_config_v2()
return config[0], config[1], config[2].get("timeout_sec")

def _get_config_v2(self) -> Tuple[str, Optional[str], Dict]:
conn: "Connection" = self.get_connection(self.datahub_rest_conn_id)

host = conn.host
Expand All @@ -74,14 +81,19 @@ def _get_config(self) -> Tuple[str, Optional[str], Optional[int]]:
"host parameter should not contain a port number if the port is specified separately"
)
host = f"{host}:{conn.port}"
password = conn.password
timeout_sec = conn.extra_dejson.get("timeout_sec")
return (host, password, timeout_sec)
token = conn.password

extra_args = conn.extra_dejson
return (host, token, extra_args)

def make_emitter(self) -> "DatahubRestEmitter":
import datahub.emitter.rest_emitter

return datahub.emitter.rest_emitter.DatahubRestEmitter(*self._get_config())
config = self._get_config_v2()

return datahub.emitter.rest_emitter.DataHubRestEmitter(
config[0], config[1], **config[2]
)

def emit(
self,
Expand Down
7 changes: 2 additions & 5 deletions metadata-ingestion/src/datahub/emitter/generic_emitter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, Union
from typing import Callable, Optional, Union

from typing_extensions import Protocol

Expand All @@ -21,10 +21,7 @@ def emit(
# required. However, this would be a breaking change that may need
# more careful consideration.
callback: Optional[Callable[[Exception, str], None]] = None,
# TODO: The rest emitter returns timestamps as the return type. For now
# we smooth over that detail using Any, but eventually we should
# standardize on a return type.
) -> Any:
) -> None:
raise NotImplementedError

def flush(self) -> None:
Expand Down
7 changes: 2 additions & 5 deletions metadata-ingestion/src/datahub/emitter/rest_emitter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import datetime
import functools
import json
import logging
import os
from json.decoder import JSONDecodeError
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import requests
from deprecated import deprecated
Expand Down Expand Up @@ -208,8 +207,7 @@ def emit(
UsageAggregation,
],
callback: Optional[Callable[[Exception, str], None]] = None,
) -> Tuple[datetime.datetime, datetime.datetime]:
start_time = datetime.datetime.now()
) -> None:
try:
if isinstance(item, UsageAggregation):
self.emit_usage(item)
Expand All @@ -226,7 +224,6 @@ def emit(
else:
if callback:
callback(None, "success") # type: ignore
return start_time, datetime.datetime.now()

def emit_mce(self, mce: MetadataChangeEvent) -> None:
url = f"{self._gms_server}/entities?action=ingest"
Expand Down
20 changes: 16 additions & 4 deletions metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import logging
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import timedelta
from datetime import datetime, timedelta
from enum import auto
from threading import BoundedSemaphore
from typing import Union
from typing import Tuple, Union

from datahub.cli.cli_utils import set_env_variables_override_config
from datahub.configuration.common import (
Expand Down Expand Up @@ -181,6 +181,18 @@ def _write_done_callback(
self.report.report_failure({"e": e})
write_callback.on_failure(record_envelope, Exception(e), {})

def _emit_wrapper(
self,
record: Union[
MetadataChangeEvent,
MetadataChangeProposal,
MetadataChangeProposalWrapper,
],
) -> Tuple[datetime, datetime]:
start_time = datetime.now()
self.emitter.emit(record)
return start_time, datetime.now()

def write_record_async(
self,
record_envelope: RecordEnvelope[
Expand All @@ -194,7 +206,7 @@ def write_record_async(
) -> None:
record = record_envelope.record
if self.config.mode == SyncOrAsync.ASYNC:
write_future = self.executor.submit(self.emitter.emit, record)
write_future = self.executor.submit(self._emit_wrapper, record)
write_future.add_done_callback(
functools.partial(
self._write_done_callback, record_envelope, write_callback
Expand All @@ -204,7 +216,7 @@ def write_record_async(
else:
# execute synchronously
try:
(start, end) = self.emitter.emit(record)
(start, end) = self._emit_wrapper(record)
write_callback.on_success(record_envelope, success_metadata={})
except Exception as e:
write_callback.on_failure(record_envelope, e, failure_metadata={})
Expand Down
12 changes: 6 additions & 6 deletions metadata-ingestion/tests/test_helpers/graph_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union

from datahub.emitter.mce_builder import Aspect
from datahub.emitter.mcp import MetadataChangeProposalWrapper
Expand All @@ -22,15 +21,17 @@


class MockDataHubGraph(DataHubGraph):
def __init__(self, entity_graph: Dict[str, Dict[str, Any]] = {}) -> None:
def __init__(
self, entity_graph: Optional[Dict[str, Dict[str, Any]]] = None
) -> None:
self.emitted: List[
Union[
MetadataChangeEvent,
MetadataChangeProposal,
MetadataChangeProposalWrapper,
]
] = []
self.entity_graph = entity_graph
self.entity_graph = entity_graph or {}

def import_file(self, file: Path) -> None:
"""Imports metadata from any MCE/MCP file. Does not clear prior loaded data.
Expand Down Expand Up @@ -110,9 +111,8 @@ def emit(
UsageAggregationClass,
],
callback: Union[Callable[[Exception, str], None], None] = None,
) -> Tuple[datetime, datetime]:
) -> None:
self.emitted.append(item) # type: ignore
return (datetime.now(), datetime.now())

def emit_mce(self, mce: MetadataChangeEvent) -> None:
self.emitted.append(mce)
Expand Down

0 comments on commit b22401d

Please sign in to comment.