Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Jan 25, 2025
1 parent 50d0a28 commit de92b67
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
6 changes: 4 additions & 2 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,10 +1007,12 @@ def forward(
)[0]

if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(motion_module, hidden_states, temb)
hidden_states = self._gradient_checkpointing_func(
motion_module, hidden_states, None, None, None, num_frames, None
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = motion_module(hidden_states, num_frames=num_frames)
hidden_states = motion_module(hidden_states, None, None, None, num_frames, None)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

return hidden_states
Expand Down
21 changes: 6 additions & 15 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,24 +953,15 @@ def test_gradient_checkpointing_is_applied(
init_dict["block_out_channels"] = block_out_channels

model_class_copy = copy.copy(self.model_class)

modules_with_gc_enabled = {}

# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value

def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True

model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new

model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()

modules_with_gc_enabled = {}
for submodule in model.modules():
if hasattr(submodule, "gradient_checkpointing"):
self.assertTrue(submodule.gradient_checkpointing)
modules_with_gc_enabled[submodule.__class__.__name__] = True

assert set(modules_with_gc_enabled.keys()) == expected_set
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"

Expand Down

0 comments on commit de92b67

Please sign in to comment.