diff --git a/sd_mecha/model_detection.py b/sd_mecha/model_detection.py index 6616e47..4a1be64 100644 --- a/sd_mecha/model_detection.py +++ b/sd_mecha/model_detection.py @@ -165,20 +165,22 @@ class KeyMergeVisitor(KeyVisitor): _passthrough_callback: Callable[[str], torch.Tensor] def visit_merge(self, node: MergeRecipeNode) -> torch.Tensor: + merged: List[Optional[torch.Tensor]] = [None] * len(node.models) try: return node.merge_method( - self.__visit_deeper_first(node.models), + self.__visit_deeper_first(node.models, merged), {k: get_hyper(v, self._key, node.model_arch) for k, v in node.hypers.items()} | node.volatile_hypers, self._key, node.device if node.device is not None else self._default_device, node.dtype if node.dtype is not None else self._default_dtype, ) - except KeyError as e: + except KeyError: + for n, m in ((n, m) for n, m in zip(node.models, merged) if m is not None): + if n.merge_space == node.merge_space: + return m return self._passthrough_callback(self._key) - def __visit_deeper_first(self, nodes: Tuple[RecipeNode, ...]) -> list: - merged: List[Optional[torch.Tensor]] = [None] * len(nodes) - + def __visit_deeper_first(self, nodes: Tuple[RecipeNode, ...], merged: List[Optional[torch.Tensor]]) -> list: def depth_of_value(index) -> int: if nodes[index] is None: return 0