Skip to content

Commit

Permalink
fp8 weight support for Stable Cascade.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Feb 17, 2024
1 parent f870654 commit 11e3221
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion comfy/ldm/cascade/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, dim, dtype=None, device=None):
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma.to(x.device) * (x * Nx) + self.beta.to(x.device) + x
return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x


class ResBlock(nn.Module):
Expand Down

0 comments on commit 11e3221

Please sign in to comment.