-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix offload gpu tests etc #10366
fix offload gpu tests etc #10366
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
583a7e9
to
42d3a6a
Compare
@@ -1080,7 +1080,7 @@ def test_cpu_offload(self): | |||
torch.manual_seed(0) | |||
base_output = model(**inputs_dict) | |||
|
|||
model_size = compute_module_persistent_sizes(model)[""] | |||
model_size = compute_module_sizes(model)[""] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
follow up fix for this https://github.com/huggingface/diffusers/pull/10340/files#r1895134336
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor | ||
) -> torch.Tensor: | ||
hidden_states = self.norm(hidden_states) | ||
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really a fan of this kind of device casting in forward
but okay to keep it since we don't have better solution yet. These usually end up creating problems for anything that modifies device/dtype with hooks and we then have to use some workarounds.
Going forward, I think nn.Parameter
's can be put in their own dummy nn.Module
so that device map, or other things we're introducing (like group offloading or fp8 layerwise upcasting), works out of the box (as they will handle the weight/type-casting of inputs in overwritten pre-hook methods). If this sounds good, will do future model integrations with this design
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh I actually did not think about this at all (I just copied from the original code) - could you explain why do we need this device casting here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay, I see. I think I missed it when reviewing the PR that added Sana, otherwise would have probably removed it then. I'm not really sure why it is needed here, and think it might be okay to remove
this PR: