diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index 37856a38c..d28123320 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -190,6 +190,37 @@ def serialized_function(self) -> bytes: logger.debug(f"Serializing function for class service function {self.cls.__qualname__} as empty") return b"" + def get_cls_vars(self) -> Dict[str, Any]: + if self.cls is not None: + cls_vars = { + attr: getattr(self.cls, attr) + for attr in dir(self.cls) + if not callable(getattr(self.cls, attr)) and not attr.startswith("__") + } + return cls_vars + return {} + + def get_cls_var_attrs(self) -> Dict[str, Any]: + import dis + + import opcode + + LOAD_ATTR = opcode.opmap["LOAD_ATTR"] + STORE_ATTR = opcode.opmap["STORE_ATTR"] + + func = self.raw_f + code = func.__code__ + f_attr_ops = set() + for instr in dis.get_instructions(code): + if instr.opcode == LOAD_ATTR: + f_attr_ops.add(instr.argval) + elif instr.opcode == STORE_ATTR: + f_attr_ops.add(instr.argval) + + cls_vars = self.get_cls_vars() + f_attrs = {k: cls_vars[k] for k in cls_vars if k in f_attr_ops} + return f_attrs + def get_globals(self) -> Dict[str, Any]: from .._vendor.cloudpickle import _extract_code_globals diff --git a/modal/image.py b/modal/image.py index 103b82202..0a2597b0d 100644 --- a/modal/image.py +++ b/modal/image.py @@ -322,6 +322,8 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s build_function_id = build_function.object_id globals = build_function._get_info().get_globals() + attrs = build_function._get_info().get_cls_var_attrs() + globals = {**globals, **attrs} filtered_globals = {} for k, v in globals.items(): if isfunction(v): diff --git a/test/conftest.py b/test/conftest.py index 5fe7737f6..73b59a19c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -825,6 +825,10 @@ async def FunctionUpdateSchedulingParams(self, stream): async def ImageGetOrCreate(self, stream): request: api_pb2.ImageGetOrCreateRequest = await stream.recv_message() + for k in self.images: + if request.image.SerializeToString() == self.images[k].SerializeToString(): + await stream.send_message(api_pb2.ImageGetOrCreateResponse(image_id=k)) + return idx = len(self.images) + 1 image_id = f"im-{idx}" diff --git a/test/image_test.py b/test/image_test.py index 032aff850..74155ccb7 100644 --- a/test/image_test.py +++ b/test/image_test.py @@ -750,9 +750,59 @@ def f(self): print("bar!", VARIABLE_6) +class FooInstance: + not_used_by_build_method: str = "normal" + used_by_build_method: str = "normal" + + @build() + def build_func(self): + global VARIABLE_5 + + print("global variable", VARIABLE_5) + print("static class var", FooInstance.used_by_build_method) + FooInstance.used_by_build_method = "normal" + + +def test_image_cls_var_rebuild(client, servicer): + rebuild_app = App() + image_ids = [] + rebuild_app.cls(image=Image.debian_slim())(FooInstance) + with rebuild_app.run(client=client): + image_ids = list(servicer.images) + FooInstance.used_by_build_method = "rebuild" + rebuild_app.cls(image=Image.debian_slim())(FooInstance) + with rebuild_app.run(client=client): + image_ids_rebuild = list(servicer.images) + # Ensure that a new image was created + assert image_ids[-1] != image_ids_rebuild[-1] + FooInstance.used_by_build_method = "normal" + rebuild_app.cls(image=Image.debian_slim())(FooInstance) + with rebuild_app.run(client=client): + image_ids = list(servicer.images) + # Ensure that no new image was created + assert len(image_ids) == len(image_ids_rebuild) + + +def test_image_cls_var_no_rebuild(client, servicer): + rebuild_app = App() + image_id = -1 + rebuild_app.cls(image=Image.debian_slim())(FooInstance) + with rebuild_app.run(client=client): + image_id = list(servicer.images)[-1] + rebuild_app.cls(image=Image.debian_slim())(FooInstance) + with rebuild_app.run(client=client): + image_id2 = list(servicer.images)[-1] + FooInstance.not_used_by_build_method = "no rebuild" + rebuild_app.cls(image=Image.debian_slim())(FooInstance) + with rebuild_app.run(client=client): + image_id3 = list(servicer.images)[-1] + assert image_id == image_id2 + assert image_id2 == image_id3 + + def test_image_build_snapshot(client, servicer): with cls_app.run(client=client): - image_id = list(servicer.images.keys())[-1] + image_id = list(servicer.images)[-1] layers = get_image_layers(image_id, servicer) assert "foo!" in layers[0].build_function.definition