Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebuild image on changes to static class variables #1959

Merged
merged 10 commits into from
Jul 5, 2024
31 changes: 31 additions & 0 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,37 @@ def serialized_function(self) -> bytes:
logger.debug(f"Serializing function for class service function {self.cls.__qualname__} as empty")
return b""

def get_cls_vars(self) -> Dict[str, Any]:
if self.cls is not None:
cls_vars = {
attr: getattr(self.cls, attr)
for attr in dir(self.cls)
if not callable(getattr(self.cls, attr)) and not attr.startswith("__")
}
return cls_vars
return {}

def get_cls_var_attrs(self) -> Dict[str, Any]:
import dis

import opcode
Comment on lines +204 to +206
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these heavy imports? We should already have loaded them anyway on client startup via modal._serialization -> modal._vendor.cloudpickle. Maybe we should care more about optimizing client startup time (although I think we need serialization utilities to do most things).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They should be pretty light

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, not reason not to import them at the top of the module then IMO


LOAD_ATTR = opcode.opmap["LOAD_ATTR"]
STORE_ATTR = opcode.opmap["STORE_ATTR"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it weird to rebuild if users only store a new value for a class attribute, but don't read it, during the build phase? (Actually, separate discussion, but since we're doing this anyay we should maybe warn or error if users try to set class/instance attributes in @build, since that won't propagate to runtime containers, and they get confused about it occasionally.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is fair


func = self.raw_f
code = func.__code__
f_attr_ops = set()
for instr in dis.get_instructions(code):
if instr.opcode == LOAD_ATTR:
f_attr_ops.add(instr.argval)
elif instr.opcode == STORE_ATTR:
f_attr_ops.add(instr.argval)

cls_vars = self.get_cls_vars()
f_attrs = {k: cls_vars[k] for k in cls_vars if k in f_attr_ops}
return f_attrs

def get_globals(self) -> Dict[str, Any]:
from .._vendor.cloudpickle import _extract_code_globals

Expand Down
2 changes: 2 additions & 0 deletions modal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s
build_function_id = build_function.object_id

globals = build_function._get_info().get_globals()
attrs = build_function._get_info().get_cls_var_attrs()
globals = {**globals, **attrs}
filtered_globals = {}
for k, v in globals.items():
if isfunction(v):
Expand Down
4 changes: 4 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,10 @@ async def FunctionUpdateSchedulingParams(self, stream):

async def ImageGetOrCreate(self, stream):
request: api_pb2.ImageGetOrCreateRequest = await stream.recv_message()
for k in self.images:
if request.image.SerializeToString() == self.images[k].SerializeToString():
await stream.send_message(api_pb2.ImageGetOrCreateResponse(image_id=k))
return
idx = len(self.images) + 1
image_id = f"im-{idx}"

Expand Down
52 changes: 51 additions & 1 deletion test/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,9 +750,59 @@ def f(self):
print("bar!", VARIABLE_6)


class FooInstance:
not_used_by_build_method: str = "normal"
used_by_build_method: str = "normal"

@build()
def build_func(self):
global VARIABLE_5

print("global variable", VARIABLE_5)
print("static class var", FooInstance.used_by_build_method)
FooInstance.used_by_build_method = "normal"


def test_image_cls_var_rebuild(client, servicer):
thundergolfer marked this conversation as resolved.
Show resolved Hide resolved
rebuild_app = App()
image_ids = []
rebuild_app.cls(image=Image.debian_slim())(FooInstance)
with rebuild_app.run(client=client):
image_ids = list(servicer.images)
FooInstance.used_by_build_method = "rebuild"
rebuild_app.cls(image=Image.debian_slim())(FooInstance)
with rebuild_app.run(client=client):
image_ids_rebuild = list(servicer.images)
# Ensure that a new image was created
assert image_ids[-1] != image_ids_rebuild[-1]
FooInstance.used_by_build_method = "normal"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you intend to check that this doesn't rebuild again after you revert the change? That seems like a good idea if so.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sure good point

rebuild_app.cls(image=Image.debian_slim())(FooInstance)
with rebuild_app.run(client=client):
image_ids = list(servicer.images)
# Ensure that no new image was created
assert len(image_ids) == len(image_ids_rebuild)


def test_image_cls_var_no_rebuild(client, servicer):
rebuild_app = App()
image_id = -1
rebuild_app.cls(image=Image.debian_slim())(FooInstance)
with rebuild_app.run(client=client):
image_id = list(servicer.images)[-1]
rebuild_app.cls(image=Image.debian_slim())(FooInstance)
with rebuild_app.run(client=client):
image_id2 = list(servicer.images)[-1]
FooInstance.not_used_by_build_method = "no rebuild"
rebuild_app.cls(image=Image.debian_slim())(FooInstance)
with rebuild_app.run(client=client):
image_id3 = list(servicer.images)[-1]
assert image_id == image_id2
assert image_id2 == image_id3


def test_image_build_snapshot(client, servicer):
with cls_app.run(client=client):
image_id = list(servicer.images.keys())[-1]
image_id = list(servicer.images)[-1]
layers = get_image_layers(image_id, servicer)

assert "foo!" in layers[0].build_function.definition
Expand Down
Loading