diff --git a/modal/_asgi.py b/modal/_asgi.py index 142e55e20..7139c8555 100644 --- a/modal/_asgi.py +++ b/modal/_asgi.py @@ -16,10 +16,11 @@ class LifespanManager: - def __init__(self, asgi_app): + def __init__(self, asgi_app, state): self.asgi_app = asgi_app self.queue = None self.startup_complete = None + self.state = state async def background_task(self): self.queue = asyncio.Queue() @@ -41,7 +42,7 @@ async def send(message): else: raise ValueError(f"Unexpected message type: {message['type']}") - await self.asgi_app({"type": "lifespan"}, receive, send) + await self.asgi_app({"type": "lifespan", "state": self.state}, receive, send) async def lifespan_startup(self): if self.queue is None or self.shutdown is None: @@ -61,7 +62,12 @@ async def lifespan_shutdown(self): def asgi_app_wrapper( asgi_app, function_io_manager ) -> Tuple[Callable[..., AsyncGenerator], Callable[..., Awaitable[None]], Callable[..., Awaitable[None]]]: + state = {} # used for lifespan state + async def fn(scope): + if "state" in scope: + raise ValueError("Unpexected state in ASGI scope") + scope["state"] = state function_call_id = current_function_call_id() assert function_call_id, "internal error: function_call_id not set in asgi_app() scope" @@ -173,7 +179,7 @@ async def receive(): app_task.result() # consume/raise exceptions if there are any! break - return fn, LifespanManager(asgi_app) + return fn, LifespanManager(asgi_app, state) def wsgi_app_wrapper(wsgi_app, function_io_manager):