Skip to content

Commit

Permalink
Remove the FP32 Wrapper when evaluating (#10617)
Browse files Browse the repository at this point in the history
Remove the FP32 Wrapper

Co-authored-by: Linoy Tsaban <[email protected]>
  • Loading branch information
lmxyy and linoytsaban authored Jan 21, 2025
1 parent 012d08b commit 158a5a8
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/dreambooth/train_dreambooth_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,9 +1716,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer),
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
Expand Down

0 comments on commit 158a5a8

Please sign in to comment.