-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Export a model with multiple entry points #7458
Comments
We need to have more documentation regarding multimethods...you can checkout the unit test written in #7281 |
Thanks, would you expect multimethods to work when modifying some inner state of a module? from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import (
ExecutorBackendPartitioner,
)
import torch
class SharedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self._v = torch.nn.Parameter(torch.ones(1, dtype=torch.float))
class Module1(torch.nn.Module):
def __init__(self, shared_module):
super().__init__()
self.shared_module = shared_module
def forward(self, x):
self.shared_module._v[:] = self.shared_module._v + x
return self.shared_module._v
class Module2(torch.nn.Module):
def __init__(self, shared_module):
super().__init__()
self.shared_module = shared_module
def forward(self, x):
self.shared_module._v.fill_(0.0)
return x
def export():
shared_module = SharedModule()
module_1 = Module1(shared_module)
module_2 = Module2(shared_module)
example_inputs = (torch.randn(1),)
module_1(*example_inputs)
module_2(*example_inputs)
ep1 = torch.export.export_for_training(module_1, example_inputs)
ep2 = torch.export.export_for_training(module_2, example_inputs)
edge_program_manager = executorch.exir.to_edge(
{
"forward1": ep1,
"forward2": ep2,
},
compile_config=executorch.exir.EdgeCompileConfig(
_check_ir_validity=False, _use_edge_ops=True
),
)
edge_program_manager = edge_program_manager.to_backend(ExecutorBackendPartitioner()).to_executorch()
with torch.no_grad():
export() However this resulted in the following error. Maybe there is something wrong in the way I'm trying to export the modules?
edit
|
I would be somewhat surprised if we have a support for shared state that can be modified from two different methods. Like self._v in your example. Although note that instead of using Parameter you might wanna use register_buffer API from torch. cc @JacobSzwejbka for shared mutable state. Now regarding the actual error, I suspect it will be also resolved if you register buffer since only buffer can be mutated/changed by the program and not parameters. THat is why I think you are seeing that error (https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/schemas.py#L824) |
@LaurentMazare in case this helps, here are some examples of the |
Hi, @LaurentMazare So, I have been lurking here for a bit... figured I would hijack the thread, looking for feedback. I ran across the same series of forum posts and git issues as you, and have just got a basic version of multiple method export with shared mutable state working here: https://github.com/cptspacemanspiff/execu-tools This is mostly based on #3518, and consists of a python export wrapper and additional specialized code on the c++ runner side. HowState is not technically 'shared' at least not on the python side, it just so happens that the memory addresses that contain the shared data used by method 1, are the same as the memory addresses that contains the data for method 2. On the python side:
On the C++ side:
I got it to work for a toy example and kind of want feedback, when I started down this rabbit hole, I was not expecting it to be so involved (I just wanted an easy encoder/decoder pipeline for hugging face models...). So I would kind of like external feedback on how janky this is.../ is it something that should be done at all... Also @JacobSzwejbka with regards to any hidden issues that using the memory planner in this way can cause, especially with hacking with the mutable buffers. |
@cptspacemanspiff can you put an example PR whenever you think it is ready. I presume you are using this in your private repo but I would be curious to know if it can serve either as an example to achieve something similar or just be a contribution to ET stack. Would love if it can be latter. |
Using memory planning to essentially move the shared state to their own id and then allocating the same buffer for those at the runtime is currently the best way to do this. Its pretty crap UX though. Ive been trying to think of a better way to do it. At the very least we should probably have a default memory plan utility that will do this for you if you give it the names of the buffers to lift. Its a little hard because export itself doesn't have a good concept of multiple entry points today, so we have to do something down stream anyway. We hack export by overriding forward (Im sure this blows up in a really dumb way if the method you are overwriting to forward itself calls forward). This also won't work if the shared state needs to be consumed by a delegate aot. We currently don't have a way for delegates (like XNNPack CoreML etc) to share the state in a non hacked way. Though we are actively looking into that this month. |
Maybe we introduce the concept of Program scoped memory arenas vs methods? That way the shared nature is explicit in the schema and lowering. cc @dbort for thoughts. |
Hello,
I hope it's the appropriate place to ask such questions, let me know if it would be better suited elsewhere.
We would like to run some model via executorch, the trickiness is that this model has multiple methods that we want to expose and that can manipulate its state (one method is the actual forward pass, the other allows one to reset the internal state). I don't think I've found a proper way to do this.
I came across export multiple functions of a pytorch module that suggests calling
export
multiple times and relying on "Later when ExecuTorch serializes to a binary, the weights/buffer in that structure are then merged into one state dict" but didn't manage to get that to work.First even if the documentation of
torch.export
mentions that it can apply to a callable here, it seems to only work on modules. And after trying to calltorch.export
on two different modules with a common state, these don't seem to actually get shared.Do you know if it's possible to achieve this with executorch at the moment? Any pointer on a model that already does this? (I looked into the examples and googled around but no luck)
In case it helps, you can see the code where I tried to export two modules that share a common state here.
The text was updated successfully, but these errors were encountered: