Skip to content

Commit

Permalink
Add support to Any type to parameterized functions
Browse files Browse the repository at this point in the history
add test

typing.Annotated

Rename SandboxSnapshotFromId to SandboxSnapshotGet (#2800)

rename SandboxSnapshotFromId

[auto-commit] [skip ci] Bump the build number

Enable sandbox tests on mac (#2808)

* Enable sandbox tests on mac

* update scheduler placement test

* remove unused import

[auto-commit] [skip ci] Bump the build number
  • Loading branch information
kramstrom committed Jan 27, 2025
1 parent f078598 commit b6dfc0d
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 66 deletions.
2 changes: 2 additions & 0 deletions modal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
try:
from ._runtime.execution_context import current_function_call_id, current_input_id, interact, is_local
from ._tunnel import Tunnel, forward
from ._utils.function_utils import PickleSerialization
from .app import App, Stub
from .client import Client
from .cloud_bucket_mount import CloudBucketMount
Expand Down Expand Up @@ -78,6 +79,7 @@
"interact",
"method",
"parameter",
"PickleSerialization",
"web_endpoint",
"web_server",
"wsgi_app",
Expand Down
4 changes: 3 additions & 1 deletion modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,9 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
param_args, param_kwargs = deserialize(serialized_params, _client)
elif function_def.class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO:
param_args = ()
param_kwargs = deserialize_proto_params(serialized_params, list(function_def.class_parameter_info.schema))
param_kwargs = deserialize_proto_params(
serialized_params, list(function_def.class_parameter_info.schema), _client
)
else:
raise ExecutionError(
f"Unknown class parameter serialization format: {function_def.class_parameter_info.format}"
Expand Down
9 changes: 8 additions & 1 deletion modal/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,9 @@ class ParamTypeInfo:
PARAM_TYPE_MAPPING = {
api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(default_field="string_default", proto_field="string_value", converter=str),
api_pb2.PARAM_TYPE_INT: ParamTypeInfo(default_field="int_default", proto_field="int_value", converter=int),
api_pb2.PARAM_TYPE_PICKLE: ParamTypeInfo(
default_field="pickle_default", proto_field="pickle_value", converter=serialize
),
}


Expand Down Expand Up @@ -425,7 +428,9 @@ def serialize_proto_params(python_params: dict[str, Any], schema: typing.Sequenc
return proto_bytes


def deserialize_proto_params(serialized_params: bytes, schema: list[api_pb2.ClassParameterSpec]) -> dict[str, Any]:
def deserialize_proto_params(
serialized_params: bytes, schema: list[api_pb2.ClassParameterSpec], _client
) -> dict[str, Any]:
proto_struct = api_pb2.ClassParameterSet()
proto_struct.ParseFromString(serialized_params)
value_by_name = {p.name: p for p in proto_struct.parameters}
Expand All @@ -446,6 +451,8 @@ def deserialize_proto_params(serialized_params: bytes, schema: list[api_pb2.Clas
python_value = param_value.string_value
elif schema_param.type == api_pb2.PARAM_TYPE_INT:
python_value = param_value.int_value
elif schema_param.type == api_pb2.PARAM_TYPE_PICKLE:
python_value = deserialize(param_value.pickle_value, _client)
else:
# TODO(elias): based on `parameters` declared types, we could add support for
# custom non proto types encoded as bytes in the proto, e.g. PARAM_TYPE_PYTHON_PICKLE
Expand Down
28 changes: 22 additions & 6 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import AsyncGenerator
from enum import Enum
from pathlib import Path, PurePosixPath
from typing import Any, Callable, Literal, Optional
from typing import Annotated, Any, Callable, Literal, Optional, get_args, get_origin

from grpclib import GRPCError
from grpclib.exceptions import StreamTerminatedError
Expand All @@ -14,7 +14,7 @@
import modal_proto
from modal_proto import api_pb2

from .._serialization import deserialize, deserialize_data_format, serialize
from .._serialization import PARAM_TYPE_MAPPING, deserialize, deserialize_data_format, serialize
from .._traceback import append_modal_tb
from ..config import config, logger
from ..exception import (
Expand All @@ -37,10 +37,15 @@ class FunctionInfoType(Enum):
NOTEBOOK = "notebook"


class PickleSerialization:
pass


# TODO(elias): Add support for quoted/str annotations
CLASS_PARAM_TYPE_MAP: dict[type, tuple["api_pb2.ParameterType.ValueType", str]] = {
str: (api_pb2.PARAM_TYPE_STRING, "string_default"),
int: (api_pb2.PARAM_TYPE_INT, "int_default"),
PickleSerialization: (api_pb2.PARAM_TYPE_PICKLE, "pickle_default"),
}


Expand Down Expand Up @@ -295,12 +300,23 @@ def class_parameter_info(self) -> api_pb2.ClassParameterInfo:
signature = _get_class_constructor_signature(self.user_cls)
for param in signature.parameters.values():
has_default = param.default is not param.empty
if param.annotation not in CLASS_PARAM_TYPE_MAP:
raise InvalidError("modal.parameter() currently only support str or int types")
param_type, default_field = CLASS_PARAM_TYPE_MAP[param.annotation]
pickle_annotated = (
get_origin(param.annotation) == Annotated and PickleSerialization in get_args(param.annotation)[1:]
)
param_annotation = PickleSerialization if pickle_annotated else param.annotation
if param_annotation not in CLASS_PARAM_TYPE_MAP:
raise InvalidError(
"To use custom types you must use typing.Annotated[<type>, modal.PickleSerialization],"
+ f" got {param_annotation}."
)
param_type, default_field = CLASS_PARAM_TYPE_MAP[param_annotation]
class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default, type=param_type)
if has_default:
setattr(class_param_spec, default_field, param.default)
type_info = PARAM_TYPE_MAPPING.get(param_type)
if not type_info:
raise ValueError(f"Unsupported parameter type: {param_type}")
converted_value = type_info.converter(param.default)
setattr(class_param_spec, default_field, converted_value)
modal_parameters.append(class_param_spec)

return api_pb2.ClassParameterInfo(
Expand Down
14 changes: 10 additions & 4 deletions modal/cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import os
import typing
from collections.abc import Collection
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Annotated, Any, Callable, Optional, TypeVar, Union, get_args, get_origin

from google.protobuf.message import Message
from grpclib import GRPCError, Status

from modal._utils.function_utils import CLASS_PARAM_TYPE_MAP
from modal._utils.function_utils import CLASS_PARAM_TYPE_MAP, PickleSerialization
from modal_proto import api_pb2

from ._object import _get_environment_name, _Object
Expand Down Expand Up @@ -227,7 +227,7 @@ async def keep_warm(self, warm_pool_size: int) -> None:
of containers and the warm_pool_size affects that common container pool.
```python notest
# Usage on a parametrized function.
# Usage on a parameterized function.
Model = modal.Cls.lookup("my-app", "Model")
Model("fine-tuned-model").keep_warm(2)
```
Expand Down Expand Up @@ -474,12 +474,18 @@ def validate_construction_mechanism(user_cls):

annotated_params = {k: t for k, t in annotations.items() if k in params}
for k, t in annotated_params.items():
if t not in CLASS_PARAM_TYPE_MAP:
pickle_annotated = get_origin(t) == Annotated and PickleSerialization in get_args(t)[1:]
param_annotation = PickleSerialization if pickle_annotated else t

if param_annotation not in CLASS_PARAM_TYPE_MAP:
t_name = getattr(t, "__name__", repr(t))
supported = ", ".join(t.__name__ for t in CLASS_PARAM_TYPE_MAP.keys())
raise InvalidError(
f"{user_cls.__name__}.{k}: {t_name} is not a supported parameter type. Use one of: {supported}"
)
# TODO:
# raise if cls has webhooks
# and no default value for pickle parameter

@staticmethod
def from_local(user_cls, app: "modal.app._App", class_service_function: _Function) -> "_Cls":
Expand Down
4 changes: 2 additions & 2 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ async def _load(param_bound_func: _Function, resolver: Resolver, existing_object
response = await retry_transient_errors(parent._client.stub.FunctionBindParams, req)
param_bound_func._hydrate(response.bound_function_id, parent._client, response.handle_metadata)

fun: _Function = _Function._from_loader(_load, "Function(parametrized)", hydrate_lazily=True)
fun: _Function = _Function._from_loader(_load, "Function(parameterized)", hydrate_lazily=True)

if can_use_parent and parent.is_hydrated:
# skip the resolver altogether:
Expand All @@ -1022,7 +1022,7 @@ async def keep_warm(self, warm_pool_size: int) -> None:
f = modal.Function.lookup("my-app", "function")
f.keep_warm(2)
# Usage on a parametrized function.
# Usage on a parameterized function.
Model = modal.Cls.lookup("my-app", "Model")
Model("fine-tuned-model").keep_warm(2)
```
Expand Down
18 changes: 9 additions & 9 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2329,14 +2329,6 @@ message SandboxRestoreResponse {
string sandbox_id = 1;
}

message SandboxSnapshotFromIdRequest {
string snapshot_id = 1;
}

message SandboxSnapshotFromIdResponse {
string snapshot_id = 1;
}

message SandboxSnapshotFsRequest {
string sandbox_id = 1;
float timeout = 2;
Expand All @@ -2349,6 +2341,14 @@ message SandboxSnapshotFsResponse {
ImageMetadata image_metadata = 3;
}

message SandboxSnapshotGetRequest {
string snapshot_id = 1;
}

message SandboxSnapshotGetResponse {
string snapshot_id = 1;
}

message SandboxSnapshotRequest {
string sandbox_id = 1;
}
Expand Down Expand Up @@ -2992,8 +2992,8 @@ service ModalClient {
rpc SandboxList(SandboxListRequest) returns (SandboxListResponse);
rpc SandboxRestore(SandboxRestoreRequest) returns (SandboxRestoreResponse);
rpc SandboxSnapshot(SandboxSnapshotRequest) returns (SandboxSnapshotResponse);
rpc SandboxSnapshotFromId(SandboxSnapshotFromIdRequest) returns (SandboxSnapshotFromIdResponse);
rpc SandboxSnapshotFs(SandboxSnapshotFsRequest) returns (SandboxSnapshotFsResponse);
rpc SandboxSnapshotGet(SandboxSnapshotGetRequest) returns (SandboxSnapshotGetResponse);
rpc SandboxSnapshotWait(SandboxSnapshotWaitRequest) returns (SandboxSnapshotWaitResponse);
rpc SandboxStdinWrite(SandboxStdinWriteRequest) returns (SandboxStdinWriteResponse);
rpc SandboxTagsSet(SandboxTagsSetRequest) returns (google.protobuf.Empty);
Expand Down
2 changes: 1 addition & 1 deletion modal_version/_version_generated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright Modal Labs 2025

# Note: Reset this value to -1 whenever you make a minor `0.X` release of the client.
build_number = 49 # git: 432126d
build_number = 51 # git: c7dc212
6 changes: 3 additions & 3 deletions test/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,8 @@ async def _write():
(webhook_app_file, ""), # Function must be inferred
# TODO: fix modal shell auto-detection of a single class, even if it has multiple methods
# (cls_app_file, ""), # Class must be inferred
# (cls_app_file, "AParametrized"), # class name
(cls_app_file, "::AParametrized.some_method"), # method name
# (cls_app_file, "AParameterized"), # class name
(cls_app_file, "::AParameterized.some_method"), # method name
],
)
def test_shell(servicer, set_env_client, supports_dir, mock_shell_pty, rel_file, suffix):
Expand Down Expand Up @@ -777,7 +777,7 @@ def test_cls(servicer, set_env_client, test_dir):
app_file = test_dir / "supports" / "app_run_tests" / "cls.py"

print(_run(["run", app_file.as_posix(), "--x", "42", "--y", "1000"]))
_run(["run", f"{app_file.as_posix()}::AParametrized.some_method", "--x", "42", "--y", "1000"])
_run(["run", f"{app_file.as_posix()}::AParameterized.some_method", "--x", "42", "--y", "1000"])


def test_profile_list(servicer, server_url_env, modal_config):
Expand Down
7 changes: 4 additions & 3 deletions test/cls_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,7 @@ class UsingAnnotationParameters:
a: int = modal.parameter()
b: str = modal.parameter(default="hello")
c: float = modal.parameter(init=False)
d: typing.Annotated[dict, modal.PickleSerialization] = modal.parameter(default={"foo": "bar"})

@method()
def get_value(self):
Expand Down Expand Up @@ -907,10 +908,10 @@ def test_implicit_constructor():
assert c.a == 10
assert c.get_value.local() == 10
assert c.b == "hello"

d = UsingAnnotationParameters(a=11, b="goodbye")
assert c.d == {"foo": "bar"}
d = UsingAnnotationParameters(a=11, b="goodbye", d=[1, 2, 3])
assert d.b == "goodbye"

assert d.d == [1, 2, 3]
# TODO(elias): fix "eager" constructor call validation by looking at signature
# with pytest.raises(TypeError, match="missing a required argument: 'a'"):
# UsingAnnotationParameters()
Expand Down
Loading

0 comments on commit b6dfc0d

Please sign in to comment.