Skip to content

Commit

Permalink
Implement support for taef1 latent previews (#4409)
Browse files Browse the repository at this point in the history
* add taef1 handling to several places

* remove guess_latent_channels and add latent_channels info directly to flux model

* remove TODO

* fix numbers
  • Loading branch information
mturnshek authored Aug 16, 2024
1 parent 05a9f3f commit 1770fc7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 2 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class StableAudio1(LatentFormat):
latent_channels = 64

class Flux(SD3):
latent_channels = 16
def __init__(self):
self.scale_factor = 0.3611
self.shift_factor = 0.1159
Expand All @@ -162,6 +163,7 @@ def __init__(self):
[-0.0005, -0.0530, -0.0020],
[-0.1273, -0.0932, -0.0680]
]
self.taesd_decoder_name = "taef1_decoder"

def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
Expand Down
13 changes: 12 additions & 1 deletion nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,8 @@ def vae_list():
sd1_taesd_dec = False
sd3_taesd_enc = False
sd3_taesd_dec = False
f1_taesd_enc = False
f1_taesd_dec = False

for v in approx_vaes:
if v.startswith("taesd_decoder."):
Expand All @@ -679,12 +681,18 @@ def vae_list():
sd3_taesd_dec = True
elif v.startswith("taesd3_encoder."):
sd3_taesd_enc = True
elif v.startswith("taef1_encoder."):
f1_taesd_dec = True
elif v.startswith("taef1_decoder."):
f1_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl")
if sd3_taesd_dec and sd3_taesd_enc:
vaes.append("taesd3")
if f1_taesd_dec and f1_taesd_enc:
vaes.append("taef1")
return vaes

@staticmethod
Expand Down Expand Up @@ -712,6 +720,9 @@ def load_taesd(name):
elif name == "taesd3":
sd["vae_scale"] = torch.tensor(1.5305)
sd["vae_shift"] = torch.tensor(0.0609)
elif name == "taef1":
sd["vae_scale"] = torch.tensor(0.3611)
sd["vae_shift"] = torch.tensor(0.1159)
return sd

@classmethod
Expand All @@ -724,7 +735,7 @@ def INPUT_TYPES(s):

#TODO: scale factor?
def load_vae(self, vae_name):
if vae_name in ["taesd", "taesdxl", "taesd3"]:
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
Expand Down

1 comment on commit 1770fc7

@JorgeR81
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Flux preview is excellent now. Thanks. !

The models weren't downloaded automatically, but there was a warning in the console, so I put them in the models/vae_approx folder manually.

https://github.com/madebyollin/taesd/blob/main/taef1_decoder.pth
https://github.com/madebyollin/taesd/blob/main/taef1_encoder.pth

Please sign in to comment.