Skip to content

Commit

Permalink
Merge branch 'kramstrom/add-iter-gen-to-function-call' of github.com:…
Browse files Browse the repository at this point in the history
…modal-labs/modal-client into kramstrom/add-iter-gen-to-function-call
  • Loading branch information
kramstrom committed Aug 27, 2024
2 parents e91ae89 + 52f218d commit 7d98964
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 5 deletions.
29 changes: 26 additions & 3 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,7 +1334,10 @@ async def spawn(self, *args: P.args, **kwargs: P.kwargs) -> "_FunctionCall[R]":
invocation = await self._call_generator_nowait(args, kwargs)
else:
invocation = await self._call_function_nowait(args, kwargs)
return _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)

fc = _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)
fc._is_generator = self._is_generator if self._is_generator else False
return fc

def get_raw_f(self) -> Callable[..., Any]:
"""Return the inner Python object wrapped by this Modal Function."""
Expand Down Expand Up @@ -1374,6 +1377,8 @@ class _FunctionCall(typing.Generic[R], _Object, type_prefix="fc"):
Conceptually similar to a Future/Promise/AsyncResult in other contexts and languages.
"""

_is_generator: bool = False

def _invocation(self):
assert self._client.stub
return _Invocation(self._client.stub, self.object_id, self._client)
Expand All @@ -1387,8 +1392,22 @@ async def get(self, timeout: Optional[float] = None) -> R:
The returned coroutine is not cancellation-safe.
"""

if self._is_generator:
raise Exception("Cannot get the result of a generator function call. Use `iter_gen` instead.")

return await self._invocation().poll_function(timeout=timeout)

async def iter_gen(self) -> AsyncGenerator[Any, None]:
"""
Calls the generator remotely, executing it with the given arguments and returning the execution's result.
"""
if not self._is_generator:
raise Exception("Cannot iterate over a non-generator function call. Use `get` instead.")

async for res in self._invocation().run_generator():
yield res

async def get_call_graph(self) -> List[InputInfo]:
"""Returns a structure representing the call graph from a given root
call ID, along with the status of execution for each node.
Expand Down Expand Up @@ -1418,11 +1437,15 @@ async def cancel(
await retry_transient_errors(self._client.stub.FunctionCallCancel, request)

@staticmethod
async def from_id(function_call_id: str, client: Optional[_Client] = None) -> "_FunctionCall":
async def from_id(
function_call_id: str, client: Optional[_Client] = None, is_generator: bool = False
) -> "_FunctionCall":
if client is None:
client = await _Client.from_env()

return _FunctionCall._new_hydrated(function_call_id, client, None)
fc = _FunctionCall._new_hydrated(function_call_id, client, None)
fc._is_generator = is_generator
return fc


FunctionCall = synchronize_api(_FunctionCall)
Expand Down
36 changes: 34 additions & 2 deletions test/function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ def test_function_future(client, servicer):

assert future.object_id not in servicer.cleared_function_calls

with pytest.raises(Exception, match="Cannot iterate"):
next(future.iter_gen())


@pytest.mark.asyncio
async def test_function_future_async(client, servicer):
Expand Down Expand Up @@ -365,9 +368,16 @@ async def async_dummy():
async def test_generator_future(client, servicer):
app = App()

later_gen_modal = app.function()(later_gen)
servicer.function_body(later_gen)
later_modal = app.function()(later_gen)
with app.run(client=client):
assert later_gen_modal.spawn() is None # until we have a nice interface for polling generator futures
future = later_modal.spawn()
assert isinstance(future, FunctionCall)

with pytest.raises(Exception, match="Cannot get"):
future.get()

assert next(future.iter_gen()) == "foo"


def gen_with_arg(i):
Expand Down Expand Up @@ -547,6 +557,28 @@ def foo():
assert rehydrated_function_call.object_id == function_call.object_id


@pytest.mark.parametrize("is_generator", [False, True])
def test_from_id_iter_gen(client, servicer, is_generator):
app = App()

f = later_gen if is_generator else later

servicer.function_body(f)
later_modal = app.function()(f)
with app.run(client=client):
future = later_modal.spawn()
assert isinstance(future, FunctionCall)

assert future.object_id
rehydrated_function_call = FunctionCall.from_id(future.object_id, client, is_generator=is_generator)
assert rehydrated_function_call.object_id == future.object_id

if is_generator:
assert next(rehydrated_function_call.iter_gen()) == "foo"
else:
assert rehydrated_function_call.get() == "hello"


lc_app = App()


Expand Down

0 comments on commit 7d98964

Please sign in to comment.