Skip to content

Commit

Permalink
fix lora fallback to merge (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
ljleb authored Jul 2, 2024
1 parent a86f838 commit e6c20a3
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions sd_mecha/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,22 @@ class KeyMergeVisitor(KeyVisitor):
_passthrough_callback: Callable[[str], torch.Tensor]

def visit_merge(self, node: MergeRecipeNode) -> torch.Tensor:
merged: List[Optional[torch.Tensor]] = [None] * len(node.models)
try:
return node.merge_method(
self.__visit_deeper_first(node.models),
self.__visit_deeper_first(node.models, merged),
{k: get_hyper(v, self._key, node.model_arch) for k, v in node.hypers.items()} | node.volatile_hypers,
self._key,
node.device if node.device is not None else self._default_device,
node.dtype if node.dtype is not None else self._default_dtype,
)
except KeyError as e:
except KeyError:
for n, m in ((n, m) for n, m in zip(node.models, merged) if m is not None):
if n.merge_space == node.merge_space:
return m
return self._passthrough_callback(self._key)

def __visit_deeper_first(self, nodes: Tuple[RecipeNode, ...]) -> list:
merged: List[Optional[torch.Tensor]] = [None] * len(nodes)

def __visit_deeper_first(self, nodes: Tuple[RecipeNode, ...], merged: List[Optional[torch.Tensor]]) -> list:
def depth_of_value(index) -> int:
if nodes[index] is None:
return 0
Expand Down

0 comments on commit e6c20a3

Please sign in to comment.