Skip to content

Commit

Permalink
calculate_weight function to use a different dtype.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 17, 2024
1 parent 83f3431 commit cd5017c
Showing 1 changed file with 34 additions and 34 deletions.
68 changes: 34 additions & 34 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from comfy.types import UnetWrapperFunction


def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32)
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
Expand Down Expand Up @@ -426,7 +426,7 @@ def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patc
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
return self.model

def calculate_weight(self, patches, weight, key):
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
for p in patches:
strength = p[0]
v = p[1]
Expand All @@ -445,7 +445,7 @@ def calculate_weight(self, patches, weight, key):
weight *= strength_model

if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
v = (self.calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )

if len(v) == 1:
patch_type = "diff"
Expand All @@ -461,8 +461,8 @@ def calculate_weight(self, patches, weight, key):
else:
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
Expand All @@ -471,13 +471,13 @@ def calculate_weight(self, patches, weight, key):

if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
Expand All @@ -495,23 +495,23 @@ def calculate_weight(self, patches, weight, key):

if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)

if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)

if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
Expand All @@ -523,7 +523,7 @@ def calculate_weight(self, patches, weight, key):
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
Expand All @@ -543,24 +543,24 @@ def calculate_weight(self, patches, weight, key):
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))

m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))

try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
Expand All @@ -573,15 +573,15 @@ def calculate_weight(self, patches, weight, key):

dora_scale = v[5]

a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)

try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
Expand Down

0 comments on commit cd5017c

Please sign in to comment.