Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kramstrom committed Sep 4, 2024
1 parent 269fb4e commit ce715f9
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 23 deletions.
6 changes: 2 additions & 4 deletions modal/_asgi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Modal Labs 2022
import asyncio
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, NoReturn, Optional, Tuple, cast
from typing import Any, AsyncGenerator, Callable, Dict, NoReturn, Optional, Tuple, cast

import aiohttp

Expand Down Expand Up @@ -59,9 +59,7 @@ async def lifespan_shutdown(self):
await self.shutdown


def asgi_app_wrapper(
asgi_app, function_io_manager
) -> Tuple[Callable[..., AsyncGenerator], Callable[..., Awaitable[None]], Callable[..., Awaitable[None]]]:
def asgi_app_wrapper(asgi_app, function_io_manager) -> Tuple[Callable[..., AsyncGenerator], LifespanManager]:
state = {} # used for lifespan state

async def fn(scope):
Expand Down
18 changes: 9 additions & 9 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,10 +844,11 @@ def breakpoint_wrapper():
# Execute the function.
try:
for finalized_function in finalized_functions.values():
event_loop.create_task(finalized_function.lifespan_manager.background_task())
call_lifecycle_functions(
event_loop, container_io_manager, [finalized_function.lifespan_manager.lifespan_startup]
)
if finalized_function.lifespan_manager:
event_loop.create_task(finalized_function.lifespan_manager.background_task())
call_lifecycle_functions(
event_loop, container_io_manager, [finalized_function.lifespan_manager.lifespan_startup]
)
call_function(
event_loop,
container_io_manager,
Expand All @@ -863,11 +864,10 @@ def breakpoint_wrapper():
usr1_handler = signal.signal(signal.SIGUSR1, signal.SIG_IGN)

for finalized_function in finalized_functions.values():
call_lifecycle_functions(
event_loop, container_io_manager, [finalized_function.lifespan_manager.lifespan_shutdown]
)
for task in event_loop.tasks:
task.cancel()
if finalized_function.lifespan_manager:
call_lifecycle_functions(
event_loop, container_io_manager, [finalized_function.lifespan_manager.lifespan_shutdown]
)

try:
# Identify "exit" methods and run them.
Expand Down
7 changes: 5 additions & 2 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from contextlib import AsyncExitStack
from dataclasses import dataclass
from pathlib import Path
from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, Dict, List, Optional, Tuple

from google.protobuf.empty_pb2 import Empty
from google.protobuf.message import Message
Expand All @@ -32,6 +32,9 @@
from .exception import InputCancellation, InvalidError, SerializationError
from .running_app import RunningApp

if TYPE_CHECKING:
import modal._asgi

DYNAMIC_CONCURRENCY_INTERVAL_SECS = 3
DYNAMIC_CONCURRENCY_TIMEOUT_SECS = 10
MAX_OUTPUT_BATCH_SIZE: int = 49
Expand All @@ -50,10 +53,10 @@ class Sentinel:
@dataclass
class FinalizedFunction:
callable: Callable[..., Any]
lifespan_manager: Any
is_async: bool
is_generator: bool
data_format: int # api_pb2.DataFormat
lifespan_manager: Optional["modal._asgi.LifespanManager"] = None


class IOContext:
Expand Down
35 changes: 32 additions & 3 deletions test/container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,9 @@ def test_asgi(servicer):


@skip_github_non_linux
def test_asgi_with_lifespan(servicer):
inputs = _get_web_inputs(path="/foo")
def test_asgi_with_lifespan(servicer, capsys):
inputs = _get_web_inputs(path="/")

_put_web_body(servicer, b"")
ret = _run_container(
servicer,
Expand All @@ -689,7 +690,35 @@ def test_asgi_with_lifespan(servicer):
assert headers[b"content-type"] == b"application/json"

# Check body
assert json.loads(second_message["body"]) == {"hello": "space"}
assert json.loads(second_message["body"]) == "foo"

captured = capsys.readouterr()
assert captured.out.strip().split("\n") == ["enter", "foo", "exit"]


@skip_github_non_linux
def test_cls_web_asgi_with_lifespan(servicer, capsys):
servicer.app_objects.setdefault("ap-1", {}).setdefault("square", "fu-2")
servicer.app_functions["fu-2"] = api_pb2.Function()

inputs = _get_web_inputs(method_name="my_app1")
ret = _run_container(
servicer,
"test.supports.functions",
"fastapi_class_multiple_asgi_apps_lifespans.*",
inputs=inputs,
is_class=True,
)

_, second_message = _unwrap_asgi(ret)
assert json.loads(second_message["body"]) == "foo1"
captured = capsys.readouterr()
output_lines: List[str] = captured.out.strip().split("\n")

assert len(output_lines) == 5
assert output_lines[2] == "foo1"
assert ["enter1", "enter2"] == sorted(output_lines[:2])
assert ["exit1", "exit2"] == sorted(output_lines[3:])


@skip_github_non_linux
Expand Down
52 changes: 47 additions & 5 deletions test/supports/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,61 @@ def fastapi_app_with_lifespan():

@contextlib.asynccontextmanager
async def lifespan(wapp: FastAPI):
print("enter!")
print("enter")
yield
print("exit!")
print("exit")

web_app = FastAPI(lifespan=lifespan)

@web_app.get("/foo")
async def foo(arg="world"):
return {"hello": arg}
@web_app.get("/")
async def foo():
print("foo")
return "foo"

return web_app


@app.cls(container_idle_timeout=300, concurrency_limit=1, allow_concurrent_inputs=100)
class fastapi_class_multiple_asgi_apps_lifespans:
@asgi_app()
def my_app1(self):
from fastapi import FastAPI

@contextlib.asynccontextmanager
async def lifespan1(wapp):
print("enter1")
yield
print("exit1")

web_app1 = FastAPI(lifespan=lifespan1)

@web_app1.get("/")
async def foo1():
print("foo1")
return "foo1"

return web_app1

@asgi_app()
def my_app2(self):
from fastapi import FastAPI

@contextlib.asynccontextmanager
async def lifespan2(wapp):
print("enter2")
yield
print("exit2")

web_app2 = FastAPI(lifespan=lifespan2)

@web_app2.get("/")
async def foo2():
print("foo2")
return "foo2"

return web_app2


@app.function()
@asgi_app()
def error_in_asgi_setup():
Expand Down

0 comments on commit ce715f9

Please sign in to comment.