Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
kramstrom committed Sep 4, 2024
1 parent a2702ff commit 3863e1b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
13 changes: 7 additions & 6 deletions modal/_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@
class LifespanManager:
def __init__(self, asgi_app, state):
self.asgi_app = asgi_app
self.queue = None
self.startup_complete = None
self.queue: Optional[asyncio.Queue[Dict[str, Any]]] = None
self.startup_complete: Optional[asyncio.Future[None]] = None
self.state = state

async def background_task(self):
self.queue = asyncio.Queue()
self.startup = asyncio.Future()
self.shutdown = asyncio.Future()
self.startup: asyncio.Future[None] = asyncio.Future()
self.shutdown: asyncio.Future[None] = asyncio.Future()

async def receive():
if self.queue is None:
raise ValueError("queue is not initialized, call background_task first")
return await self.queue.get()

async def send(message):
Expand Down Expand Up @@ -60,7 +62,7 @@ async def lifespan_shutdown(self):


def asgi_app_wrapper(asgi_app, function_io_manager) -> Tuple[Callable[..., AsyncGenerator], LifespanManager]:
state = {} # used for lifespan state
state: Dict[str, Any] = {} # used for lifespan state

async def fn(scope):
if "state" in scope:
Expand All @@ -69,7 +71,6 @@ async def fn(scope):
function_call_id = current_function_call_id()
assert function_call_id, "internal error: function_call_id not set in asgi_app() scope"

# TODO: Add support for the ASGI lifecycle spec.
messages_from_app: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(1)
messages_to_app: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(1)

Expand Down
16 changes: 8 additions & 8 deletions test/asgi_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def aio(_function_call_id):
async def test_success():
mock_manager = MockIOManager()
_set_current_context_ids(["in-123"], ["fc-123"])
wrapped_app = asgi_app_wrapper(app, mock_manager)
wrapped_app, lifespan_manager = asgi_app_wrapper(app, mock_manager)
asgi_scope = _asgi_get_scope("/")
outputs = [output async for output in wrapped_app(asgi_scope)]
assert len(outputs) == 2
Expand All @@ -89,7 +89,7 @@ async def test_success():
async def test_endpoint_exception(endpoint_url):
mock_manager = MockIOManager()
_set_current_context_ids(["in-123"], ["fc-123"])
wrapped_app = asgi_app_wrapper(app, mock_manager)
wrapped_app, lifespan_manager = asgi_app_wrapper(app, mock_manager)
asgi_scope = _asgi_get_scope(endpoint_url)
outputs = []

Expand Down Expand Up @@ -122,7 +122,7 @@ async def test_broken_io_unused(caplog):
# and not raise an exception - but print a warning since it's unexpected
mock_manager = BrokenIOManager()
_set_current_context_ids(["in-123"], ["fc-123"])
wrapped_app = asgi_app_wrapper(app, mock_manager)
wrapped_app, lifespan_manager = asgi_app_wrapper(app, mock_manager)
asgi_scope = _asgi_get_scope("/")
outputs = []

Expand All @@ -141,7 +141,7 @@ async def test_broken_io_unused(caplog):
async def test_broken_io_used():
mock_manager = BrokenIOManager()
_set_current_context_ids(["in-123"], ["fc-123"])
wrapped_app = asgi_app_wrapper(app, mock_manager)
wrapped_app, lifespan_manager = asgi_app_wrapper(app, mock_manager)
asgi_scope = _asgi_get_scope("/async_reading_body", "POST")
outputs = []
with pytest.raises(ClientDisconnect):
Expand All @@ -165,7 +165,7 @@ async def aio(_function_call_id):
async def test_first_message_timeout(monkeypatch):
monkeypatch.setattr("modal._asgi.FIRST_MESSAGE_TIMEOUT_SECONDS", 0.1) # simulate timeout
_set_current_context_ids(["in-123"], ["fc-123"])
wrapped_app = asgi_app_wrapper(app, SlowIOManager())
wrapped_app, lifespan_manager = asgi_app_wrapper(app, SlowIOManager())
asgi_scope = _asgi_get_scope("/async_reading_body", "POST")
outputs = []
with pytest.raises(ClientDisconnect):
Expand All @@ -181,7 +181,7 @@ async def test_cancellation_cleanup(caplog):
# this test mostly exists to get some coverage on the cancellation/error paths and
# ensure nothing unexpected happens there
_set_current_context_ids(["in-123"], ["fc-123"])
wrapped_app = asgi_app_wrapper(app, SlowIOManager())
wrapped_app, lifespan_manager = asgi_app_wrapper(app, SlowIOManager())
asgi_scope = _asgi_get_scope("/async_reading_body", "POST")
outputs = []

Expand All @@ -200,7 +200,7 @@ async def app_runner():
@pytest.mark.asyncio
async def test_streaming_response():
_set_current_context_ids(["in-123"], ["fc-123"])
wrapped_app = asgi_app_wrapper(app, SlowIOManager())
wrapped_app, lifespan_manager = asgi_app_wrapper(app, SlowIOManager())
asgi_scope = _asgi_get_scope("/streaming_response", "GET")
outputs = []
async for output in wrapped_app(asgi_scope):
Expand All @@ -227,7 +227,7 @@ async def aio(_function_call_id):
async def test_streaming_body():
_set_current_context_ids(["in-123"], ["fc-123"])

wrapped_app = asgi_app_wrapper(app, StreamingIOManager())
wrapped_app, lifespan_manager = asgi_app_wrapper(app, StreamingIOManager())
asgi_scope = _asgi_get_scope("/async_reading_body", "POST")
outputs = []
async for output in wrapped_app(asgi_scope):
Expand Down

0 comments on commit 3863e1b

Please sign in to comment.