Skip to content

Commit

Permalink
Fix some lora loading slowdowns.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 12, 2024
1 parent 52a471c commit 517f4a9
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,14 @@ def patch_model(self, device_to=None, patch_weights=True):

return self.model

def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
for n, m in self.model.named_modules():
lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):

if not full_load and hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
Expand Down Expand Up @@ -401,8 +402,11 @@ def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weigh
if weight.device == device_to:
continue

self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM
self.patch_weight_to_device(bias_key)
weight_to = None
if full_load:#TODO
weight_to = device_to
self.patch_weight_to_device(weight_key, device_to=weight_to) #TODO: speed this up without OOM
self.patch_weight_to_device(bias_key, device_to=weight_to)
m.to(device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))

Expand Down Expand Up @@ -665,12 +669,13 @@ def partially_unload(self, device_to, memory_to_free=0):
return memory_freed

def partially_load(self, device_to, extra_memory=0):
full_load = False
if self.model.model_lowvram == False:
return 0
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
pass #TODO: Full load
full_load = True
current_used = self.model.model_loaded_weight_memory
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory)
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
return self.model.model_loaded_weight_memory - current_used

def current_loaded_device(self):
Expand Down

0 comments on commit 517f4a9

Please sign in to comment.