Skip to content

Commit

Permalink
Merge branch 'main' into kramstrom/add-iter-gen-to-function-call
Browse files Browse the repository at this point in the history
  • Loading branch information
kramstrom authored Sep 2, 2024
2 parents af7517d + 51a53c0 commit 95459f6
Show file tree
Hide file tree
Showing 13 changed files with 263 additions and 70 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ We appreciate your patience while we speedily work towards a stable release of t

<!-- NEW CONTENT GENERATED BELOW. PLEASE PRESERVE THIS COMMENT. -->

### 0.64.67 (2024-08-30)

- Fix a regression in `modal launch` behavior not showing progress output when starting the container.



### 0.64.38 (2024-08-16)

- Added a `modal app rollback` CLI command for rolling back an App deployment to a previous version.
Expand Down
39 changes: 16 additions & 23 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class DaemonizedThreadPool:
# Used instead of ThreadPoolExecutor, since the latter won't allow
# the interpreter to shut down before the currently running tasks
# have finished
def __init__(self, max_threads):
def __init__(self, max_threads: int):
self.max_threads = max_threads

def __enter__(self):
Expand Down Expand Up @@ -321,9 +321,8 @@ def call_function(
user_code_event_loop: UserCodeEventLoop,
container_io_manager: "modal._container_io_manager.ContainerIOManager",
finalized_functions: Dict[str, FinalizedFunction],
input_concurrency: int,
batch_max_size: Optional[int],
batch_wait_ms: Optional[int],
batch_max_size: int,
batch_wait_ms: int,
):
async def run_input_async(io_context: IOContext) -> None:
started_at = time.time()
Expand Down Expand Up @@ -416,8 +415,8 @@ def run_input_sync(io_context: IOContext) -> None:
)
reset_context()

if input_concurrency > 1:
with DaemonizedThreadPool(max_threads=input_concurrency) as thread_pool:
if container_io_manager.target_concurrency > 1:
with DaemonizedThreadPool(max_threads=container_io_manager.max_concurrency) as thread_pool:

def make_async_cancel_callback(task):
def f():
Expand Down Expand Up @@ -448,10 +447,10 @@ async def run_concurrent_inputs():
# for them to resolve gracefully:
async with TaskContext(0.01) as task_context:
async for io_context in container_io_manager.run_inputs_outputs.aio(
finalized_functions, input_concurrency, batch_max_size, batch_wait_ms
finalized_functions, batch_max_size, batch_wait_ms
):
# Note that run_inputs_outputs will not return until the concurrency semaphore has
# released all its slots so that they can be acquired by the run_inputs_outputs finalizer
# Note that run_inputs_outputs will not return until all the input slots are released
# so that they can be acquired by the run_inputs_outputs finalizer
# This prevents leaving the task_context before outputs have been created
# TODO: refactor to make this a bit more easy to follow?
if io_context.finalized_function.is_async:
Expand All @@ -464,9 +463,7 @@ async def run_concurrent_inputs():

user_code_event_loop.run(run_concurrent_inputs())
else:
for io_context in container_io_manager.run_inputs_outputs(
finalized_functions, input_concurrency, batch_max_size, batch_wait_ms
):
for io_context in container_io_manager.run_inputs_outputs(finalized_functions, batch_max_size, batch_wait_ms):
if io_context.finalized_function.is_async:
user_code_event_loop.run(run_input_async(io_context))
else:
Expand Down Expand Up @@ -767,16 +764,13 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
# if the app can't be inferred by the imported function, use name-based fallback
active_app = get_active_app_fallback(function_def)

# Container can fetch multiple inputs simultaneously
if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
# Concurrency and batching doesn't apply for `modal shell`.
input_concurrency = 1
batch_max_size = 0
batch_wait_ms = 0
else:
input_concurrency = function_def.allow_concurrent_inputs or 1
batch_max_size = function_def.batch_max_size or 0
batch_wait_ms = function_def.batch_linger_ms or 0
if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
# Concurrency and batching doesn't apply for `modal shell`.
batch_max_size = 0
batch_wait_ms = 0
else:
batch_max_size = function_def.batch_max_size or 0
batch_wait_ms = function_def.batch_linger_ms or 0

# Get ids and metadata for objects (primarily functions and classes) on the app
container_app: RunningApp = container_io_manager.get_app_objects()
Expand Down Expand Up @@ -842,7 +836,6 @@ def breakpoint_wrapper():
event_loop,
container_io_manager,
finalized_functions,
input_concurrency,
batch_max_size,
batch_wait_ms,
)
Expand Down
165 changes: 137 additions & 28 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import time
import traceback
from contextlib import AsyncExitStack
from dataclasses import dataclass
from pathlib import Path
from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, Dict, List, Optional, Tuple
Expand All @@ -30,6 +31,8 @@
from .exception import InputCancellation, InvalidError
from .running_app import RunningApp

DYNAMIC_CONCURRENCY_INTERVAL_SECS = 3
DYNAMIC_CONCURRENCY_TIMEOUT_SECS = 10
MAX_OUTPUT_BATCH_SIZE: int = 49

RTT_S: float = 0.5 # conservative estimate of RTT in seconds.
Expand Down Expand Up @@ -177,6 +180,51 @@ def validate_output_data(self, data: Any) -> List[Any]:
return data


class InputSlots:
"""A semaphore that allows dynamically adjusting the concurrency."""

active: int
value: int
waiter: Optional[asyncio.Future]
closed: bool

def __init__(self, value: int) -> None:
self.active = 0
self.value = value
self.waiter = None
self.closed = False

async def acquire(self) -> None:
if self.active < self.value:
self.active += 1
elif self.waiter is None:
self.waiter = asyncio.get_running_loop().create_future()
await self.waiter
else:
raise RuntimeError("Concurrent waiters are not supported.")

def _wake_waiter(self) -> None:
if self.active < self.value and self.waiter is not None:
self.waiter.set_result(None)
self.waiter = None
self.active += 1

def release(self) -> None:
self.active -= 1
self._wake_waiter()

def set_value(self, value: int) -> None:
if self.closed:
return
self.value = value
self._wake_waiter()

async def close(self) -> None:
self.closed = True
for _ in range(self.value):
await self.acquire()


class _ContainerIOManager:
"""Synchronizes all RPC calls and network operations for a running container.
Expand All @@ -196,8 +244,11 @@ class _ContainerIOManager:
current_inputs: Dict[str, IOContext] # input_id -> IOContext
current_input_started_at: Optional[float]

_input_concurrency: Optional[int]
_semaphore: Optional[asyncio.Semaphore]
_target_concurrency: int
_max_concurrency: int
_concurrency_loop: Optional[asyncio.Task]
_input_slots: InputSlots

_environment_name: str
_heartbeat_loop: Optional[asyncio.Task]
_heartbeat_condition: asyncio.Condition
Expand All @@ -224,9 +275,18 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client):
self.current_inputs = {}
self.current_input_started_at = None

self._input_concurrency = None
if container_args.function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
target_concurrency = 1
max_concurrency = 0
else:
target_concurrency = container_args.function_def.target_concurrent_inputs or 1
max_concurrency = container_args.function_def.max_concurrent_inputs or target_concurrency

self._target_concurrency = target_concurrency
self._max_concurrency = max_concurrency
self._concurrency_loop = None
self._input_slots = InputSlots(target_concurrency)

self._semaphore = None
self._environment_name = container_args.environment_name
self._heartbeat_loop = None
self._heartbeat_condition = asyncio.Condition()
Expand Down Expand Up @@ -297,7 +357,7 @@ async def _heartbeat_handle_cancellations(self) -> bool:
# Pause processing of the current input by signaling self a SIGUSR1.
input_ids_to_cancel = response.cancel_input_event.input_ids
if input_ids_to_cancel:
if self._input_concurrency > 1:
if self._target_concurrency > 1:
for input_id in input_ids_to_cancel:
if input_id in self.current_inputs:
self.current_inputs[input_id].cancel()
Expand Down Expand Up @@ -330,6 +390,39 @@ def stop_heartbeat(self):
if self._heartbeat_loop:
self._heartbeat_loop.cancel()

@asynccontextmanager
async def dynamic_concurrency_manager(self) -> AsyncGenerator[None, None]:
async with TaskContext() as tc:
self._concurrency_loop = t = tc.create_task(self._dynamic_concurrency_loop())
t.set_name("dynamic concurrency loop")
try:
yield
finally:
t.cancel()

async def _dynamic_concurrency_loop(self):
logger.debug(f"Starting dynamic concurrency loop for task {self.task_id}")
while 1:
try:
request = api_pb2.FunctionGetDynamicConcurrencyRequest(
function_id=self.function_id,
target_concurrency=self._target_concurrency,
max_concurrency=self._max_concurrency,
)
resp = await retry_transient_errors(
self._client.stub.FunctionGetDynamicConcurrency,
request,
attempt_timeout=DYNAMIC_CONCURRENCY_TIMEOUT_SECS,
)
if resp.concurrency != self._input_slots.value:
logger.debug(f"Dynamic concurrency set from {self._input_slots.value} to {resp.concurrency}")
self._input_slots.set_value(resp.concurrency)

except Exception as exc:
logger.debug(f"Failed to get dynamic concurrency for task {self.task_id}, {exc}")

await asyncio.sleep(DYNAMIC_CONCURRENCY_INTERVAL_SECS)

async def get_app_objects(self) -> RunningApp:
req = api_pb2.AppGetObjectsRequest(app_id=self.app_id, include_unindexed=True)
resp = await retry_transient_errors(self._client.stub.AppGetObjects, req)
Expand Down Expand Up @@ -470,12 +563,13 @@ async def _generate_inputs(
request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
iteration = 0
while self._fetching_inputs:
await self._input_slots.acquire()

request.average_call_time = self.get_average_call_time()
request.max_values = self.get_max_inputs_to_fetch() # Deprecated; remove.
request.input_concurrency = self._input_concurrency
request.input_concurrency = self.get_input_concurrency()
request.batch_max_size, request.batch_linger_ms = batch_max_size, batch_wait_ms

await self._semaphore.acquire()
yielded = False
try:
# If number of active inputs is at max queue size, this will block.
Expand Down Expand Up @@ -508,7 +602,7 @@ async def _generate_inputs(
final_input_received = True
break

# If yielded, allow semaphore to be released via exit_context
# If yielded, allow input slots to be released via exit_context
yield inputs
yielded = True

Expand All @@ -517,35 +611,34 @@ async def _generate_inputs(
return
finally:
if not yielded:
self._semaphore.release()
self._input_slots.release()

@synchronizer.no_io_translation
async def run_inputs_outputs(
self,
finalized_functions: Dict[str, FinalizedFunction],
input_concurrency: int = 1,
batch_max_size: int = 0,
batch_wait_ms: int = 0,
) -> AsyncIterator[IOContext]:
# Ensure we do not fetch new inputs when container is too busy.
# Before trying to fetch an input, acquire the semaphore:
# - if no input is fetched, release the semaphore.
# - or, when the output for the fetched input is sent, release the semaphore.
self._input_concurrency = input_concurrency
self._semaphore = asyncio.Semaphore(input_concurrency)

async for inputs in self._generate_inputs(batch_max_size, batch_wait_ms):
io_context = await IOContext.create(self._client, finalized_functions, inputs, batch_max_size > 0)
for input_id in io_context.input_ids:
self.current_inputs[input_id] = io_context
# Before trying to fetch an input, acquire an input slot:
# - if no input is fetched, release the input slot.
# - or, when the output for the fetched input is sent, release the input slot.
dynamic_concurrency_manager = (
self.dynamic_concurrency_manager() if self._max_concurrency > self._target_concurrency else AsyncExitStack()
)
async with dynamic_concurrency_manager:
async for inputs in self._generate_inputs(batch_max_size, batch_wait_ms):
io_context = await IOContext.create(self._client, finalized_functions, inputs, batch_max_size > 0)
for input_id in io_context.input_ids:
self.current_inputs[input_id] = io_context

self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time()
yield io_context
self.current_input_id, self.current_input_started_at = (None, None)
self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time()
yield io_context
self.current_input_id, self.current_input_started_at = (None, None)

# collect all active input slots, meaning all inputs have wrapped up.
for _ in range(input_concurrency):
await self._semaphore.acquire()
# collect all active input slots, meaning all inputs have wrapped up.
await self._input_slots.close()

@synchronizer.no_io_translation
async def _push_outputs(
Expand Down Expand Up @@ -666,14 +759,16 @@ async def handle_input_exception(
repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000]
repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception"

data: bytes = self.serialize_exception(exc) or b""
data_result_part = await self.format_blob_data(data)
results = [
api_pb2.GenericResult(
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
exception=repr_exc,
traceback=traceback.format_exc(),
serialized_tb=serialized_tb,
tb_line_cache=tb_line_cache,
**await self.format_blob_data(self.serialize_exception(exc)),
**data_result_part,
)
for _ in io_context.input_ids
]
Expand All @@ -692,7 +787,7 @@ async def exit_context(self, started_at, input_ids: List[str]):
for input_id in input_ids:
self.current_inputs.pop(input_id)

self._semaphore.release()
self._input_slots.release()

@synchronizer.no_io_translation
async def push_outputs(self, io_context: IOContext, started_at: float, data: Any, data_format: int) -> None:
Expand Down Expand Up @@ -840,6 +935,20 @@ async def interact(self, from_breakpoint: bool = False):
print("Error: Failed to start PTY shell.")
raise e

@property
def target_concurrency(self) -> int:
return self._target_concurrency

@property
def max_concurrency(self) -> int:
return self._max_concurrency

@classmethod
def get_input_concurrency(cls) -> int:
io_manager = cls._singleton
assert io_manager
return io_manager._input_slots.value

@classmethod
def stop_fetching_inputs(cls):
assert cls._singleton
Expand Down
Loading

0 comments on commit 95459f6

Please sign in to comment.