diff --git a/python_modules/dagster/dagster/_core/pipes/context.py b/python_modules/dagster/dagster/_core/pipes/context.py index a08f8c743116f..14fdd1318e3ce 100644 --- a/python_modules/dagster/dagster/_core/pipes/context.py +++ b/python_modules/dagster/dagster/_core/pipes/context.py @@ -161,7 +161,9 @@ def _resolve_metadata_value( # Type ignores because we currently validate in individual handlers def handle_message(self, message: PipesMessage) -> None: if self._received_closed_msg: - self._context.log.warn(f"[pipes] unexpected message received after closed: `{message}`") + self._context.log.warning( + f"[pipes] unexpected message received after closed: `{message}`" + ) method = cast(Method, message["method"]) if method == "opened": diff --git a/python_modules/dagster/dagster/_core/pipes/utils.py b/python_modules/dagster/dagster/_core/pipes/utils.py index d403234a8e501..420762a8a7362 100644 --- a/python_modules/dagster/dagster/_core/pipes/utils.py +++ b/python_modules/dagster/dagster/_core/pipes/utils.py @@ -1,3 +1,4 @@ +import collections.abc as abc import datetime import json import os @@ -8,7 +9,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from threading import Event, Thread -from typing import Iterator, Optional, Sequence, TextIO +from typing import IO, Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple, TypeVar, cast from dagster_pipes import ( PIPES_PROTOCOL_VERSION_FIELD, @@ -25,7 +26,7 @@ _check as check, ) from dagster._core.errors import DagsterInvariantViolationError, DagsterPipesExecutionError -from dagster._core.pipes.client import PipesContextInjector, PipesMessageReader +from dagster._core.pipes.client import PipesContextInjector, PipesLaunchedData, PipesMessageReader from dagster._core.pipes.context import ( PipesMessageHandler, PipesSession, @@ -135,6 +136,9 @@ class PipesFileMessageReader(PipesMessageReader): def __init__(self, path: str): self._path = check.str_param(path, "path") + def on_launched(self, params: PipesLaunchedData) -> None: + self.launched_payload = params + @contextmanager def read_messages( self, @@ -236,43 +240,28 @@ def no_messages_debug_text(self) -> str: WAIT_FOR_LOGS_TIMEOUT = 60 -class PipesBlobStoreMessageReader(PipesMessageReader): - """Message reader that reads a sequence of message chunks written by an external process into a - blob store such as S3, Azure blob storage, or GCS. +TCursor = TypeVar("TCursor") - The reader maintains a counter, starting at 1, that is synchronized with a message writer in - some pipes process. The reader starts a thread that periodically attempts to read a chunk - indexed by the counter at some location expected to be written by the pipes process. The chunk - should be a file with each line corresponding to a JSON-encoded pipes message. When a chunk is - successfully read, the messages are processed and the counter is incremented. The - :py:class:`PipesBlobStoreMessageWriter` on the other end is expected to similarly increment a - counter (starting from 1) on successful write, keeping counters on the read and write end in - sync. - - If `log_readers` is passed, the message reader will start the passed log readers when the - `opened` message is received from the external process. - - Args: - interval (float): interval in seconds between attempts to download a chunk - log_readers (Optional[Sequence[PipesLogReader]]): A set of readers for logs. - """ +class PipesThreadedMessageReader(PipesMessageReader): interval: float - counter: int - log_readers: Sequence["PipesLogReader"] - opened_payload: Optional[PipesOpenedData] + log_readers: Dict[str, "PipesLogReader"] def __init__( self, interval: float = 10, - log_readers: Optional[Sequence["PipesLogReader"]] = None, + log_readers: Optional[Mapping[str, "PipesLogReader"]] = None, ): self.interval = interval - self.counter = 1 - self.log_readers = check.opt_sequence_param( - log_readers, "log_readers", of_type=PipesLogReader + self.log_readers = ( + {name: reader for name, reader in log_readers.items()} if log_readers else {} ) + self.extra_log_readers = {} self.opened_payload = None + self.launched_payload = None + + @abstractmethod + def can_start(self, params: PipesParams) -> bool: ... @contextmanager def read_messages( @@ -317,6 +306,14 @@ def read_messages( def on_opened(self, opened_payload: PipesOpenedData) -> None: self.opened_payload = opened_payload + def add_log_reader(self, name: str, log_reader: "PipesLogReader") -> None: + """Can be used to attach extra log readers to the message reader. + Typically called when the target for reading logs is not known until after the external + process has started (for example, when the target depends on an external job_id). + The LogReader will be eventually started by the PipesThreadedMessageReader. + """ + self.extra_log_readers[name] = log_reader + @abstractmethod @contextmanager def get_params(self) -> Iterator[PipesParams]: @@ -328,7 +325,18 @@ def get_params(self) -> Iterator[PipesParams]: """ @abstractmethod - def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: ... + def download_messages_parts( + self, cursor: Optional[TCursor], params: PipesParams + ) -> Optional[Tuple[TCursor, str]]: + """Download a chunk of messages from the target location. + + Args: + cursor (Optional[Any]): A cursor that specifies where to start downloading messages from. + it can be set according to specifics of each message reader implementation, for example. + it can be an index for a line in a log file, or a timestamp for a message in a stream. + params (PipesParams): A dict of parameters that specifies where to download messages from. + """ + ... def _messages_thread( self, @@ -337,22 +345,61 @@ def _messages_thread( is_session_closed: Event, ) -> None: try: + # first, check if the target is readable. If it's now, we may need to receive extra params from the external process + if not self.can_start(params): + while not is_session_closed.is_set() and not self.launched_payload is not None: + time.sleep(DEFAULT_SLEEP_INTERVAL) + + # hurray! we have some extra params, let's check if the target is readable now + params = {**params, **(self.launched_payload or {}).get("extras", {})} + + if not self.can_start(params): + if self.launched_payload is not None: + handler._context.log.warning( # noqa: SLF001 + f"[pipes] Target of {self.__class__.__name__} is not readable after receiving extra params from the external process (`on_launched` has been called)" + ) + else: + handler._context.log.warning( # noqa: SLF001 + f"[pipes] Target of {self.__class__.__name__} is not readable." + ) + + return + start_or_last_download = datetime.datetime.now() + session_closed_at = None + cursor = None while True: + if handler.received_closed_message: + return + now = datetime.datetime.now() if ( now - start_or_last_download ).seconds > self.interval or is_session_closed.is_set(): start_or_last_download = now - chunk = self.download_messages_chunk(self.counter, params) - if chunk: + result = self.download_messages_parts(cursor, params) + if result is not None: + cursor, chunk = result for line in chunk.split("\n"): - message = json.loads(line) - handler.handle_message(message) - self.counter += 1 - elif is_session_closed.is_set(): - break + try: + message = json.loads(line) + if PIPES_PROTOCOL_VERSION_FIELD in message.keys(): + handler.handle_message(message) + except json.JSONDecodeError: + pass + time.sleep(DEFAULT_SLEEP_INTERVAL) + + if is_session_closed.is_set(): + if session_closed_at is None: + session_closed_at = datetime.datetime.now() + + # After the external process has completed, we don't want to immediately exit + if ( + datetime.datetime.now() - session_closed_at + ).seconds > WAIT_FOR_LOGS_AFTER_EXECUTION_INTERVAL: + return + except: handler.report_pipes_framework_exception( f"{self.__class__.__name__} messages thread", @@ -377,22 +424,31 @@ def _logs_thread( return time.sleep(DEFAULT_SLEEP_INTERVAL) - # Logs are started with a merge of the params generated by the message reader and the opened - # payload. + # Logs are started with a merge of the params generated by the message reader, the opened + # payload, and params from the state log_params = {**params, **self.opened_payload} + wait_for_logs_start = None + # Loop over all log readers and start them if the target is readable, which typically means # a file exists at the target location. Different execution environments may write logs at # different times (e.g., some may write logs periodically during execution, while others may # only write logs after the process has completed). try: - unstarted_log_readers = list(self.log_readers) - wait_for_logs_start = None - while unstarted_log_readers: - # iterate in reverse so we can pop off elements as we go - for i in reversed(range(len(unstarted_log_readers))): - if unstarted_log_readers[i].target_is_readable(log_params): - reader = unstarted_log_readers.pop(i) + unstarted_log_readers = {**self.log_readers, **self.extra_log_readers} + failed_log_readers = set() + + while True: + # periodically check extra log readers for new readers which may be added after the + # external process has started and add them to the unstarted log readers + for key in list(self.extra_log_readers.keys()).copy(): + new_log_reader = self.extra_log_readers.pop(key) + unstarted_log_readers[key] = new_log_reader + self.log_readers[key] = new_log_reader + + for key in list(unstarted_log_readers.keys()).copy(): + if unstarted_log_readers[key].can_start(log_params): + reader = unstarted_log_readers.pop(key) reader.start(log_params, is_session_closed) # In some cases logs might not be written out until after the external process has @@ -403,19 +459,24 @@ def _logs_thread( if is_session_closed.is_set(): if wait_for_logs_start is None: wait_for_logs_start = datetime.datetime.now() - if ( - datetime.datetime.now() - wait_for_logs_start - ).seconds > WAIT_FOR_LOGS_TIMEOUT: - for log_reader in unstarted_log_readers: - warnings.warn( - f"Attempted to read log for reader {log_reader.name} but log was" - " still not written {WAIT_FOR_LOGS_TIMEOUT} seconds after session close. Abandoning log." - ) - break - time.sleep(DEFAULT_SLEEP_INTERVAL) - # Wait for the external process to complete - is_session_closed.wait() + if not unstarted_log_readers: + return + elif ( + unstarted_log_readers + and (datetime.datetime.now() - wait_for_logs_start).seconds + > WAIT_FOR_LOGS_TIMEOUT + ): + for key, log_reader in unstarted_log_readers.items(): + if key not in failed_log_readers: + failed_log_readers.add(key) + warnings.warn( + f"[pipes] Attempted to read log for reader {key}:{log_reader.name} but log was" + f" still not written {WAIT_FOR_LOGS_TIMEOUT} seconds after session close. Abandoning reader {key}." + ) + return + + time.sleep(DEFAULT_SLEEP_INTERVAL) except: handler.report_pipes_framework_exception( f"{self.__class__.__name__} logs thread", @@ -423,11 +484,80 @@ def _logs_thread( ) raise finally: - for log_reader in self.log_readers: + for log_reader in self.log_readers.values(): if log_reader.is_running(): log_reader.stop() +class PipesBlobStoreMessageReader(PipesThreadedMessageReader): + """Message reader that reads a sequence of message chunks written by an external process into a + blob store such as S3, Azure blob storage, or GCS. + + The reader maintains a counter, starting at 1, that is synchronized with a message writer in + some pipes process. The reader starts a thread that periodically attempts to read a chunk + indexed by the counter at some location expected to be written by the pipes process. The chunk + should be a file with each line corresponding to a JSON-encoded pipes message. When a chunk is + successfully read, the messages are processed and the counter is incremented. The + :py:class:`PipesBlobStoreMessageWriter` on the other end is expected to similarly increment a + counter (starting from 1) on successful write, keeping counters on the read and write end in + sync. + + If `log_readers` is passed, the message reader will start the passed log readers when the + `opened` message is received from the external process. + + Args: + interval (float): interval in seconds between attempts to download a chunk + log_readers (Optional[Mapping[str, PipesLogReader]]): A mapping of arbitrary names to readers for logs. + """ + + counter: int + + def __init__( + self, + interval: float = 10, + log_readers: Optional[Mapping[str, "PipesLogReader"]] = None, + ): + if isinstance(log_readers, abc.Mapping): + log_readers = check.dict_param( + log_readers, "log_readers", key_type=str, value_type=PipesLogReader + ) + elif isinstance(log_readers, abc.Sequence): + log_readers = { + str(i): lr + for i, lr in enumerate( + check.opt_sequence_param( + cast(Sequence["PipesLogReader"], log_readers), + "log_readers", + of_type=PipesLogReader, + ) + ) + } + + super().__init__(interval=interval, log_readers=log_readers) + + self.counter = 1 + + @abstractmethod + def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: + ... + # historical reasons, keeping the original interface of PipesBlobStoreMessageReader + + def download_messages_parts( + self, cursor: Any, params: PipesParams + ) -> Optional[Tuple[Any, str]]: + # mapping new interface to the old one + # the old interface isn't using the cursor parameter, instead, it keeps track of counter in the "counter" attribute + chunk = self.download_messages_chunk(self.counter, params) + if chunk: + self.counter += 1 + return None, chunk + + def can_start(self, params: PipesParams) -> bool: + return ( + True # historical reasons, this method was introduced after the initial implementation + ) + + class PipesLogReader(ABC): @abstractmethod def start(self, params: PipesParams, is_session_closed: Event) -> None: ... @@ -439,7 +569,7 @@ def stop(self) -> None: ... def is_running(self) -> bool: ... @abstractmethod - def target_is_readable(self, params: PipesParams) -> bool: ... + def can_start(self, params: PipesParams) -> bool: ... @property def name(self) -> str: @@ -452,10 +582,10 @@ class PipesChunkedLogReader(PipesLogReader): Args: interval (float): interval in seconds between attempts to download a chunk. - target_stream (TextIO): The stream to which to write the logs. Typcially `sys.stdout` or `sys.stderr`. + target_stream (IO[str]): The stream to which to write the logs. Typcially `sys.stdout` or `sys.stderr`. """ - def __init__(self, *, interval: float = 10, target_stream: TextIO): + def __init__(self, *, interval: float = 10, target_stream: IO[str]): self.interval = interval self.target_stream = target_stream self.thread: Optional[Thread] = None @@ -512,7 +642,21 @@ def _join_thread(thread: Thread, thread_name: str) -> None: raise DagsterPipesExecutionError(f"Timed out waiting for {thread_name} thread to finish.") -def extract_message_or_forward_to_file(handler: "PipesMessageHandler", log_line: str, file: TextIO): +def forward_only_logs_to_file(log_line: str, file: IO[str]): + """Will write the log line to the file if it is not a Pipes message.""" + try: + message = json.loads(log_line) + if PIPES_PROTOCOL_VERSION_FIELD in message.keys(): + return + else: + file.writelines((log_line, "\n")) + except Exception: + file.writelines((log_line, "\n")) + + +def extract_message_or_forward_to_file( + handler: "PipesMessageHandler", log_line: str, file: IO[str] +): # exceptions as control flow, you love to see it try: message = json.loads(log_line) @@ -521,7 +665,7 @@ def extract_message_or_forward_to_file(handler: "PipesMessageHandler", log_line: else: file.writelines((log_line, "\n")) except Exception: - # move non-message logs in to stdout for compute log capture + # move non-message logs in to file for compute log capture file.writelines((log_line, "\n")) @@ -612,7 +756,7 @@ def ext_asset(context: OpExecutionContext): ) finally: if not message_handler.received_opened_message: - context.log.warn( + context.log.warning( "[pipes] did not receive any messages from external process. Check stdout / stderr" " logs from the external process if" f" possible.\n{context_injector.__class__.__name__}:" @@ -620,7 +764,7 @@ def ext_asset(context: OpExecutionContext): f" {message_reader.no_messages_debug_text()}\n" ) elif not message_handler.received_closed_message: - context.log.warn( + context.log.warning( "[pipes] did not receive closed message from external process. Buffered messages" " may have been discarded without being delivered. Use `open_dagster_pipes` as a" " context manager (a with block) to ensure that cleanup is successfully completed." diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py b/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py index e73b3fe5c5857..3b79ba5c0d673 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py @@ -403,6 +403,9 @@ def __init__( self.log_modification_time = None self.log_path = None + def can_start(self, params: PipesParams) -> bool: + return self._get_log_path(params) is not None + def download_log_chunk(self, params: PipesParams) -> Optional[str]: log_path = self._get_log_path(params) if log_path is None: