Skip to content

Commit

Permalink
Add Flux fp16 support hack.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 7, 2024
1 parent 6969fc9 commit 8115d8c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
9 changes: 8 additions & 1 deletion comfy/ldm/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)

if txt.dtype == torch.float16:
txt = txt.clip(-65504, 65504)

return img, txt


Expand Down Expand Up @@ -239,7 +243,10 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
x = x + mod.gate * output
if x.dtype == torch.float16:
x = x.clip(-65504, 65504)
return x


class LastLayer(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ class Flux(supported_models_base.BASE):

memory_usage_factor = 2.8

supported_inference_dtypes = [torch.bfloat16, torch.float32]
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]

vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
Expand Down

10 comments on commit 8115d8c

@RandomGitUser321
Copy link
Contributor

@RandomGitUser321 RandomGitUser321 commented on 8115d8c Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using Flux:

The load diffusion model node will crash comfy if you're using an fp8 version of the transformer and it's set to "default" instead of manually picking the weight_dtype.

Also, the regular load checkpoint node will crash comfy as well if you're trying to load an all-in-one fp8 version that has everything packaged into it, probably for the same reason as the load diffusion model node?

This is on an RTX 2080

model weight dtype torch.float8_e4m3fn, manual cast: torch.float16
model_type FLOW
Model doesn't have a device attribute.

@comfyanonymous
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The crash would be because of OOM. Try to increase your page file size or set it to system managed.

@RandomGitUser321
Copy link
Contributor

@RandomGitUser321 RandomGitUser321 commented on 8115d8c Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that was it, it was peaking up into the 40something range. I have 32gb of system ram and had reduced my pagefile to 8gb testing something else out last night. Apparently, I forgot I did that. Bumped it up to 16gb and it's working again. I'll probably take it back up to 32gb anyways.

@deepfree2023
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whenever switching from one Flux model to another Flux model, upon loading the new model, before any vmem or mem usage change, python will crash, sometimes need logout windows and log in again to recover. (12 vram + 32g ram + 20g pagefile size)

One Flux model -> another Flux model (crash)
One Flux model -> another non-Flux model (Success) -> another Flux model (crash)

@JorgeR81
Copy link

@JorgeR81 JorgeR81 commented on 8115d8c Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit is slowing down my inference times by %50
#4271 (comment)

@JorgeR81
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the images are slightly different ( tried with fp8 models ).

What does this commit do ?
Are the models now being upcasted only to fp16, instead of fp32 ?

If so, it doesn't seem to make much difference, in terms of RAM usage.
I still need more than 32GB RAM, when the model is loading.
#4239

@deepfree2023
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whenever switching from one Flux model to another Flux model, upon loading the new model, before any vmem or mem usage change, python will crash, sometimes need logout windows and log in again to recover. (12 vram + 32g ram + 20g pagefile size)

One Flux model -> another Flux model (crash) One Flux model -> another non-Flux model (Success) -> another Flux model (crash)

Upgraded mem to 48G, and the problem was gone.

@comfyanonymous
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default is bf16 but it will use fp16 instead of fp32 if your card has poor support for bf16.

@JorgeR81
Copy link

@JorgeR81 JorgeR81 commented on 8115d8c Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I probably have poor bf16 support ( GTX 1070 )

But using fp16 instead of fp32, doesn't seem to be making any improvements here.
The generation times are slower, and I still need above 32 GB RAM

In terms of image quality, it's about the same.
Some seem slightly better, others slightly worse, depending on the image itself.
But, theoretically, using fp16 instead of fp32, should give worse quality, right ?

@JorgeR81
Copy link

@JorgeR81 JorgeR81 commented on 8115d8c Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, I'd like to be able to choose between upcasting to fp16 or fp32

And fp16 should be usable with less than 32 GB RAM, when loading the models ( without the need for the page file ).

Please sign in to comment.