Skip to content

Commit

Permalink
[dagster-aws] make PipesCloudWatchMessageReader subclass
Browse files Browse the repository at this point in the history
PipesThreadedMessageReader
  • Loading branch information
danielgafni committed Oct 2, 2024
1 parent d48132c commit b80d4f9
Show file tree
Hide file tree
Showing 14 changed files with 564 additions and 172 deletions.
Binary file modified docs/content/api/modules.json.gz
Binary file not shown.
Binary file modified docs/content/api/searchindex.json.gz
Binary file not shown.
Binary file modified docs/content/api/sections.json.gz
Binary file not shown.
1 change: 0 additions & 1 deletion docs/content/integrations/deltalake/reference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,6 @@ config = S3Config(copy_if_not_exists="header: cf-copy-destination-if-none-match:

</TabGroup>


In cases where non-AWS S3 implementations are used, the endpoint URL or the S3 service needs to be provided.

```py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,9 @@ from dagster_deltalake import LocalConfig
from dagster_deltalake_pandas import DeltaLakePandasIOManager

from dagster import Definitions
from . import assets

all_assets = load_assets_from_modules([assets])

defs = Definitions(
assets=all_assets,
assets=[iris_dataset],
resources={
"io_manager": DeltaLakePandasIOManager(
root_uri="path/to/deltalake", # required
Expand Down
6 changes: 5 additions & 1 deletion python_modules/dagster/dagster/_core/pipes/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime
from functools import cached_property
from queue import Queue
from typing import (
Expand Down Expand Up @@ -290,13 +291,16 @@ class PipesSession:
indicating the location from which the external process should load context data.
message_reader_params (PipesParams): Parameters yielded by the message reader, indicating
the location to which the external process should write messages.
created_at (datetime): The time at which the session was created. Useful as cutoff for
reading logs.
"""

context_data: PipesContextData
message_handler: PipesMessageHandler
context_injector_params: PipesParams
message_reader_params: PipesParams
context: OpExecutionContext
created_at: datetime = field(default_factory=datetime.now)

@cached_property
def default_remote_invocation_tags(self) -> Dict[str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
PipesS3ContextInjector,
)
from dagster_aws.pipes.message_readers import (
PipesCloudWatchLogReader,
PipesCloudWatchMessageReader,
PipesLambdaLogsMessageReader,
PipesS3MessageReader,
Expand All @@ -22,6 +23,7 @@
"PipesLambdaEventContextInjector",
"PipesS3MessageReader",
"PipesLambdaLogsMessageReader",
"PipesCloudWatchLogReader",
"PipesCloudWatchMessageReader",
"PipesEMRServerlessClient",
]
181 changes: 123 additions & 58 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/ecs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pprint import pformat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
from typing import TYPE_CHECKING, Any, Dict, Optional, cast

import boto3
import botocore
import dagster._check as check
from dagster import PipesClient
from dagster import DagsterInvariantViolationError, PipesClient
from dagster._annotations import experimental, public
from dagster._core.definitions.resource_annotation import TreatAsResourceParam
from dagster._core.errors import DagsterExecutionInterruptedError
Expand All @@ -16,19 +16,23 @@
)
from dagster._core.pipes.utils import PipesEnvContextInjector, open_pipes_session

from dagster_aws.pipes.message_readers import PipesCloudWatchMessageReader
from dagster_aws.pipes.message_readers import PipesCloudWatchLogReader, PipesCloudWatchMessageReader

if TYPE_CHECKING:
from mypy_boto3_ecs.client import ECSClient
from mypy_boto3_ecs.type_defs import RunTaskRequestRequestTypeDef
from mypy_boto3_ecs.type_defs import (
DescribeTasksResponseTypeDef,
RunTaskRequestRequestTypeDef,
RunTaskResponseTypeDef,
)


@experimental
class PipesECSClient(PipesClient, TreatAsResourceParam):
"""A pipes client for running AWS ECS tasks.
Args:
client (Optional[boto3.client]): The boto ECS client used to launch the ECS task
client (Any): The boto ECS client used to launch the ECS task
context_injector (Optional[PipesContextInjector]): A context injector to use to inject
context into the ECS task. Defaults to :py:class:`PipesEnvContextInjector`.
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
Expand All @@ -38,7 +42,7 @@ class PipesECSClient(PipesClient, TreatAsResourceParam):

def __init__(
self,
client: Optional[boto3.client] = None, # pyright: ignore (reportGeneralTypeIssues)
client=None,
context_injector: Optional[PipesContextInjector] = None,
message_reader: Optional[PipesMessageReader] = None,
forward_termination: bool = True,
Expand All @@ -59,6 +63,7 @@ def run(
context: OpExecutionContext,
run_task_params: "RunTaskRequestRequestTypeDef",
extras: Optional[Dict[str, Any]] = None,
pipes_container_name: Optional[str] = None,
) -> PipesClientCompletedInvocation:
"""Run ECS tasks, enriched with the pipes protocol.
Expand All @@ -68,6 +73,8 @@ def run(
Must contain ``taskDefinition`` key.
See `Boto3 API Documentation <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs/client/run_task.html#run-task>`_
extras (Optional[Dict[str, Any]]): Additional information to pass to the Pipes session in the external process.
pipes_container_name (Optional[str]): If running more than one container in the task,
and using :py:class:`PipesCloudWatchMessageReader`, specify the container name which will be running Pipes.
Returns:
PipesClientCompletedInvocation: Wrapper containing results reported by the external
Expand All @@ -79,12 +86,19 @@ def run(
context_injector=self._context_injector,
extras=extras,
) as session:
params = run_task_params
# we can't be running more than 1 replica of the task
# because this guarantees multiple Pipes sessions running at the same time
# which we don't support yet

task_definition = params["taskDefinition"]
cluster = params.get("cluster")
if run_task_params.get("count") and run_task_params.get("count", 1) > 1:
raise DagsterInvariantViolationError(
"Running more than one ECS task is not supported."
)

task_definition = run_task_params["taskDefinition"]
cluster = run_task_params.get("cluster")

overrides = cast(dict, params.get("overrides") or {})
overrides = cast(dict, run_task_params.get("overrides") or {})
overrides["containerOverrides"] = overrides.get("containerOverrides", [])

# get all containers from task definition
Expand Down Expand Up @@ -134,47 +148,97 @@ def run(
}
)

params["overrides"] = ( # pyright: ignore (reportGeneralTypeIssues)
run_task_params["overrides"] = ( # pyright: ignore (reportGeneralTypeIssues)
overrides # assign in case overrides was created here as an empty dict
)

response = self._client.run_task(**params)
response = self._client.run_task(**run_task_params)

if len(response["tasks"]) > 1:
# this error should never happen, as we're running a single task
raise DagsterInvariantViolationError(
f"Expected to get a single task from response, got multiple: {response['tasks']}"
)

tasks: List[str] = [task["taskArn"] for task in response["tasks"]] # pyright: ignore (reportTypedDictNotRequiredAccess)
task = response["tasks"][0]
task_arn = task["taskArn"] # pyright: ignore (reportTypedDictNotRequiredAccess)
task_id = task_arn.split("/")[-1]
containers = task["containers"] # pyright: ignore (reportTypedDictNotRequiredAccess)

def get_cloudwatch_params(container_name: str) -> Optional[Dict[str, str]]:
"""This will either return the log group and stream for the container, or None in case of a bad log configuration."""
if log_config := log_configurations.get(container_name):
if log_config["logDriver"] == "awslogs":
log_group = log_config["options"]["awslogs-group"] # pyright: ignore (reportTypedDictNotRequiredAccess)

# stream name is combined from: prefix, container name, task id
log_stream = f"{log_config['options']['awslogs-stream-prefix']}/{container_name}/{task_id}" # pyright: ignore (reportTypedDictNotRequiredAccess)

return {"log_group": log_group, "log_stream": log_stream}
else:
context.log.warning(
f"[pipes] Unsupported log driver {log_config['logDriver']} for Pipes container {container_name} in task {task_arn}. Dagster Pipes won't be able to read CloudWatch logs from this container."
)
else:
context.log.warning(
f"[pipes] log configuration for container {container_name} not found in task definition {task_definition}."
)
return None

try:
response = self._wait_for_tasks_completion(tasks=tasks, cluster=cluster)
if (
isinstance(self._message_reader, PipesCloudWatchMessageReader)
and len(containers) > 1
and not pipes_container_name
):
raise DagsterInvariantViolationError(
"When using PipesCloudWatchMessageReader with more than one container, pipes_container_name must be set."
)
elif (
isinstance(self._message_reader, PipesCloudWatchMessageReader)
and len(containers) == 1
):
pipes_container_name = containers[0]["name"] # pyright: ignore (reportTypedDictNotRequiredAccess)

if isinstance(self._message_reader, PipesCloudWatchMessageReader):
pipes_container_name = cast(str, pipes_container_name)

params = get_cloudwatch_params(pipes_container_name)

if params:
# update log group and stream for the message reader
# it should start receiving messages shortly after this call
session.report_launched({"extras": params})

# collect logs from all containers
for task in response["tasks"]:
task_id = task["taskArn"].split("/")[-1]

for container in task["containers"]:
if log_config := log_configurations.get(container["name"]):
if log_config["logDriver"] == "awslogs":
log_group = log_config["options"]["awslogs-group"] # pyright: ignore (reportTypedDictNotRequiredAccess)

# stream name is combined from: prefix, container name, task id
log_stream = f"{log_config['options']['awslogs-stream-prefix']}/{container['name']}/{task_id}" # pyright: ignore (reportTypedDictNotRequiredAccess)

if isinstance(self._message_reader, PipesCloudWatchMessageReader):
self._message_reader.consume_cloudwatch_logs(
log_group,
log_stream,
start_time=int(task["createdAt"].timestamp() * 1000),
)
else:
context.log.warning(
f"[pipes] Unsupported log driver {log_config['logDriver']} for container {container['name']} in task {task['taskArn']}. Dagster Pipes won't be able to receive messages from this container."
)

# TODO: insert container names into the log message
# right now all logs will be mixed together, which is not very good

for container in containers:
if isinstance(self._message_reader, PipesCloudWatchMessageReader):
params = get_cloudwatch_params(container["name"]) # pyright: ignore (reportTypedDictNotRequiredAccess)

if params:
self._message_reader.add_log_reader(
container["name"], # pyright: ignore (reportTypedDictNotRequiredAccess)
PipesCloudWatchLogReader(
client=self._message_reader.client,
log_group=params["log_group"],
log_stream=params["log_stream"],
start_time=int(session.created_at.timestamp() * 1000),
),
)

response = self._wait_for_completion(response, cluster=cluster)

# check for failed containers
failed_containers = {}

for task in response["tasks"]:
for container in task["containers"]:
for container in task["containers"]: # pyright: ignore (reportTypedDictNotRequiredAccess)
if container.get("exitCode") not in (0, None):
failed_containers[container["runtimeId"]] = container.get("exitCode")
failed_containers[container["runtimeId"]] = container.get("exitCode") # pyright: ignore (reportTypedDictNotRequiredAccess)

if failed_containers:
raise RuntimeError(
Expand All @@ -186,37 +250,38 @@ def run(
context.log.warning(
"[pipes] Dagster process interrupted, terminating ECS tasks"
)
self._terminate_tasks(context=context, tasks=tasks, cluster=cluster)
self._terminate(context=context, wait_response=response, cluster=cluster)
raise

context.log.info(f"[pipes] ECS tasks {tasks} completed")

context.log.info(f"[pipes] ECS task {task_arn} completed")
return PipesClientCompletedInvocation(session)

def _wait_for_tasks_completion(
self, tasks: List[str], cluster: Optional[str] = None
) -> Dict[str, Any]:
def _wait_for_completion(
self, start_response: "RunTaskResponseTypeDef", cluster: Optional[str] = None
) -> "DescribeTasksResponseTypeDef":
waiter = self._client.get_waiter("tasks_stopped")

params: Dict[str, Any] = {"tasks": tasks}
params: Dict[str, Any] = {"tasks": [start_response["tasks"][0]["taskArn"]]} # pyright: ignore (reportGeneralTypeIssues)

if cluster:
params["cluster"] = cluster

waiter.wait(**params)
return self._client.describe_tasks(**params) # pyright: ignore (reportReturnType)
return self._client.describe_tasks(**params)

def _terminate_tasks(
self, context: OpExecutionContext, tasks: List[str], cluster: Optional[str] = None
def _terminate(
self,
context: OpExecutionContext,
wait_response: "DescribeTasksResponseTypeDef",
cluster: Optional[str] = None,
):
for task in tasks:
try:
self._client.stop_task(
cluster=cluster, # pyright: ignore ()
task=task,
reason="Dagster process was interrupted",
)
except botocore.exceptions.ClientError as e: # pyright: ignore (reportAttributeAccessIssue)
context.log.warning(
f"[pipes] Couldn't stop ECS task {task} in cluster {cluster}:\n{e}"
)
task = wait_response["tasks"][0]

try:
self._client.stop_task(
cluster=cluster, # pyright: ignore ()
task=wait_response["tasks"][0]["taskArn"], # pyright: ignore (reportGeneralTypeIssues)
reason="Dagster process was interrupted",
)
except botocore.exceptions.ClientError as e: # pyright: ignore (reportAttributeAccessIssue)
context.log.warning(f"[pipes] Couldn't stop ECS task {task} in cluster {cluster}:\n{e}")
Loading

0 comments on commit b80d4f9

Please sign in to comment.