diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 49be9e87520f..21e4db23a166 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -1007,10 +1007,12 @@ def forward( )[0] if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(motion_module, hidden_states, temb) + hidden_states = self._gradient_checkpointing_func( + motion_module, hidden_states, None, None, None, num_frames, None + ) hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: - hidden_states = motion_module(hidden_states, num_frames=num_frames) + hidden_states = motion_module(hidden_states, None, None, None, num_frames, None) hidden_states = resnet(input_tensor=hidden_states, temb=temb) return hidden_states diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 05050e05bb19..b88b6f16b9fb 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -953,24 +953,15 @@ def test_gradient_checkpointing_is_applied( init_dict["block_out_channels"] = block_out_channels model_class_copy = copy.copy(self.model_class) - - modules_with_gc_enabled = {} - - # now monkey patch the following function: - # def _set_gradient_checkpointing(self, module, value=False): - # if hasattr(module, "gradient_checkpointing"): - # module.gradient_checkpointing = value - - def _set_gradient_checkpointing_new(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - modules_with_gc_enabled[module.__class__.__name__] = True - - model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new - model = model_class_copy(**init_dict) model.enable_gradient_checkpointing() + modules_with_gc_enabled = {} + for submodule in model.modules(): + if hasattr(submodule, "gradient_checkpointing"): + self.assertTrue(submodule.gradient_checkpointing) + modules_with_gc_enabled[submodule.__class__.__name__] = True + assert set(modules_with_gc_enabled.keys()) == expected_set assert all(modules_with_gc_enabled.values()), "All modules should be enabled"