Skip to content

Commit

Permalink
WIP support for Wan t2v model.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Feb 25, 2025
1 parent f400760 commit 6302301
Show file tree
Hide file tree
Showing 10 changed files with 1,307 additions and 3 deletions.
28 changes: 28 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,31 @@ class Cosmos1CV8x8x8(LatentFormat):
]

latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]

class Wan21(LatentFormat):
latent_channels = 16
latent_dimensions = 3

def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]).view(1, self.latent_channels, 1, 1, 1)


self.taesd_decoder_name = None #TODO

def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std

def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
Loading

0 comments on commit 6302301

Please sign in to comment.