Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrent Logging Interface #2167

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):

_client: _Client = synchronizer._translate_in(client) # TODO(erikbern): ugly

with container_io_manager.heartbeats(is_snapshotting_function), UserCodeEventLoop() as event_loop:
with container_io_manager.heartbeats_and_logs(is_snapshotting_function), UserCodeEventLoop() as event_loop:
# If this is a serialized function, fetch the definition from the server
if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
ser_cls, ser_fun = container_io_manager.get_serialized_function()
Expand Down
61 changes: 54 additions & 7 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import time
import traceback
from collections import defaultdict
from contextlib import AsyncExitStack
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -55,6 +56,31 @@ class FinalizedFunction:
data_format: int # api_pb2.DataFormat


class LabeledLogAccumulator:
"""Accumulates logs across inputs, sending them in a single RPC call. Flushed by external caller."""

data_by_key: Dict[tuple[str, str], str]
ready: asyncio.Event

def __init__(self):
self.data_by_key = defaultdict(str)
self.ready = asyncio.Event()

def add(self, input_id: str, function_call_id: str, fd: api_pb2.FileDescriptor, data: str) -> None:
key = (input_id, function_call_id, fd)
self.data_by_key[key] += data
self.ready.set()

async def flush(self, client: _Client) -> None:
logs = [
api_pb2.TaskLogs(input_id=input_id, function_call_id=function_call_id, data=data, file_descriptor=fd)
for (input_id, function_call_id, fd), data in self.data_by_key.items()
]
self.data_by_key.clear()
self.ready.clear()
await retry_transient_errors(client.stub.ContainerLog, api_pb2.ContainerLogRequest(logs=logs))


class IOContext:
"""Context object for managing input, function calls, and function executions
in a batched or single input context.
Expand Down Expand Up @@ -245,6 +271,8 @@ class _ContainerIOManager:
current_inputs: Dict[str, IOContext] # input_id -> IOContext
current_input_started_at: Optional[float]

log_accumulator: LabeledLogAccumulator

_target_concurrency: int
_max_concurrency: int
_concurrency_loop: Optional[asyncio.Task]
Expand All @@ -253,6 +281,7 @@ class _ContainerIOManager:
_environment_name: str
_heartbeat_loop: Optional[asyncio.Task]
_heartbeat_condition: asyncio.Condition
_log_rpc_condition: asyncio.Condition
_waiting_for_memory_snapshot: bool

_is_interactivity_enabled: bool
Expand All @@ -276,6 +305,8 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client):
self.current_inputs = {}
self.current_input_started_at = None

self.log_accumulator = LabeledLogAccumulator()

if container_args.function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
target_concurrency = 1
max_concurrency = 0
Expand All @@ -291,6 +322,7 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client):
self._environment_name = container_args.environment_name
self._heartbeat_loop = None
self._heartbeat_condition = asyncio.Condition()
self._log_rpc_condition = asyncio.Condition()
self._waiting_for_memory_snapshot = False

self._is_interactivity_enabled = False
Expand Down Expand Up @@ -376,20 +408,33 @@ async def _heartbeat_handle_cancellations(self) -> bool:
return True
return False

async def _run_structured_log_loop(self) -> None:
while 1:
# Prevent RPCs from going out before snapshot restore complete.
async with self._log_rpc_condition:
while self._waiting_for_memory_snapshot:
await self._log_rpc_condition.wait()

await self.log_accumulator.ready.wait()
await self.log_accumulator.flush(self._client)
await asyncio.sleep(0.01)

@asynccontextmanager
async def heartbeats(self, wait_for_mem_snap: bool) -> AsyncGenerator[None, None]:
async def heartbeats_and_logs(self, wait_for_mem_snap: bool) -> AsyncGenerator[None, None]:
async with TaskContext() as tc:
self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop())
t.set_name("heartbeat loop")
self._heartbeat_loop = tc.create_task(self._run_heartbeat_loop())
self._structured_log_task = tc.create_task(self._run_structured_log_loop())
self._heartbeat_loop.set_name("heartbeat loop")
self._structured_log_task.set_name("structured logs loop")
self._waiting_for_memory_snapshot = wait_for_mem_snap
try:
yield
finally:
t.cancel()
self._structured_log_task.cancel()
self._heartbeat_loop.cancel()

def stop_heartbeat(self):
if self._heartbeat_loop:
self._heartbeat_loop.cancel()
# Flush any remaining logs
await self.log_accumulator.flush(self._client)

@asynccontextmanager
async def dynamic_concurrency_manager(self) -> AsyncGenerator[None, None]:
Expand Down Expand Up @@ -890,6 +935,7 @@ async def memory_snapshot(self) -> None:
# Notify the heartbeat loop that the snapshot phase has begun in order to
# prevent it from sending heartbeat RPCs
self._waiting_for_memory_snapshot = True
self._log_rpc_condition.notify_all()
self._heartbeat_condition.notify_all()

await self._client.stub.ContainerCheckpoint(
Expand All @@ -904,6 +950,7 @@ async def memory_snapshot(self) -> None:
# Turn heartbeats back on. This is safe since the snapshot RPC
# and the restore phase has finished.
self._waiting_for_memory_snapshot = False
self._log_rpc_condition.notify_all()
self._heartbeat_condition.notify_all()

async def volume_commit(self, volume_ids: List[str]) -> None:
Expand Down
69 changes: 69 additions & 0 deletions modal/experimental.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
# Copyright Modal Labs 2022
import contextlib
import sys
from typing import Literal

from modal.execution_context import current_function_call_id, current_input_id
from modal_proto import api_pb2

from ._container_io_manager import _ContainerIOManager


Expand All @@ -13,3 +20,65 @@ def get_local_input_concurrency():
"""Get the container's local input concurrency. Return 0 if the container is not running."""

return _ContainerIOManager.get_input_concurrency()


def log(message: str, end: str = "\n", stream: Literal["stdout", "stderr"] = "stdout") -> None:
"""Add message to the current input's log for structured logging."""
io_manager = _ContainerIOManager._singleton
assert io_manager is not None
fd = (
api_pb2.FileDescriptor.FILE_DESCRIPTOR_STDOUT
if stream == "stdout"
else api_pb2.FileDescriptor.FILE_DESCRIPTOR_STDERR
)
io_manager.log_accumulator.add(
input_id=current_input_id(),
function_call_id=current_function_call_id(),
fd=fd,
data=message + end,
)


class DemuxStream:
"""
Demultiplexes a stream from multiple concurrent inputs.
"""

def __init__(self, passthrough: any, fd: api_pb2.FileDescriptor, io_manager: "_ContainerIOManager"):
self.passthrough = passthrough
self.fd = fd
self.io_manager = io_manager

def write(self, data: str):
input_id, function_call_id = current_input_id(), current_function_call_id()
if input_id is None or function_call_id is None:
self.passthrough.write(data + " (no tag)")
else:
self.io_manager.log_accumulator.add(
input_id,
function_call_id,
self.fd,
data,
)

def flush(self):
pass

def getvalue(self):
return ""


@contextlib.contextmanager
def tagged_logs():
io_manager = _ContainerIOManager._singleton
assert io_manager is not None

original_stdout = sys.stdout
original_stderr = sys.stderr
try:
sys.stdout = DemuxStream(original_stdout, api_pb2.FileDescriptor.FILE_DESCRIPTOR_STDOUT, io_manager)
sys.stderr = DemuxStream(original_stderr, api_pb2.FileDescriptor.FILE_DESCRIPTOR_STDERR, io_manager)
yield
finally:
sys.stdout = original_stdout
sys.stderr = original_stderr
5 changes: 2 additions & 3 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -740,8 +740,7 @@ message ContainerHeartbeatResponse {
}

message ContainerLogRequest {
string message = 1;
string input_id = 2;
repeated TaskLogs logs = 1;
}

message ContainerStopRequest {
Expand Down Expand Up @@ -2329,7 +2328,7 @@ service ModalClient {
rpc ContainerExecPutInput(ContainerExecPutInputRequest) returns (google.protobuf.Empty);
rpc ContainerExecWait(ContainerExecWaitRequest) returns (ContainerExecWaitResponse);
rpc ContainerHeartbeat(ContainerHeartbeatRequest) returns (ContainerHeartbeatResponse);
rpc ContainerLog(ContainerLogRequest) returns (google.protobuf.Empty);
rpc ContainerLog(stream ContainerLogRequest) returns (google.protobuf.Empty);
rpc ContainerStop(ContainerStopRequest) returns (ContainerStopResponse);

// Dicts
Expand Down