diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py index 7d0e28069054..4449811ab747 100644 --- a/examples/flux-control/train_control_flux.py +++ b/examples/flux-control/train_control_flux.py @@ -812,6 +812,8 @@ def main(args): for name, module in flux_transformer.named_modules(): if "transformer_blocks" in name: module.requires_grad_(True) + else: + module.requirs_grad_(False) def unwrap_model(model): model = accelerator.unwrap_model(model)