Skip to content

Commit

Permalink
Add a way to set a different compute dtype for the model at runtime.
Browse files Browse the repository at this point in the history
Currently only works for diffusion models.
  • Loading branch information
comfyanonymous committed Feb 14, 2025
1 parent 8773ccf commit 019c702
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up
self.load_device = load_device
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
self.force_cast_weights = False
self.patches_uuid = uuid.uuid4()
self.parent = None

Expand Down Expand Up @@ -277,6 +278,8 @@ def clone(self):
n.object_patches_backup = self.object_patches_backup
n.parent = self

n.force_cast_weights = self.force_cast_weights

# attachments
n.attachments = {}
for k in self.attachments:
Expand Down Expand Up @@ -424,6 +427,12 @@ def set_model_forward_timestep_embed_patch(self, patch):
def add_object_patch(self, name, obj):
self.object_patches[name] = obj

def set_model_compute_dtype(self, dtype):
self.add_object_patch("manual_cast_dtype", dtype)
if dtype is not None:
self.force_cast_weights = True
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this

def add_weight_wrapper(self, name, function):
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
self.patches_uuid = uuid.uuid4()
Expand Down Expand Up @@ -602,6 +611,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue

cast_weight = self.force_cast_weights
if lowvram_weight:
if hasattr(m, "comfy_cast_weights"):
m.weight_function = []
Expand All @@ -620,8 +630,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
m.bias_function = [LowVramPatch(bias_key, self.patches)]
patch_counter += 1

m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
cast_weight = True
else:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
Expand All @@ -630,6 +639,10 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))

if cast_weight:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True

if weight_key in self.weight_wrapper_patches:
m.weight_function.extend(self.weight_wrapper_patches[weight_key])

Expand Down Expand Up @@ -766,6 +779,7 @@ def partially_unload(self, device_to, memory_to_free=0):
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
Expand All @@ -775,7 +789,9 @@ def partially_unload(self, device_to, memory_to_free=0):
if bias_key in self.patches:
m.bias_function.append(LowVramPatch(bias_key, self.patches))
patch_counter += 1
cast_weight = True

if cast_weight:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
Expand Down

0 comments on commit 019c702

Please sign in to comment.