From 36b2214e30db955a10b27ae0d58453bab99dac96 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sun, 28 Jan 2024 20:40:18 -0800 Subject: [PATCH 01/30] Execution Model Inversion This PR inverts the execution model -- from recursively calling nodes to using a topological sort of the nodes. This change allows for modification of the node graph during execution. This allows for two major advantages: 1. The implementation of lazy evaluation in nodes. For example, if a "Mix Images" node has a mix factor of exactly 0.0, the second image input doesn't even need to be evaluated (and visa-versa if the mix factor is 1.0). 2. Dynamic expansion of nodes. This allows for the creation of dynamic "node groups". Specifically, custom nodes can return subgraphs that replace the original node in the graph. This is an incredibly powerful concept. Using this functionality, it was easy to implement: a. Components (a.k.a. node groups) b. Flow control (i.e. while loops) via tail recursion c. All-in-one nodes that replicate the WebUI functionality d. and more All of those were able to be implemented entirely via custom nodes, so those features are *not* a part of this PR. (There are some front-end changes that should occur before that functionality is made widely available, particularly around variant sockets.) The custom nodes associated with this PR can be found at: https://github.com/BadCafeCode/execution-inversion-demo-comfyui Note that some of them require that variant socket types ("*") be enabled. --- comfy/caching.py | 316 +++++++++++++++ comfy/cli_args.py | 4 + comfy/graph.py | 172 ++++++++ comfy/graph_utils.py | 140 +++++++ execution.py | 602 ++++++++++++++++------------ main.py | 4 +- server.py | 4 + tests-ui/tests/groupNode.test.js | 2 + web/extensions/core/groupNode.js | 4 +- web/extensions/core/widgetInputs.js | 2 +- web/scripts/api.js | 2 +- web/scripts/app.js | 6 +- web/scripts/ui.js | 9 +- 13 files changed, 1008 insertions(+), 259 deletions(-) create mode 100644 comfy/caching.py create mode 100644 comfy/graph.py create mode 100644 comfy/graph_utils.py diff --git a/comfy/caching.py b/comfy/caching.py new file mode 100644 index 00000000000..ef047dcc5d8 --- /dev/null +++ b/comfy/caching.py @@ -0,0 +1,316 @@ +import itertools +from typing import Sequence, Mapping + +import nodes + +from comfy.graph_utils import is_link + +class CacheKeySet: + def __init__(self, dynprompt, node_ids, is_changed_cache): + self.keys = {} + self.subcache_keys = {} + + def add_keys(node_ids): + raise NotImplementedError() + + def all_node_ids(self): + return set(self.keys.keys()) + + def get_used_keys(self): + return self.keys.values() + + def get_used_subcache_keys(self): + return self.subcache_keys.values() + + def get_data_key(self, node_id): + return self.keys.get(node_id, None) + + def get_subcache_key(self, node_id): + return self.subcache_keys.get(node_id, None) + +class Unhashable: + def __init__(self): + self.value = float("NaN") + +def to_hashable(obj): + # So that we don't infinitely recurse since frozenset and tuples + # are Sequences. + if isinstance(obj, (int, float, str, bool, type(None))): + return obj + elif isinstance(obj, Mapping): + return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) + elif isinstance(obj, Sequence): + return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj])) + else: + # TODO - Support other objects like tensors? + return Unhashable() + +class CacheKeySetID(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.add_keys(node_ids) + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = (node_id, node["class_type"]) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + +class CacheKeySetInputSignature(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.is_changed_cache = is_changed_cache + self.add_keys(node_ids) + + def include_node_id_in_input(self): + return False + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + + def get_node_signature(self, dynprompt, node_id): + signature = [] + ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) + signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) + for ancestor_id in ancestors: + signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) + return to_hashable(signature) + + def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + node = dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + signature = [class_type, self.is_changed_cache.get(node_id)] + if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT): + signature.append(node_id) + inputs = node["inputs"] + for key in sorted(inputs.keys()): + if is_link(inputs[key]): + (ancestor_id, ancestor_socket) = inputs[key] + ancestor_index = ancestor_order_mapping[ancestor_id] + signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) + else: + signature.append((key, inputs[key])) + return signature + + # This function returns a list of all ancestors of the given node. The order of the list is + # deterministic based on which specific inputs the ancestor is connected by. + def get_ordered_ancestry(self, dynprompt, node_id): + ancestors = [] + order_mapping = {} + self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) + return ancestors, order_mapping + + def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): + inputs = dynprompt.get_node(node_id)["inputs"] + input_keys = sorted(inputs.keys()) + for key in input_keys: + if is_link(inputs[key]): + ancestor_id = inputs[key][0] + if ancestor_id not in order_mapping: + ancestors.append(ancestor_id) + order_mapping[ancestor_id] = len(ancestors) - 1 + self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) + +class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + + def include_node_id_in_input(self): + return True + +class BasicCache: + def __init__(self, key_class): + self.key_class = key_class + self.dynprompt = None + self.cache_key_set = None + self.cache = {} + self.subcaches = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + self.dynprompt = dynprompt + self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) + self.is_changed_cache = is_changed_cache + + def all_node_ids(self): + assert self.cache_key_set is not None + node_ids = self.cache_key_set.all_node_ids() + for subcache in self.subcaches.values(): + node_ids = node_ids.union(subcache.all_node_ids()) + return node_ids + + def clean_unused(self): + assert self.cache_key_set is not None + preserve_keys = set(self.cache_key_set.get_used_keys()) + preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) + to_remove = [] + for key in self.cache: + if key not in preserve_keys: + to_remove.append(key) + for key in to_remove: + del self.cache[key] + + to_remove = [] + for key in self.subcaches: + if key not in preserve_subcaches: + to_remove.append(key) + for key in to_remove: + del self.subcaches[key] + + def _set_immediate(self, node_id, value): + assert self.cache_key_set is not None + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + def _get_immediate(self, node_id): + assert self.cache_key_set is not None + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key in self.cache: + return self.cache[cache_key] + else: + return None + + def _ensure_subcache(self, node_id, children_ids): + assert self.cache_key_set is not None + subcache_key = self.cache_key_set.get_subcache_key(node_id) + subcache = self.subcaches.get(subcache_key, None) + if subcache is None: + subcache = BasicCache(self.key_class) + self.subcaches[subcache_key] = subcache + subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) + return subcache + + def _get_subcache(self, node_id): + assert self.cache_key_set is not None + subcache_key = self.cache_key_set.get_subcache_key(node_id) + if subcache_key in self.subcaches: + return self.subcaches[subcache_key] + else: + return None + + def recursive_debug_dump(self): + result = [] + for key in self.cache: + result.append({"key": key, "value": self.cache[key]}) + for key in self.subcaches: + result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()}) + return result + +class HierarchicalCache(BasicCache): + def __init__(self, key_class): + super().__init__(key_class) + + def _get_cache_for(self, node_id): + parent_id = self.dynprompt.get_parent_node_id(node_id) + if parent_id is None: + return self + + hierarchy = [] + while parent_id is not None: + hierarchy.append(parent_id) + parent_id = self.dynprompt.get_parent_node_id(parent_id) + + cache = self + for parent_id in reversed(hierarchy): + cache = cache._get_subcache(parent_id) + if cache is None: + return None + return cache + + def get(self, node_id): + cache = self._get_cache_for(node_id) + if cache is None: + return None + return cache._get_immediate(node_id) + + def set(self, node_id, value): + cache = self._get_cache_for(node_id) + assert cache is not None + cache._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + cache = self._get_cache_for(node_id) + assert cache is not None + return cache._ensure_subcache(node_id, children_ids) + + def all_active_values(self): + active_nodes = self.all_node_ids() + result = [] + for node_id in active_nodes: + value = self.get(node_id) + if value is not None: + result.append(value) + return result + +class LRUCache(BasicCache): + def __init__(self, key_class, max_size=100): + super().__init__(key_class) + self.max_size = max_size + self.min_generation = 0 + self.generation = 0 + self.used_generation = {} + self.children = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + super().set_prompt(dynprompt, node_ids, is_changed_cache) + self.generation += 1 + for node_id in node_ids: + self._mark_used(node_id) + print("LRUCache: Now at generation %d" % self.generation) + + def clean_unused(self): + print("LRUCache: Cleaning unused. Current size: %d/%d" % (len(self.cache), self.max_size)) + while len(self.cache) > self.max_size and self.min_generation < self.generation: + print("LRUCache: Evicting generation %d" % self.min_generation) + self.min_generation += 1 + to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation] + for key in to_remove: + del self.cache[key] + del self.used_generation[key] + if key in self.children: + del self.children[key] + + def get(self, node_id): + self._mark_used(node_id) + return self._get_immediate(node_id) + + def _mark_used(self, node_id): + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key is not None: + self.used_generation[cache_key] = self.generation + + def set(self, node_id, value): + self._mark_used(node_id) + return self._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + self.cache_key_set.add_keys(children_ids) + self._mark_used(node_id) + cache_key = self.cache_key_set.get_data_key(node_id) + self.children[cache_key] = [] + for child_id in children_ids: + self._mark_used(child_id) + self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) + return self + + def all_active_values(self): + explored = set() + to_explore = set(self.cache_key_set.get_used_keys()) + while len(to_explore) > 0: + cache_key = to_explore.pop() + if cache_key not in explored: + self.used_generation[cache_key] = self.generation + explored.add(cache_key) + if cache_key in self.children: + to_explore.update(self.children[cache_key]) + return [self.cache[key] for key in explored if key in self.cache] + diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b4bbfbfab53..2cbefefebd9 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -87,6 +87,10 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) +cache_group = parser.add_mutually_exclusive_group() +cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") +cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") + attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") diff --git a/comfy/graph.py b/comfy/graph.py new file mode 100644 index 00000000000..2612317f9e6 --- /dev/null +++ b/comfy/graph.py @@ -0,0 +1,172 @@ +import nodes + +from comfy.graph_utils import is_link + +class DynamicPrompt: + def __init__(self, original_prompt): + # The original prompt provided by the user + self.original_prompt = original_prompt + # Any extra pieces of the graph created during execution + self.ephemeral_prompt = {} + self.ephemeral_parents = {} + self.ephemeral_display = {} + + def get_node(self, node_id): + if node_id in self.ephemeral_prompt: + return self.ephemeral_prompt[node_id] + if node_id in self.original_prompt: + return self.original_prompt[node_id] + return None + + def add_ephemeral_node(self, node_id, node_info, parent_id, display_id): + self.ephemeral_prompt[node_id] = node_info + self.ephemeral_parents[node_id] = parent_id + self.ephemeral_display[node_id] = display_id + + def get_real_node_id(self, node_id): + while node_id in self.ephemeral_parents: + node_id = self.ephemeral_parents[node_id] + return node_id + + def get_parent_node_id(self, node_id): + return self.ephemeral_parents.get(node_id, None) + + def get_display_node_id(self, node_id): + while node_id in self.ephemeral_display: + node_id = self.ephemeral_display[node_id] + return node_id + + def all_node_ids(self): + return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys())) + +def get_input_info(class_def, input_name): + valid_inputs = class_def.INPUT_TYPES() + input_info = None + input_category = None + if "required" in valid_inputs and input_name in valid_inputs["required"]: + input_category = "required" + input_info = valid_inputs["required"][input_name] + elif "optional" in valid_inputs and input_name in valid_inputs["optional"]: + input_category = "optional" + input_info = valid_inputs["optional"][input_name] + elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]: + input_category = "hidden" + input_info = valid_inputs["hidden"][input_name] + if input_info is None: + return None, None, None + input_type = input_info[0] + if len(input_info) > 1: + extra_info = input_info[1] + else: + extra_info = {} + return input_type, input_category, extra_info + +class TopologicalSort: + def __init__(self, dynprompt): + self.dynprompt = dynprompt + self.pendingNodes = {} + self.blockCount = {} # Number of nodes this node is directly blocked by + self.blocking = {} # Which nodes are blocked by this node + + def get_input_info(self, unique_id, input_name): + class_type = self.dynprompt.get_node(unique_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + return get_input_info(class_def, input_name) + + def make_input_strong_link(self, to_node_id, to_input): + inputs = self.dynprompt.get_node(to_node_id)["inputs"] + if to_input not in inputs: + raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input)) + value = inputs[to_input] + if not is_link(value): + raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input)) + from_node_id, from_socket = value + self.add_strong_link(from_node_id, from_socket, to_node_id) + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + self.add_node(from_node_id) + if to_node_id not in self.blocking[from_node_id]: + self.blocking[from_node_id][to_node_id] = {} + self.blockCount[to_node_id] += 1 + self.blocking[from_node_id][to_node_id][from_socket] = True + + def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None): + if unique_id in self.pendingNodes: + return + self.pendingNodes[unique_id] = True + self.blockCount[unique_id] = 0 + self.blocking[unique_id] = {} + + inputs = self.dynprompt.get_node(unique_id)["inputs"] + for input_name in inputs: + value = inputs[input_name] + if is_link(value): + from_node_id, from_socket = value + if subgraph_nodes is not None and from_node_id not in subgraph_nodes: + continue + input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + is_lazy = "lazy" in input_info and input_info["lazy"] + if include_lazy or not is_lazy: + self.add_strong_link(from_node_id, from_socket, unique_id) + + def get_ready_nodes(self): + return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] + + def pop_node(self, unique_id): + del self.pendingNodes[unique_id] + for blocked_node_id in self.blocking[unique_id]: + self.blockCount[blocked_node_id] -= 1 + del self.blocking[unique_id] + + def is_empty(self): + return len(self.pendingNodes) == 0 + +# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, +# it can still be returned to the graph after having further dependencies added. +class ExecutionList(TopologicalSort): + def __init__(self, dynprompt, output_cache): + super().__init__(dynprompt) + self.output_cache = output_cache + self.staged_node_id = None + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + if self.output_cache.get(from_node_id) is not None: + # Nothing to do + return + super().add_strong_link(from_node_id, from_socket, to_node_id) + + def stage_node_execution(self): + assert self.staged_node_id is None + if self.is_empty(): + return None + available = self.get_ready_nodes() + if len(available) == 0: + raise Exception("Dependency cycle detected") + next_node = available[0] + # If an output node is available, do that first. + # Technically this has no effect on the overall length of execution, but it feels better as a user + # for a PreviewImage to display a result as soon as it can + # Some other heuristics could probably be used here to improve the UX further. + for node_id in available: + class_type = self.dynprompt.get_node(node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + next_node = node_id + break + self.staged_node_id = next_node + return self.staged_node_id + + def unstage_node_execution(self): + assert self.staged_node_id is not None + self.staged_node_id = None + + def complete_node_execution(self): + node_id = self.staged_node_id + self.pop_node(node_id) + self.staged_node_id = None + +# Return this from a node and any users will be blocked with the given error message. +class ExecutionBlocker: + def __init__(self, message): + self.message = message + diff --git a/comfy/graph_utils.py b/comfy/graph_utils.py new file mode 100644 index 00000000000..a0042e078f7 --- /dev/null +++ b/comfy/graph_utils.py @@ -0,0 +1,140 @@ +def is_link(obj): + if not isinstance(obj, list): + return False + if len(obj) != 2: + return False + if not isinstance(obj[0], str): + return False + if not isinstance(obj[1], int) and not isinstance(obj[1], float): + return False + return True + +# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end +class GraphBuilder: + _default_prefix_root = "" + _default_prefix_call_index = 0 + _default_prefix_graph_index = 0 + + def __init__(self, prefix = None): + if prefix is None: + self.prefix = GraphBuilder.alloc_prefix() + else: + self.prefix = prefix + self.nodes = {} + self.id_gen = 1 + + @classmethod + def set_default_prefix(cls, prefix_root, call_index, graph_index = 0): + cls._default_prefix_root = prefix_root + cls._default_prefix_call_index = call_index + if graph_index is not None: + cls._default_prefix_graph_index = graph_index + + @classmethod + def alloc_prefix(cls, root=None, call_index=None, graph_index=None): + if root is None: + root = GraphBuilder._default_prefix_root + if call_index is None: + call_index = GraphBuilder._default_prefix_call_index + if graph_index is None: + graph_index = GraphBuilder._default_prefix_graph_index + result = "%s.%d.%d." % (root, call_index, graph_index) + GraphBuilder._default_prefix_graph_index += 1 + return result + + def node(self, class_type, id=None, **kwargs): + if id is None: + id = str(self.id_gen) + self.id_gen += 1 + id = self.prefix + id + if id in self.nodes: + return self.nodes[id] + + node = Node(id, class_type, kwargs) + self.nodes[id] = node + return node + + def lookup_node(self, id): + id = self.prefix + id + return self.nodes.get(id) + + def finalize(self): + output = {} + for node_id, node in self.nodes.items(): + output[node_id] = node.serialize() + return output + + def replace_node_output(self, node_id, index, new_value): + node_id = self.prefix + node_id + to_remove = [] + for node in self.nodes.values(): + for key, value in node.inputs.items(): + if is_link(value) and value[0] == node_id and value[1] == index: + if new_value is None: + to_remove.append((node, key)) + else: + node.inputs[key] = new_value + for node, key in to_remove: + del node.inputs[key] + + def remove_node(self, id): + id = self.prefix + id + del self.nodes[id] + +class Node: + def __init__(self, id, class_type, inputs): + self.id = id + self.class_type = class_type + self.inputs = inputs + self.override_display_id = None + + def out(self, index): + return [self.id, index] + + def set_input(self, key, value): + if value is None: + if key in self.inputs: + del self.inputs[key] + else: + self.inputs[key] = value + + def get_input(self, key): + return self.inputs.get(key) + + def set_override_display_id(self, override_display_id): + self.override_display_id = override_display_id + + def serialize(self): + serialized = { + "class_type": self.class_type, + "inputs": self.inputs + } + if self.override_display_id is not None: + serialized["override_display_id"] = self.override_display_id + return serialized + +def add_graph_prefix(graph, outputs, prefix): + # Change the node IDs and any internal links + new_graph = {} + for node_id, node_info in graph.items(): + # Make sure the added nodes have unique IDs + new_node_id = prefix + node_id + new_node = { "class_type": node_info["class_type"], "inputs": {} } + for input_name, input_value in node_info.get("inputs", {}).items(): + if is_link(input_value): + new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]] + else: + new_node["inputs"][input_name] = input_value + new_graph[new_node_id] = new_node + + # Change the node IDs in the outputs + new_outputs = [] + for n in range(len(outputs)): + output = outputs[n] + if is_link(output): + new_outputs.append([prefix + output[0], output[1]]) + else: + new_outputs.append(output) + + return new_graph, tuple(new_outputs) + diff --git a/execution.py b/execution.py index 00908eadd46..3b18d2a7aaa 100644 --- a/execution.py +++ b/execution.py @@ -4,6 +4,7 @@ import threading import heapq import traceback +from enum import Enum import inspect from typing import List, Literal, NamedTuple, Optional @@ -11,29 +12,97 @@ import nodes import comfy.model_management +import comfy.graph_utils +from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker +from comfy.graph_utils import is_link, GraphBuilder +from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID + +class ExecutionResult(Enum): + SUCCESS = 0 + FAILURE = 1 + SLEEPING = 2 + +class IsChangedCache: + def __init__(self, dynprompt, outputs_cache): + self.dynprompt = dynprompt + self.outputs_cache = outputs_cache + self.is_changed = {} + + def get(self, node_id): + if node_id not in self.is_changed: + node = self.dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, "IS_CHANGED"): + if "is_changed" in node: + self.is_changed[node_id] = node["is_changed"] + else: + input_data_all = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) + try: + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") + node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] + self.is_changed[node_id] = node["is_changed"] + except: + node["is_changed"] = float("NaN") + self.is_changed[node_id] = node["is_changed"] + else: + self.is_changed[node_id] = False + return self.is_changed[node_id] -def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): +class CacheSet: + def __init__(self, lru_size=None): + if lru_size is None or lru_size == 0: + self.init_classic_cache() + else: + self.init_lru_cache(lru_size) + self.all = [self.outputs, self.ui, self.objects] + + # Useful for those with ample RAM/VRAM -- allows experimenting without + # blowing away the cache every time + def init_lru_cache(self, cache_size): + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.ui = LRUCache(CacheKeySetInputSignatureWithID, max_size=cache_size) + self.objects = HierarchicalCache(CacheKeySetID) + + # Performs like the old cache -- dump data ASAP + def init_classic_cache(self): + self.outputs = HierarchicalCache(CacheKeySetInputSignature) + self.ui = HierarchicalCache(CacheKeySetInputSignatureWithID) + self.objects = HierarchicalCache(CacheKeySetID) + + def recursive_debug_dump(self): + result = { + "outputs": self.outputs.recursive_debug_dump(), + "ui": self.ui.recursive_debug_dump(), + } + return result + +def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynprompt=None, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} for x in inputs: input_data = inputs[x] - if isinstance(input_data, list): + input_type, input_category, input_info = get_input_info(class_def, x) + if is_link(input_data) and not input_info.get("rawLink", False): input_unique_id = input_data[0] output_index = input_data[1] - if input_unique_id not in outputs: - input_data_all[x] = (None,) + if outputs is None: + continue # This might be a lazily-evaluated input + cached_output = outputs.get(input_unique_id) + if cached_output is None: continue - obj = outputs[input_unique_id][output_index] + obj = cached_output[output_index] input_data_all[x] = obj - else: - if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = [input_data] + elif input_category is not None: + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": input_data_all[x] = [prompt] + if h[x] == "DYNPROMPT": + input_data_all[x] = [dynprompt] if h[x] == "EXTRA_PNGINFO": if "extra_pnginfo" in extra_data: input_data_all[x] = [extra_data['extra_pnginfo']] @@ -41,7 +110,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = [unique_id] return input_data_all -def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): # check if node wants the lists input_is_list = False if hasattr(obj, "INPUT_IS_LIST"): @@ -63,51 +132,97 @@ def slice_dict(d, i): if input_is_list: if allow_interrupt: nodes.before_node_execution() - results.append(getattr(obj, func)(**input_data_all)) + execution_block = None + for k, v in input_data_all.items(): + for input in v: + if isinstance(v, ExecutionBlocker): + execution_block = execution_block_cb(v) if execution_block_cb is not None else v + break + + if execution_block is None: + if pre_execute_cb is not None: + pre_execute_cb(0) + results.append(getattr(obj, func)(**input_data_all)) + else: + results.append(execution_block) elif max_len_input == 0: if allow_interrupt: nodes.before_node_execution() results.append(getattr(obj, func)()) - else: + else: for i in range(max_len_input): if allow_interrupt: nodes.before_node_execution() - results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + input_dict = slice_dict(input_data_all, i) + execution_block = None + for k, v in input_dict.items(): + if isinstance(v, ExecutionBlocker): + execution_block = execution_block_cb(v) if execution_block_cb is not None else v + break + if execution_block is None: + if pre_execute_cb is not None: + pre_execute_cb(i) + results.append(getattr(obj, func)(**input_dict)) + else: + results.append(execution_block) return results -def get_output_data(obj, input_data_all): +def merge_result_data(results, obj): + # check which outputs need concatenating + output = [] + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + return output + +def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): results = [] uis = [] - return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) - - for r in return_values: + subgraph_results = [] + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + has_subgraph = False + for i in range(len(return_values)): + r = return_values[i] if isinstance(r, dict): if 'ui' in r: uis.append(r['ui']) - if 'result' in r: - results.append(r['result']) + if 'expand' in r: + # Perform an expansion, but do not append results + has_subgraph = True + new_graph = r['expand'] + result = r.get("result", None) + if isinstance(result, ExecutionBlocker): + result = tuple([result] * len(obj.RETURN_TYPES)) + subgraph_results.append((new_graph, result)) + elif 'result' in r: + result = r.get("result", None) + if isinstance(result, ExecutionBlocker): + result = tuple([result] * len(obj.RETURN_TYPES)) + results.append(result) + subgraph_results.append((None, result)) else: + if isinstance(r, ExecutionBlocker): + r = tuple([r] * len(obj.RETURN_TYPES)) results.append(r) - output = [] - if len(results) > 0: - # check which outputs need concatenating - output_is_list = [False] * len(results[0]) - if hasattr(obj, "OUTPUT_IS_LIST"): - output_is_list = obj.OUTPUT_IS_LIST - - # merge node execution results - for i, is_list in zip(range(len(results[0])), output_is_list): - if is_list: - output.append([x for o in results for x in o[i]]) - else: - output.append([o[i] for o in results]) - + if has_subgraph: + output = subgraph_results + elif len(results) > 0: + output = merge_result_data(results, obj) + else: + output = [] ui = dict() if len(uis) > 0: ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} - return output, ui + return output, ui, has_subgraph def format_value(x): if x is None: @@ -117,53 +232,144 @@ def format_value(x): else: return str(x) -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage): +def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] + real_node_id = dynprompt.get_real_node_id(unique_id) + display_node_id = dynprompt.get_display_node_id(unique_id) + parent_node_id = dynprompt.get_parent_node_id(unique_id) + inputs = dynprompt.get_node(unique_id)['inputs'] + class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if unique_id in outputs: - return (True, None, None) - - for x in inputs: - input_data = inputs[x] - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage) - if result[0] is not True: - # Another node failed further upstream - return result + if caches.outputs.get(unique_id) is not None: + if server.client_id is not None: + cached_output = caches.ui.get(unique_id) or {} + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) + return (ExecutionResult.SUCCESS, None, None) input_data_all = None try: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) - - obj = object_storage.get((unique_id, class_type), None) - if obj is None: - obj = class_def() - object_storage[(unique_id, class_type)] = obj + if unique_id in pending_subgraph_results: + cached_results = pending_subgraph_results[unique_id] + resolved_outputs = [] + for is_subgraph, result in cached_results: + if not is_subgraph: + resolved_outputs.append(result) + else: + resolved_output = [] + for r in result: + if is_link(r): + source_node, source_output = r[0], r[1] + node_output = caches.outputs.get(source_node)[source_output] + for o in node_output: + resolved_output.append(o) - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data + else: + resolved_output.append(r) + resolved_outputs.append(tuple(resolved_output)) + output_data = merge_result_data(resolved_outputs, class_def) + output_ui = [] + has_subgraph = False + else: + input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt.original_prompt, dynprompt, extra_data) + if server.client_id is not None: + server.last_node_id = display_node_id + server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) + + obj = caches.objects.get(unique_id) + if obj is None: + obj = class_def() + caches.objects.set(unique_id, obj) + + if hasattr(obj, "check_lazy_status"): + required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) + required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) + required_inputs = [x for x in required_inputs if isinstance(x,str) and x not in input_data_all] + if len(required_inputs) > 0: + for i in required_inputs: + execution_list.make_input_strong_link(unique_id, i) + return (ExecutionResult.SLEEPING, None, None) + + def execution_block_cb(block): + if block.message is not None: + mes = { + "prompt_id": prompt_id, + "node_id": unique_id, + "node_type": class_type, + "executed": list(executed), + + "exception_message": "Execution Blocked: %s" % block.message, + "exception_type": "ExecutionBlocked", + "traceback": [], + "current_inputs": [], + "current_outputs": [], + } + server.send_sync("execution_error", mes, server.client_id) + return ExecutionBlocker(None) + else: + return block + def pre_execute_cb(call_index): + GraphBuilder.set_default_prefix(unique_id, call_index, 0) + output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui + caches.ui.set(unique_id, { + "meta": { + "node_id": unique_id, + "display_node": display_node_id, + "parent_node": parent_node_id, + "real_node_id": real_node_id, + }, + "output": output_ui + }) if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + if has_subgraph: + cached_outputs = [] + new_node_ids = [] + new_output_ids = [] + new_output_links = [] + for i in range(len(output_data)): + new_graph, node_outputs = output_data[i] + if new_graph is None: + cached_outputs.append((False, node_outputs)) + else: + # Check for conflicts + for node_id in new_graph.keys(): + if dynprompt.get_node(node_id) is not None: + raise Exception("Attempt to add duplicate node %s" % node_id) + break + for node_id, node_info in new_graph.items(): + new_node_ids.append(node_id) + display_id = node_info.get("override_display_id", unique_id) + dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id) + # Figure out if the newly created node is an output node + class_type = node_info["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + new_output_ids.append(node_id) + for i in range(len(node_outputs)): + if is_link(node_outputs[i]): + from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1] + new_output_links.append((from_node_id, from_socket)) + cached_outputs.append((True, node_outputs)) + new_node_ids = set(new_node_ids) + for cache in caches.all: + cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused() + for node_id in new_output_ids: + execution_list.add_node(node_id) + for link in new_output_links: + execution_list.add_strong_link(link[0], link[1], unique_id) + pending_subgraph_results[unique_id] = cached_outputs + return (ExecutionResult.SLEEPING, None, None) + caches.outputs.set(unique_id, output_data) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") # skip formatting inputs/outputs error_details = { - "node_id": unique_id, + "node_id": real_node_id, } - return (False, error_details, iex) + return (ExecutionResult.FAILURE, error_details, iex) except Exception as ex: typ, _, tb = sys.exc_info() exception_type = full_type_name(typ) @@ -173,109 +379,32 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute for name, inputs in input_data_all.items(): input_data_formatted[name] = [format_value(x) for x in inputs] - output_data_formatted = {} - for node_id, node_outputs in outputs.items(): - output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] - logging.error("!!! Exception during processing !!!") logging.error(traceback.format_exc()) error_details = { - "node_id": unique_id, + "node_id": real_node_id, "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), - "current_inputs": input_data_formatted, - "current_outputs": output_data_formatted + "current_inputs": input_data_formatted } - return (False, error_details, ex) + return (ExecutionResult.FAILURE, error_details, ex) executed.add(unique_id) - return (True, None, None) - -def recursive_will_execute(prompt, outputs, current_item): - unique_id = current_item - inputs = prompt[unique_id]['inputs'] - will_execute = [] - if unique_id in outputs: - return [] - - for x in inputs: - input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - will_execute += recursive_will_execute(prompt, outputs, input_unique_id) - - return will_execute + [unique_id] - -def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): - unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - - is_changed_old = '' - is_changed = '' - to_delete = False - if hasattr(class_def, 'IS_CHANGED'): - if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: - is_changed_old = old_prompt[unique_id]['is_changed'] - if 'is_changed' not in prompt[unique_id]: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs) - if input_data_all is not None: - try: - #is_changed = class_def.IS_CHANGED(**input_data_all) - is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") - prompt[unique_id]['is_changed'] = is_changed - except: - to_delete = True - else: - is_changed = prompt[unique_id]['is_changed'] - - if unique_id not in outputs: - return True - - if not to_delete: - if is_changed != is_changed_old: - to_delete = True - elif unique_id not in old_prompt: - to_delete = True - elif inputs == old_prompt[unique_id]['inputs']: - for x in inputs: - input_data = inputs[x] - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id in outputs: - to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) - else: - to_delete = True - if to_delete: - break - else: - to_delete = True - - if to_delete: - d = outputs.pop(unique_id) - del d - return to_delete + return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server): + def __init__(self, server, lru_size=None): + self.lru_size = lru_size self.server = server self.reset() def reset(self): - self.outputs = {} - self.object_storage = {} - self.outputs_ui = {} + self.caches = CacheSet(self.lru_size) self.status_messages = [] self.success = True - self.old_prompt = {} def add_message(self, event, data, broadcast: bool): self.status_messages.append((event, data)) @@ -302,7 +431,6 @@ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, e "node_id": node_id, "node_type": class_type, "executed": list(executed), - "exception_message": error["exception_message"], "exception_type": error["exception_type"], "traceback": error["traceback"], @@ -311,18 +439,6 @@ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, e } self.add_message("execution_error", mes, broadcast=False) - # Next, remove the subsequent outputs since they will not be executed - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) - del d - def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) @@ -335,61 +451,45 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): - #delete cached outputs if nodes don't exist for them - to_delete = [] - for o in self.outputs: - if o not in prompt: - to_delete += [o] - for o in to_delete: - d = self.outputs.pop(o) - del d - to_delete = [] - for o in self.object_storage: - if o[0] not in prompt: - to_delete += [o] - else: - p = prompt[o[0]] - if o[1] != p['class_type']: - to_delete += [o] - for o in to_delete: - d = self.object_storage.pop(o) - del d - - for x in prompt: - recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) - - current_outputs = set(self.outputs.keys()) - for x in list(self.outputs_ui.keys()): - if x not in current_outputs: - d = self.outputs_ui.pop(x) - del d + dynamic_prompt = DynamicPrompt(prompt) + is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() + + current_outputs = self.caches.outputs.all_node_ids() comfy.model_management.cleanup_models() self.add_message("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, broadcast=False) + pending_subgraph_results = {} executed = set() - output_node_id = None - to_execute = [] - + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) for node_id in list(execute_outputs): - to_execute += [(0, node_id)] - - while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) - output_node_id = to_execute.pop(0)[-1] - - # This call shouldn't raise anything if there's an error deep in - # the actual SD code, instead it will report the node where the - # error was raised - self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) - if self.success is not True: - self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) - break + execution_list.add_node(node_id) - for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) + while not execution_list.is_empty(): + node_id = execution_list.stage_node_execution() + result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + elif result == ExecutionResult.SLEEPING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS: + execution_list.complete_node_execution() + + ui_outputs = {} + meta_outputs = {} + for ui_info in self.caches.ui.all_active_values(): + node_id = ui_info["meta"]["node_id"] + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } self.server.last_node_id = None if comfy.model_management.DISABLE_SMART_MEMORY: comfy.model_management.unload_all_models() @@ -406,7 +506,7 @@ def validate_inputs(prompt, item, validated): obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] class_inputs = obj_class.INPUT_TYPES() - required_inputs = class_inputs['required'] + valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) errors = [] valid = True @@ -415,22 +515,23 @@ def validate_inputs(prompt, item, validated): if hasattr(obj_class, "VALIDATE_INPUTS"): validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args - for x in required_inputs: + for x in valid_inputs: + type_input, input_category, extra_info = get_input_info(obj_class, x) if x not in inputs: - error = { - "type": "required_input_missing", - "message": "Required input is missing", - "details": f"{x}", - "extra_info": { - "input_name": x + if input_category == "required": + error = { + "type": "required_input_missing", + "message": "Required input is missing", + "details": f"{x}", + "extra_info": { + "input_name": x + } } - } - errors.append(error) + errors.append(error) continue val = inputs[x] - info = required_inputs[x] - type_input = info[0] + info = (type_input, extra_info) if isinstance(val, list): if len(val) != 2: error = { @@ -501,6 +602,9 @@ def validate_inputs(prompt, item, validated): if type_input == "STRING": val = str(val) inputs[x] = val + if type_input == "BOOLEAN": + val = bool(val) + inputs[x] = val except Exception as ex: error = { "type": "invalid_input_type", @@ -516,33 +620,32 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if len(info) > 1: - if "min" in info[1] and val < info[1]["min"]: - error = { - "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } + if "min" in extra_info and val < extra_info["min"]: + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, } - errors.append(error) - continue - if "max" in info[1] and val > info[1]["max"]: - error = { - "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } + } + errors.append(error) + continue + if "max" in extra_info and val > extra_info["max"]: + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, } - errors.append(error) - continue + } + errors.append(error) + continue if x not in validate_function_inputs: if isinstance(type_input, list): @@ -582,7 +685,7 @@ def validate_inputs(prompt, item, validated): ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") for x in input_filtered: for i, r in enumerate(ret): - if r is not True: + if r is not True and not isinstance(r, ExecutionBlocker): details = f"{x}" if r is not False: details += f" - {str(r)}" @@ -741,7 +844,7 @@ class ExecutionStatus(NamedTuple): completed: bool messages: List[str] - def task_done(self, item_id, outputs, + def task_done(self, item_id, history_result, status: Optional['PromptQueue.ExecutionStatus']): with self.mutex: prompt = self.currently_running.pop(item_id) @@ -754,9 +857,10 @@ def task_done(self, item_id, outputs, self.history[prompt[1]] = { "prompt": prompt, - "outputs": copy.deepcopy(outputs), + "outputs": {}, 'status': status_dict, } + self.history[prompt[1]].update(history_result) self.server.queue_updated() def get_current_queue(self): diff --git a/main.py b/main.py index 69d9bce6cb7..8cd869e4885 100644 --- a/main.py +++ b/main.py @@ -91,7 +91,7 @@ def cuda_malloc_warning(): print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") def prompt_worker(q, server): - e = execution.PromptExecutor(server) + e = execution.PromptExecutor(server, lru_size=args.cache_lru) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 @@ -111,7 +111,7 @@ def prompt_worker(q, server): e.execute(item[2], prompt_id, item[3], item[4]) need_gc = True q.task_done(item_id, - e.outputs_ui, + e.history_result, status=execution.PromptQueue.ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, diff --git a/server.py b/server.py index dca06f6fc32..8f2896b1ba6 100644 --- a/server.py +++ b/server.py @@ -396,6 +396,7 @@ def node_info(node_class): obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] info = {} info['input'] = obj_class.INPUT_TYPES() + info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} info['output'] = obj_class.RETURN_TYPES info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] @@ -632,6 +633,9 @@ async def start(self, address, port, verbose=True, call_on_start=None): site = web.TCPSite(runner, address, port) await site.start() + self.address = address + self.port = port + if verbose: print("Starting server\n") print("To see the GUI go to: http://{}:{}".format(address, port)) diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index e6ebedd9150..15b784d6768 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -443,6 +443,7 @@ describe("group node", () => { new CustomEvent("executed", { detail: { node: `${nodes.save.id}`, + display_node: `${nodes.save.id}`, output: { images: [ { @@ -483,6 +484,7 @@ describe("group node", () => { new CustomEvent("executed", { detail: { node: `${group.id}:5`, + display_node: `${group.id}:5`, output: { images: [ { diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 0f041fcd2f9..b78d33aac7c 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -956,8 +956,8 @@ export class GroupNodeHandler { const executed = handleEvent.call( this, "executed", - (d) => d?.node, - (d, id, node) => ({ ...d, node: id, merge: !node.resetExecution }) + (d) => d?.display_node, + (d, id, node) => ({ ...d, node: id, display_node: id, merge: !node.resetExecution }) ); const onRemoved = node.onRemoved; diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 3f1c1f8c126..f89c731e6bb 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -3,7 +3,7 @@ import { app } from "../../scripts/app.js"; import { applyTextReplacements } from "../../scripts/utils.js"; const CONVERTED_TYPE = "converted-widget"; -const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; +const VALID_TYPES = ["STRING", "combo", "number", "toggle", "BOOLEAN"]; const CONFIG = Symbol(); const GET_CONFIG = Symbol(); const TARGET = Symbol(); // Used for reroutes to specify the real target widget diff --git a/web/scripts/api.js b/web/scripts/api.js index 3a9bcc87a4e..ae3fbd13a01 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -126,7 +126,7 @@ class ComfyApi extends EventTarget { this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); break; case "executing": - this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); + this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.display_node })); break; case "executed": this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); diff --git a/web/scripts/app.js b/web/scripts/app.js index 6df393ba60d..d1687845438 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1255,7 +1255,7 @@ export class ComfyApp { }); api.addEventListener("executed", ({ detail }) => { - const output = this.nodeOutputs[detail.node]; + const output = this.nodeOutputs[detail.display_node]; if (detail.merge && output) { for (const k in detail.output ?? {}) { const v = output[k]; @@ -1266,9 +1266,9 @@ export class ComfyApp { } } } else { - this.nodeOutputs[detail.node] = detail.output; + this.nodeOutputs[detail.display_node] = detail.output; } - const node = this.graph.getNodeById(detail.node); + const node = this.graph.getNodeById(detail.display_node); if (node) { if (node.onExecuted) node.onExecuted(detail.output); diff --git a/web/scripts/ui.js b/web/scripts/ui.js index d4835c6e445..d69434993b0 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -227,7 +227,14 @@ class ComfyList { onclick: async () => { await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow); if (item.outputs) { - app.nodeOutputs = item.outputs; + app.nodeOutputs = {}; + for (const [key, value] of Object.entries(item.outputs)) { + if (item.meta && item.meta[key] && item.meta[key].display_node) { + app.nodeOutputs[item.meta[key].display_node] = value; + } else { + app.nodeOutputs[key] = value; + } + } } }, }), From e4e20d79b22eae61eba9cbf9ebc6349294d28e11 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Wed, 14 Feb 2024 21:04:50 -0800 Subject: [PATCH 02/30] Allow `input_info` to be of type `None` --- comfy/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/graph.py b/comfy/graph.py index 2612317f9e6..656ebeb9795 100644 --- a/comfy/graph.py +++ b/comfy/graph.py @@ -105,7 +105,7 @@ def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None): if subgraph_nodes is not None and from_node_id not in subgraph_nodes: continue input_type, input_category, input_info = self.get_input_info(unique_id, input_name) - is_lazy = "lazy" in input_info and input_info["lazy"] + is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] if include_lazy or not is_lazy: self.add_strong_link(from_node_id, from_socket, unique_id) From 2c7145d5251fc5f55355b262163f638d984b9007 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Wed, 14 Feb 2024 21:05:44 -0800 Subject: [PATCH 03/30] Handle errors (like OOM) more gracefully --- execution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/execution.py b/execution.py index 3b18d2a7aaa..1e426a32fe7 100644 --- a/execution.py +++ b/execution.py @@ -435,7 +435,7 @@ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, e "exception_type": error["exception_type"], "traceback": error["traceback"], "current_inputs": error["current_inputs"], - "current_outputs": error["current_outputs"], + "current_outputs": list(current_outputs), } self.add_message("execution_error", mes, broadcast=False) From 12627ca75aaa81045dab1bbfa9bdf73bef46fde3 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Wed, 14 Feb 2024 21:06:53 -0800 Subject: [PATCH 04/30] Add a command-line argument to enable variants This allows the use of nodes that have sockets of type '*' without applying a patch to the code. --- comfy/cli_args.py | 1 + execution.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 2cbefefebd9..74354ea9426 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -117,6 +117,7 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") +parser.add_argument("--enable-variants", action="store_true", help="Enables '*' type nodes.") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/execution.py b/execution.py index 1e426a32fe7..57c9cbf784b 100644 --- a/execution.py +++ b/execution.py @@ -16,6 +16,7 @@ from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy.graph_utils import is_link, GraphBuilder from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID +from comfy.cli_args import args class ExecutionResult(Enum): SUCCESS = 0 @@ -550,7 +551,8 @@ def validate_inputs(prompt, item, validated): o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES - if r[val[1]] != type_input: + is_variant = args.enable_variants and (r[val[1]] == "*" or type_input == "*") + if r[val[1]] != type_input and not is_variant: received_type = r[val[1]] details = f"{x}, {received_type} != {type_input}" error = { From 9c1e3f7b986c040cf5fd1665fcd964b357dcc39d Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 17 Feb 2024 21:02:59 -0800 Subject: [PATCH 05/30] Fix an overly aggressive assertion. This could happen when attempting to evaluate `IS_CHANGED` for a node during the creation of the cache (in order to create the cache key). --- comfy/caching.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/caching.py b/comfy/caching.py index ef047dcc5d8..7730a37137e 100644 --- a/comfy/caching.py +++ b/comfy/caching.py @@ -172,7 +172,8 @@ def _set_immediate(self, node_id, value): self.cache[cache_key] = value def _get_immediate(self, node_id): - assert self.cache_key_set is not None + if self.cache_key_set is None: + return None cache_key = self.cache_key_set.get_data_key(node_id) if cache_key in self.cache: return self.cache[cache_key] From 508d286b8fbd4cfa0281ee6c0d25aab95a7c014c Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 17 Feb 2024 21:56:46 -0800 Subject: [PATCH 06/30] Fix Pyright warnings --- comfy/caching.py | 23 +++++++++++++---------- execution.py | 10 ++++------ 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/comfy/caching.py b/comfy/caching.py index 7730a37137e..936e2e6dfee 100644 --- a/comfy/caching.py +++ b/comfy/caching.py @@ -1,5 +1,6 @@ import itertools from typing import Sequence, Mapping +from comfy.graph import DynamicPrompt import nodes @@ -10,7 +11,7 @@ def __init__(self, dynprompt, node_ids, is_changed_cache): self.keys = {} self.subcache_keys = {} - def add_keys(node_ids): + def add_keys(self, node_ids): raise NotImplementedError() def all_node_ids(self): @@ -66,7 +67,7 @@ def __init__(self, dynprompt, node_ids, is_changed_cache): self.is_changed_cache = is_changed_cache self.add_keys(node_ids) - def include_node_id_in_input(self): + def include_node_id_in_input(self) -> bool: return False def add_keys(self, node_ids): @@ -131,8 +132,9 @@ def include_node_id_in_input(self): class BasicCache: def __init__(self, key_class): self.key_class = key_class - self.dynprompt = None - self.cache_key_set = None + self.initialized = False + self.dynprompt: DynamicPrompt + self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} @@ -140,16 +142,17 @@ def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) self.is_changed_cache = is_changed_cache + self.initialized = True def all_node_ids(self): - assert self.cache_key_set is not None + assert self.initialized node_ids = self.cache_key_set.all_node_ids() for subcache in self.subcaches.values(): node_ids = node_ids.union(subcache.all_node_ids()) return node_ids def clean_unused(self): - assert self.cache_key_set is not None + assert self.initialized preserve_keys = set(self.cache_key_set.get_used_keys()) preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) to_remove = [] @@ -167,12 +170,12 @@ def clean_unused(self): del self.subcaches[key] def _set_immediate(self, node_id, value): - assert self.cache_key_set is not None + assert self.initialized cache_key = self.cache_key_set.get_data_key(node_id) self.cache[cache_key] = value def _get_immediate(self, node_id): - if self.cache_key_set is None: + if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) if cache_key in self.cache: @@ -181,7 +184,6 @@ def _get_immediate(self, node_id): return None def _ensure_subcache(self, node_id, children_ids): - assert self.cache_key_set is not None subcache_key = self.cache_key_set.get_subcache_key(node_id) subcache = self.subcaches.get(subcache_key, None) if subcache is None: @@ -191,7 +193,7 @@ def _ensure_subcache(self, node_id, children_ids): return subcache def _get_subcache(self, node_id): - assert self.cache_key_set is not None + assert self.initialized subcache_key = self.cache_key_set.get_subcache_key(node_id) if subcache_key in self.subcaches: return self.subcaches[subcache_key] @@ -211,6 +213,7 @@ def __init__(self, key_class): super().__init__(key_class) def _get_cache_for(self, node_id): + assert self.dynprompt is not None parent_id = self.dynprompt.get_parent_node_id(node_id) if parent_id is None: return self diff --git a/execution.py b/execution.py index 57c9cbf784b..4d9a4b98c7c 100644 --- a/execution.py +++ b/execution.py @@ -84,7 +84,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro for x in inputs: input_data = inputs[x] input_type, input_category, input_info = get_input_info(class_def, x) - if is_link(input_data) and not input_info.get("rawLink", False): + if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] if outputs is None: @@ -94,7 +94,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro continue obj = cached_output[output_index] input_data_all[x] = obj - elif input_category is not None: + else: input_data_all[x] = [input_data] if "hidden" in valid_inputs: @@ -336,8 +336,7 @@ def pre_execute_cb(call_index): # Check for conflicts for node_id in new_graph.keys(): if dynprompt.get_node(node_id) is not None: - raise Exception("Attempt to add duplicate node %s" % node_id) - break + raise Exception("Attempt to add duplicate node %s. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder." % node_id) for node_id, node_info in new_graph.items(): new_node_ids.append(node_id) display_id = node_info.get("override_display_id", unique_id) @@ -518,6 +517,7 @@ def validate_inputs(prompt, item, validated): for x in valid_inputs: type_input, input_category, extra_info = get_input_info(obj_class, x) + assert extra_info is not None if x not in inputs: if input_category == "required": error = { @@ -698,8 +698,6 @@ def validate_inputs(prompt, item, validated): "details": details, "extra_info": { "input_name": x, - "input_config": info, - "received_value": val, } } errors.append(error) From fff22830a915397305b70a94893bb87f644fe033 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sun, 18 Feb 2024 01:41:21 -0800 Subject: [PATCH 07/30] Add execution model unit tests --- pytest.ini | 3 +- tests/inference/test_execution.py | 294 +++++++++++++++ .../testing_nodes/testing-pack/__init__.py | 23 ++ .../testing_nodes/testing-pack/conditions.py | 194 ++++++++++ .../testing-pack/flow_control.py | 169 +++++++++ .../testing-pack/specific_tests.py | 116 ++++++ .../testing_nodes/testing-pack/stubs.py | 61 +++ .../testing_nodes/testing-pack/util.py | 353 ++++++++++++++++++ 8 files changed, 1212 insertions(+), 1 deletion(-) create mode 100644 tests/inference/test_execution.py create mode 100644 tests/inference/testing_nodes/testing-pack/__init__.py create mode 100644 tests/inference/testing_nodes/testing-pack/conditions.py create mode 100644 tests/inference/testing_nodes/testing-pack/flow_control.py create mode 100644 tests/inference/testing_nodes/testing-pack/specific_tests.py create mode 100644 tests/inference/testing_nodes/testing-pack/stubs.py create mode 100644 tests/inference/testing_nodes/testing-pack/util.py diff --git a/pytest.ini b/pytest.ini index b5a68e0f12f..d34fb51907f 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,6 @@ [pytest] markers = inference: mark as inference test (deselect with '-m "not inference"') + execution: mark as execution test (deselect with '-m "not execution"') testpaths = tests -addopts = -s \ No newline at end of file +addopts = -s diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py new file mode 100644 index 00000000000..52e1956279c --- /dev/null +++ b/tests/inference/test_execution.py @@ -0,0 +1,294 @@ +from io import BytesIO +import numpy +from PIL import Image +import pytest +from pytest import fixture +import time +import torch +from typing import Union, Dict +import json +import subprocess +import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import urllib.request +import urllib.parse +from comfy.graph_utils import GraphBuilder, Node + +class RunResult: + def __init__(self, prompt_id: str): + self.outputs: Dict[str,Dict] = {} + self.runs: Dict[str,bool] = {} + self.prompt_id: str = prompt_id + + def get_output(self, node: Node): + return self.outputs.get(node.id, None) + + def did_run(self, node: Node): + return self.runs.get(node.id, False) + + def get_images(self, node: Node): + output = self.get_output(node) + if output is None: + return [] + return output.get('image_objects', []) + + def get_prompt_id(self): + return self.prompt_id + +class ComfyClient: + def __init__(self): + self.test_name = "" + + def connect(self, + listen:str = '127.0.0.1', + port:Union[str,int] = 8188, + client_id: str = str(uuid.uuid4()) + ): + self.client_id = client_id + self.server_address = f"{listen}:{port}" + ws = websocket.WebSocket() + ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) + self.ws = ws + + def queue_prompt(self, prompt): + p = {"prompt": prompt, "client_id": self.client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + + def get_image(self, filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response: + return response.read() + + def get_history(self, prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: + return json.loads(response.read()) + + def set_test_name(self, name): + self.test_name = name + + def run(self, graph): + prompt = graph.finalize() + for node in graph.nodes.values(): + if node.class_type == 'SaveImage': + node.inputs['filename_prefix'] = self.test_name + + prompt_id = self.queue_prompt(prompt)['prompt_id'] + result = RunResult(prompt_id) + while True: + out = self.ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['prompt_id'] != prompt_id: + continue + if data['node'] is None: + break + result.runs[data['node']] = True + elif message['type'] == 'execution_error': + raise Exception(message['data']) + elif message['type'] == 'execution_cached': + pass # Probably want to store this off for testing + + history = self.get_history(prompt_id)[prompt_id] + for o in history['outputs']: + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + result.outputs[node_id] = node_output + if 'images' in node_output: + images_output = [] + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + image_obj = Image.open(BytesIO(image_data)) + images_output.append(image_obj) + node_output['image_objects'] = images_output + + return result + +# +# Loop through these variables +# +@pytest.mark.execution +class TestExecution: + # + # Initialize server and client + # + @fixture(scope="class", autouse=True) + def _server(self, args_pytest): + # Start server + p = subprocess.Popen([ + 'python','main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + '--enable-variants', + ]) + yield + p.kill() + torch.cuda.empty_cache() + + def start_client(self, listen:str, port:int): + # Start client + comfy_client = ComfyClient() + # Connect to server (with retries) + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + comfy_client.connect(listen=listen, port=port) + except ConnectionRefusedError as e: + print(e) + print(f"({i+1}/{n_tries}) Retrying...") + else: + break + return comfy_client + + @fixture(scope="class", autouse=True) + def shared_client(self, args_pytest, _server): + client = self.start_client(args_pytest["listen"], args_pytest["port"]) + yield client + del client + torch.cuda.empty_cache() + + @fixture + def client(self, shared_client, request): + shared_client.set_test_name(f"execution[{request.node.name}]") + yield shared_client + + def clear_cache(self, client: ComfyClient): + g = GraphBuilder(prefix="foo") + random = g.node("StubImage", content="NOISE", height=1, width=1, batch_size=1) + g.node("PreviewImage", images=random.out(0)) + client.run(g) + + @fixture + def builder(self): + yield GraphBuilder(prefix="") + + def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + output = g.node("SaveImage", images=lazy_mix.out(0)) + result = client.run(g) + + result_image = result.get_images(output)[0] + assert numpy.array(result_image).any() == 0, "Image should be black" + assert result.did_run(input1) + assert not result.did_run(input2) + assert result.did_run(mask) + assert result.did_run(lazy_mix) + + def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): + self.clear_cache(client) + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + result1 = client.run(g) + result2 = client.run(g) + for node_id, node in g.nodes.items(): + assert result1.did_run(node), f"Node {node_id} didn't run" + assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + + def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): + self.clear_cache(client) + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + result1 = client.run(g) + mask.inputs['value'] = 0.4 + result2 = client.run(g) + for node_id, node in g.nodes.items(): + assert result1.did_run(node), f"Node {node_id} didn't run" + assert not result2.did_run(input1), "Input1 should have been cached" + assert not result2.did_run(input2), "Input2 should have been cached" + assert result2.did_run(mask), "Mask should have been re-run" + assert result2.did_run(lazy_mix), "Lazy mix should have been re-run" + + def test_error(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + # Different size of the two images + input2 = g.node("StubImage", content="NOISE", height=256, width=256, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + try: + client.run(g) + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" + + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): + g = builder + # Creating the nodes in this specific order previously caused a bug + save = g.node("SaveImage") + is_changed = g.node("TestCustomIsChanged", should_change=False) + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + save.set_input('images', is_changed.out(0)) + is_changed.set_input('image', input1.out(0)) + + result1 = client.run(g) + result2 = client.run(g) + is_changed.set_input('should_change', True) + result3 = client.run(g) + result4 = client.run(g) + assert result1.did_run(is_changed), "is_changed should have been run" + assert not result2.did_run(is_changed), "is_changed should have been cached" + assert result3.did_run(is_changed), "is_changed should have been re-run" + assert result4.did_run(is_changed), "is_changed should not have been cached" + + def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder): + self.clear_cache(client) + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input4 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + average = g.node("TestVariadicAverage", input1=input1.out(0), input2=input2.out(0), input3=input3.out(0), input4=input4.out(0)) + output = g.node("SaveImage", images=average.out(0)) + + result = client.run(g) + result_image = result.get_images(output)[0] + expected = 255 // 4 + assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" + assert result.did_run(input1) + assert result.did_run(input2) + + def test_for_loop(self, client: ComfyClient, builder: GraphBuilder): + g = builder + iterations = 4 + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + is_changed = g.node("TestCustomIsChanged", should_change=True, image=input2.out(0)) + for_open = g.node("TestForLoopOpen", remaining=iterations, initial_value1=is_changed.out(0)) + average = g.node("TestVariadicAverage", input1=input1.out(0), input2=for_open.out(2)) + for_close = g.node("TestForLoopClose", flow_control=for_open.out(0), initial_value1=average.out(0)) + output = g.node("SaveImage", images=for_close.out(0)) + + for iterations in range(1, 5): + for_open.set_input('remaining', iterations) + result = client.run(g) + result_image = result.get_images(output)[0] + expected = 255 // (2 ** iterations) + assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" + assert result.did_run(is_changed) diff --git a/tests/inference/testing_nodes/testing-pack/__init__.py b/tests/inference/testing_nodes/testing-pack/__init__.py new file mode 100644 index 00000000000..dcc71659a02 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/__init__.py @@ -0,0 +1,23 @@ +from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS +from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS +from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS +from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS +from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS + +# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) +# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) + +NODE_CLASS_MAPPINGS = {} +NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS) + +NODE_DISPLAY_NAME_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS) + diff --git a/tests/inference/testing_nodes/testing-pack/conditions.py b/tests/inference/testing_nodes/testing-pack/conditions.py new file mode 100644 index 00000000000..0c200ee2892 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/conditions.py @@ -0,0 +1,194 @@ +import re +import torch + +class TestIntConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "operation": (["==", "!=", "<", ">", "<=", ">="],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "int_condition" + + CATEGORY = "Testing/Logic" + + def int_condition(self, a, b, operation): + if operation == "==": + return (a == b,) + elif operation == "!=": + return (a != b,) + elif operation == "<": + return (a < b,) + elif operation == ">": + return (a > b,) + elif operation == "<=": + return (a <= b,) + elif operation == ">=": + return (a >= b,) + + +class TestFloatConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}), + "b": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}), + "operation": (["==", "!=", "<", ">", "<=", ">="],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "float_condition" + + CATEGORY = "Testing/Logic" + + def float_condition(self, a, b, operation): + if operation == "==": + return (a == b,) + elif operation == "!=": + return (a != b,) + elif operation == "<": + return (a < b,) + elif operation == ">": + return (a > b,) + elif operation == "<=": + return (a <= b,) + elif operation == ">=": + return (a >= b,) + +class TestStringConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("STRING", {"multiline": False}), + "b": ("STRING", {"multiline": False}), + "operation": (["a == b", "a != b", "a IN b", "a MATCH REGEX(b)", "a BEGINSWITH b", "a ENDSWITH b"],), + "case_sensitive": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "string_condition" + + CATEGORY = "Testing/Logic" + + def string_condition(self, a, b, operation, case_sensitive): + if not case_sensitive: + a = a.lower() + b = b.lower() + + if operation == "a == b": + return (a == b,) + elif operation == "a != b": + return (a != b,) + elif operation == "a IN b": + return (a in b,) + elif operation == "a MATCH REGEX(b)": + try: + return (re.match(b, a) is not None,) + except: + return (False,) + elif operation == "a BEGINSWITH b": + return (a.startswith(b),) + elif operation == "a ENDSWITH b": + return (a.endswith(b),) + +class TestToBoolNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("*",), + }, + "optional": { + "invert": ("BOOLEAN", {"default": False}), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "to_bool" + + CATEGORY = "Testing/Logic" + + def to_bool(self, value, invert = False): + if isinstance(value, torch.Tensor): + if value.max().item() == 0 and value.min().item() == 0: + result = False + else: + result = True + else: + try: + result = bool(value) + except: + # Can't convert it? Well then it's something or other. I dunno, I'm not a Python programmer. + result = True + + if invert: + result = not result + + return (result,) + +class TestBoolOperationNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("BOOLEAN",), + "b": ("BOOLEAN",), + "op": (["a AND b", "a OR b", "a XOR b", "NOT a"],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "bool_operation" + + CATEGORY = "Testing/Logic" + + def bool_operation(self, a, b, op): + if op == "a AND b": + return (a and b,) + elif op == "a OR b": + return (a or b,) + elif op == "a XOR b": + return (a ^ b,) + elif op == "NOT a": + return (not a,) + + +CONDITION_NODE_CLASS_MAPPINGS = { + "TestIntConditions": TestIntConditions, + "TestFloatConditions": TestFloatConditions, + "TestStringConditions": TestStringConditions, + "TestToBoolNode": TestToBoolNode, + "TestBoolOperationNode": TestBoolOperationNode, +} + +CONDITION_NODE_DISPLAY_NAME_MAPPINGS = { + "TestIntConditions": "Int Condition", + "TestFloatConditions": "Float Condition", + "TestStringConditions": "String Condition", + "TestToBoolNode": "To Bool", + "TestBoolOperationNode": "Bool Operation", +} diff --git a/tests/inference/testing_nodes/testing-pack/flow_control.py b/tests/inference/testing_nodes/testing-pack/flow_control.py new file mode 100644 index 00000000000..8befdcf1973 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/flow_control.py @@ -0,0 +1,169 @@ +from comfy.graph_utils import GraphBuilder, is_link +from comfy.graph import ExecutionBlocker + +NUM_FLOW_SOCKETS = 5 +class TestWhileLoopOpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "condition": ("BOOLEAN", {"default": True}), + }, + "optional": { + }, + } + for i in range(NUM_FLOW_SOCKETS): + inputs["optional"]["initial_value%d" % i] = ("*",) + return inputs + + RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS) + RETURN_NAMES = tuple(["FLOW_CONTROL"] + ["value%d" % i for i in range(NUM_FLOW_SOCKETS)]) + FUNCTION = "while_loop_open" + + CATEGORY = "Testing/Flow" + + def while_loop_open(self, condition, **kwargs): + values = [] + for i in range(NUM_FLOW_SOCKETS): + values.append(kwargs.get("initial_value%d" % i, None)) + return tuple(["stub"] + values) + +class TestWhileLoopClose: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "flow_control": ("FLOW_CONTROL", {"rawLink": True}), + "condition": ("BOOLEAN", {"forceInput": True}), + }, + "optional": { + }, + "hidden": { + "dynprompt": "DYNPROMPT", + "unique_id": "UNIQUE_ID", + } + } + for i in range(NUM_FLOW_SOCKETS): + inputs["optional"]["initial_value%d" % i] = ("*",) + return inputs + + RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS) + RETURN_NAMES = tuple(["value%d" % i for i in range(NUM_FLOW_SOCKETS)]) + FUNCTION = "while_loop_close" + + CATEGORY = "Testing/Flow" + + def explore_dependencies(self, node_id, dynprompt, upstream): + node_info = dynprompt.get_node(node_id) + if "inputs" not in node_info: + return + for k, v in node_info["inputs"].items(): + if is_link(v): + parent_id = v[0] + if parent_id not in upstream: + upstream[parent_id] = [] + self.explore_dependencies(parent_id, dynprompt, upstream) + upstream[parent_id].append(node_id) + + def collect_contained(self, node_id, upstream, contained): + if node_id not in upstream: + return + for child_id in upstream[node_id]: + if child_id not in contained: + contained[child_id] = True + self.collect_contained(child_id, upstream, contained) + + + def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs): + assert dynprompt is not None + if not condition: + # We're done with the loop + values = [] + for i in range(NUM_FLOW_SOCKETS): + values.append(kwargs.get("initial_value%d" % i, None)) + return tuple(values) + + # We want to loop + upstream = {} + # Get the list of all nodes between the open and close nodes + self.explore_dependencies(unique_id, dynprompt, upstream) + + contained = {} + open_node = flow_control[0] + self.collect_contained(open_node, upstream, contained) + contained[unique_id] = True + contained[open_node] = True + + # We'll use the default prefix, but to avoid having node names grow exponentially in size, + # we'll use "Recurse" for the name of the recursively-generated copy of this node. + graph = GraphBuilder() + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id) + node.set_override_display_id(node_id) + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.lookup_node("Recurse" if node_id == unique_id else node_id) + assert node is not None + for k, v in original_node["inputs"].items(): + if is_link(v) and v[0] in contained: + parent = graph.lookup_node(v[0]) + assert parent is not None + node.set_input(k, parent.out(v[1])) + else: + node.set_input(k, v) + new_open = graph.lookup_node(open_node) + assert new_open is not None + for i in range(NUM_FLOW_SOCKETS): + key = "initial_value%d" % i + new_open.set_input(key, kwargs.get(key, None)) + my_clone = graph.lookup_node("Recurse") + assert my_clone is not None + result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS)) + return { + "result": tuple(result), + "expand": graph.finalize(), + } + +class TestExecutionBlockerNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "input": ("*",), + "block": ("BOOLEAN",), + "verbose": ("BOOLEAN", {"default": False}), + }, + } + return inputs + + RETURN_TYPES = ("*",) + RETURN_NAMES = ("output",) + FUNCTION = "execution_blocker" + + CATEGORY = "Testing/Flow" + + def execution_blocker(self, input, block, verbose): + if block: + return (ExecutionBlocker("Blocked Execution" if verbose else None),) + return (input,) + +FLOW_CONTROL_NODE_CLASS_MAPPINGS = { + "TestWhileLoopOpen": TestWhileLoopOpen, + "TestWhileLoopClose": TestWhileLoopClose, + "TestExecutionBlocker": TestExecutionBlockerNode, +} +FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = { + "TestWhileLoopOpen": "While Loop Open", + "TestWhileLoopClose": "While Loop Close", + "TestExecutionBlocker": "Execution Blocker", +} diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py new file mode 100644 index 00000000000..e3d864b443b --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -0,0 +1,116 @@ +import torch + +class TestLazyMixImages: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image1": ("IMAGE",{"lazy": True}), + "image2": ("IMAGE",{"lazy": True}), + "mask": ("MASK",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "mix" + + CATEGORY = "Testing/Nodes" + + def check_lazy_status(self, mask, image1 = None, image2 = None): + mask_min = mask.min() + mask_max = mask.max() + needed = [] + if image1 is None and (mask_min != 1.0 or mask_max != 1.0): + needed.append("image1") + if image2 is None and (mask_min != 0.0 or mask_max != 0.0): + needed.append("image2") + return needed + + # Not trying to handle different batch sizes here just to keep the demo simple + def mix(self, mask, image1 = None, image2 = None): + mask_min = mask.min() + mask_max = mask.max() + if mask_min == 0.0 and mask_max == 0.0: + return (image1,) + elif mask_min == 1.0 and mask_max == 1.0: + return (image2,) + + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + if len(mask.shape) == 3: + mask = mask.unsqueeze(3) + if mask.shape[3] < image1.shape[3]: + mask = mask.repeat(1, 1, 1, image1.shape[3]) + + result = image1 * (1. - mask) + image2 * mask, + print(result[0]) + return (result[0],) + +class TestVariadicAverage: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "variadic_average" + + CATEGORY = "Testing/Nodes" + + def variadic_average(self, input1, **kwargs): + inputs = [input1] + while 'input' + str(len(inputs) + 1) in kwargs: + inputs.append(kwargs['input' + str(len(inputs) + 1)]) + return (torch.stack(inputs).mean(dim=0),) + + +class TestCustomIsChanged: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + "optional": { + "should_change": ("BOOL", {"default": False}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_is_changed" + + CATEGORY = "Testing/Nodes" + + def custom_is_changed(self, image, should_change=False): + return (image,) + + @classmethod + def IS_CHANGED(cls, should_change=False, *args, **kwargs): + if should_change: + return float("NaN") + else: + return False + +TEST_NODE_CLASS_MAPPINGS = { + "TestLazyMixImages": TestLazyMixImages, + "TestVariadicAverage": TestVariadicAverage, + "TestCustomIsChanged": TestCustomIsChanged, +} + +TEST_NODE_DISPLAY_NAME_MAPPINGS = { + "TestLazyMixImages": "Lazy Mix Images", + "TestVariadicAverage": "Variadic Average", + "TestCustomIsChanged": "Custom IsChanged", +} diff --git a/tests/inference/testing_nodes/testing-pack/stubs.py b/tests/inference/testing_nodes/testing-pack/stubs.py new file mode 100644 index 00000000000..b2a5ebf3d70 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/stubs.py @@ -0,0 +1,61 @@ +import torch + +class StubImage: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "content": (['WHITE', 'BLACK', 'NOISE'],), + "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stub_image" + + CATEGORY = "Testing/Stub Nodes" + + def stub_image(self, content, height, width, batch_size): + if content == "WHITE": + return (torch.ones(batch_size, height, width, 3),) + elif content == "BLACK": + return (torch.zeros(batch_size, height, width, 3),) + elif content == "NOISE": + return (torch.rand(batch_size, height, width, 3),) + +class StubMask: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}), + }, + } + + RETURN_TYPES = ("MASK",) + FUNCTION = "stub_mask" + + CATEGORY = "Testing/Stub Nodes" + + def stub_mask(self, value, height, width, batch_size): + return (torch.ones(batch_size, height, width) * value,) + +TEST_STUB_NODE_CLASS_MAPPINGS = { + "StubImage": StubImage, + "StubMask": StubMask, +} +TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { + "StubImage": "Stub Image", + "StubMask": "Stub Mask", +} diff --git a/tests/inference/testing_nodes/testing-pack/util.py b/tests/inference/testing_nodes/testing-pack/util.py new file mode 100644 index 00000000000..16209d3fc1b --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/util.py @@ -0,0 +1,353 @@ +from comfy.graph_utils import GraphBuilder + +class TestAccumulateNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "to_add": ("*",), + }, + "optional": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + FUNCTION = "accumulate" + + CATEGORY = "Testing/Lists" + + def accumulate(self, to_add, accumulation = None): + if accumulation is None: + value = [to_add] + else: + value = accumulation["accum"] + [to_add] + return ({"accum": value},) + +class TestAccumulationHeadNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION", "*",) + FUNCTION = "accumulation_head" + + CATEGORY = "Testing/Lists" + + def accumulation_head(self, accumulation): + accum = accumulation["accum"] + if len(accum) == 0: + return (accumulation, None) + else: + return ({"accum": accum[1:]}, accum[0]) + +class TestAccumulationTailNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION", "*",) + FUNCTION = "accumulation_tail" + + CATEGORY = "Testing/Lists" + + def accumulation_tail(self, accumulation): + accum = accumulation["accum"] + if len(accum) == 0: + return (None, accumulation) + else: + return ({"accum": accum[:-1]}, accum[-1]) + +class TestAccumulationToListNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("*",) + OUTPUT_IS_LIST = (True,) + + FUNCTION = "accumulation_to_list" + + CATEGORY = "Testing/Lists" + + def accumulation_to_list(self, accumulation): + return (accumulation["accum"],) + +class TestListToAccumulationNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": ("*",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + INPUT_IS_LIST = (True,) + + FUNCTION = "list_to_accumulation" + + CATEGORY = "Testing/Lists" + + def list_to_accumulation(self, list): + return ({"accum": list},) + +class TestAccumulationGetLengthNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("INT",) + + FUNCTION = "accumlength" + + CATEGORY = "Testing/Lists" + + def accumlength(self, accumulation): + return (len(accumulation['accum']),) + +class TestAccumulationGetItemNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + "index": ("INT", {"default":0, "step":1}) + }, + } + + RETURN_TYPES = ("*",) + + FUNCTION = "get_item" + + CATEGORY = "Testing/Lists" + + def get_item(self, accumulation, index): + return (accumulation['accum'][index],) + +class TestAccumulationSetItemNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + "index": ("INT", {"default":0, "step":1}), + "value": ("*",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + + FUNCTION = "set_item" + + CATEGORY = "Testing/Lists" + + def set_item(self, accumulation, index, value): + new_accum = accumulation['accum'][:] + new_accum[index] = value + return ({"accum": new_accum},) + +class TestIntMathOperation: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "operation": (["add", "subtract", "multiply", "divide", "modulo", "power"],), + }, + } + + RETURN_TYPES = ("INT",) + FUNCTION = "int_math_operation" + + CATEGORY = "Testing/Logic" + + def int_math_operation(self, a, b, operation): + if operation == "add": + return (a + b,) + elif operation == "subtract": + return (a - b,) + elif operation == "multiply": + return (a * b,) + elif operation == "divide": + return (a // b,) + elif operation == "modulo": + return (a % b,) + elif operation == "power": + return (a ** b,) + + +from .flow_control import NUM_FLOW_SOCKETS +class TestForLoopOpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1}), + }, + "optional": { + "initial_value%d" % i: ("*",) for i in range(1, NUM_FLOW_SOCKETS) + }, + "hidden": { + "initial_value0": ("*",) + } + } + + RETURN_TYPES = tuple(["FLOW_CONTROL", "INT",] + ["*"] * (NUM_FLOW_SOCKETS-1)) + RETURN_NAMES = tuple(["flow_control", "remaining"] + ["value%d" % i for i in range(1, NUM_FLOW_SOCKETS)]) + FUNCTION = "for_loop_open" + + CATEGORY = "Testing/Flow" + + def for_loop_open(self, remaining, **kwargs): + graph = GraphBuilder() + if "initial_value0" in kwargs: + remaining = kwargs["initial_value0"] + while_open = graph.node("TestWhileLoopOpen", condition=remaining, initial_value0=remaining, **{("initial_value%d" % i): kwargs.get("initial_value%d" % i, None) for i in range(1, NUM_FLOW_SOCKETS)}) + outputs = [kwargs.get("initial_value%d" % i, None) for i in range(1, NUM_FLOW_SOCKETS)] + return { + "result": tuple(["stub", remaining] + outputs), + "expand": graph.finalize(), + } + +class TestForLoopClose: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "flow_control": ("FLOW_CONTROL", {"rawLink": True}), + }, + "optional": { + "initial_value%d" % i: ("*",{"rawLink": True}) for i in range(1, NUM_FLOW_SOCKETS) + }, + } + + RETURN_TYPES = tuple(["*"] * (NUM_FLOW_SOCKETS-1)) + RETURN_NAMES = tuple(["value%d" % i for i in range(1, NUM_FLOW_SOCKETS)]) + FUNCTION = "for_loop_close" + + CATEGORY = "Testing/Flow" + + def for_loop_close(self, flow_control, **kwargs): + graph = GraphBuilder() + while_open = flow_control[0] + sub = graph.node("TestIntMathOperation", operation="subtract", a=[while_open,1], b=1) + cond = graph.node("TestToBoolNode", value=sub.out(0)) + input_values = {("initial_value%d" % i): kwargs.get("initial_value%d" % i, None) for i in range(1, NUM_FLOW_SOCKETS)} + while_close = graph.node("TestWhileLoopClose", + flow_control=flow_control, + condition=cond.out(0), + initial_value0=sub.out(0), + **input_values) + return { + "result": tuple([while_close.out(i) for i in range(1, NUM_FLOW_SOCKETS)]), + "expand": graph.finalize(), + } + +NUM_LIST_SOCKETS = 10 +class TestMakeListNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value1": ("*",), + }, + "optional": { + "value%d" % i: ("*",) for i in range(1, NUM_LIST_SOCKETS) + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "make_list" + OUTPUT_IS_LIST = (True,) + + CATEGORY = "Testing/Lists" + + def make_list(self, **kwargs): + result = [] + for i in range(NUM_LIST_SOCKETS): + if "value%d" % i in kwargs: + result.append(kwargs["value%d" % i]) + return (result,) + +UTILITY_NODE_CLASS_MAPPINGS = { + "TestAccumulateNode": TestAccumulateNode, + "TestAccumulationHeadNode": TestAccumulationHeadNode, + "TestAccumulationTailNode": TestAccumulationTailNode, + "TestAccumulationToListNode": TestAccumulationToListNode, + "TestListToAccumulationNode": TestListToAccumulationNode, + "TestAccumulationGetLengthNode": TestAccumulationGetLengthNode, + "TestAccumulationGetItemNode": TestAccumulationGetItemNode, + "TestAccumulationSetItemNode": TestAccumulationSetItemNode, + "TestForLoopOpen": TestForLoopOpen, + "TestForLoopClose": TestForLoopClose, + "TestIntMathOperation": TestIntMathOperation, + "TestMakeListNode": TestMakeListNode, +} +UTILITY_NODE_DISPLAY_NAME_MAPPINGS = { + "TestAccumulateNode": "Accumulate", + "TestAccumulationHeadNode": "Accumulation Head", + "TestAccumulationTailNode": "Accumulation Tail", + "TestAccumulationToListNode": "Accumulation to List", + "TestListToAccumulationNode": "List to Accumulation", + "TestAccumulationGetLengthNode": "Accumulation Get Length", + "TestAccumulationGetItemNode": "Accumulation Get Item", + "TestAccumulationSetItemNode": "Accumulation Set Item", + "TestForLoopOpen": "For Loop Open", + "TestForLoopClose": "For Loop Close", + "TestIntMathOperation": "Int Math Operation", + "TestMakeListNode": "Make List", +} From e60dbe3a44075ad0b72490755088c2e4e42a1d2f Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Wed, 21 Feb 2024 19:36:51 -0800 Subject: [PATCH 08/30] Fix issue with unused literals Behavior should now match the master branch with regard to undeclared inputs. Undeclared inputs that are socket connections will be used while undeclared inputs that are literals will be ignored. --- execution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/execution.py b/execution.py index 4d9a4b98c7c..afedc075811 100644 --- a/execution.py +++ b/execution.py @@ -94,7 +94,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro continue obj = cached_output[output_index] input_data_all[x] = obj - else: + elif input_category is not None: input_data_all[x] = [input_data] if "hidden" in valid_inputs: From 6d09dd70f8e6400ab9952bfc2bb98ba10c360395 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 24 Feb 2024 23:17:01 -0800 Subject: [PATCH 09/30] Make custom VALIDATE_INPUTS skip normal validation Additionally, if `VALIDATE_INPUTS` takes an argument named `input_types`, that variable will be a dictionary of the socket type of all incoming connections. If that argument exists, normal socket type validation will not occur. This removes the last hurdle for enabling variant types entirely from custom nodes, so I've removed that command-line option. I've added appropriate unit tests for these changes. --- comfy/cli_args.py | 1 - execution.py | 61 +++++----- tests/inference/test_execution.py | 63 +++++++++- .../testing-pack/flow_control.py | 4 + .../testing-pack/specific_tests.py | 112 ++++++++++++++++-- .../testing_nodes/testing-pack/stubs.py | 44 +++++++ .../testing_nodes/testing-pack/tools.py | 48 ++++++++ .../testing_nodes/testing-pack/util.py | 11 ++ 8 files changed, 305 insertions(+), 39 deletions(-) create mode 100644 tests/inference/testing_nodes/testing-pack/tools.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 74354ea9426..2cbefefebd9 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -117,7 +117,6 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") -parser.add_argument("--enable-variants", action="store_true", help="Enables '*' type nodes.") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/execution.py b/execution.py index afedc075811..c8c89d01f20 100644 --- a/execution.py +++ b/execution.py @@ -92,6 +92,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro cached_output = outputs.get(input_unique_id) if cached_output is None: continue + if output_index >= len(cached_output): + continue obj = cached_output[output_index] input_data_all[x] = obj elif input_category is not None: @@ -514,6 +516,7 @@ def validate_inputs(prompt, item, validated): validate_function_inputs = [] if hasattr(obj_class, "VALIDATE_INPUTS"): validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args + received_types = {} for x in valid_inputs: type_input, input_category, extra_info = get_input_info(obj_class, x) @@ -551,9 +554,9 @@ def validate_inputs(prompt, item, validated): o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES - is_variant = args.enable_variants and (r[val[1]] == "*" or type_input == "*") - if r[val[1]] != type_input and not is_variant: - received_type = r[val[1]] + received_type = r[val[1]] + received_types[x] = received_type + if 'input_types' not in validate_function_inputs and received_type != type_input: details = f"{x}, {received_type} != {type_input}" error = { "type": "return_type_mismatch", @@ -622,34 +625,34 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if "min" in extra_info and val < extra_info["min"]: - error = { - "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, + if x not in validate_function_inputs: + if "min" in extra_info and val < extra_info["min"]: + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } } - } - errors.append(error) - continue - if "max" in extra_info and val > extra_info["max"]: - error = { - "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, + errors.append(error) + continue + if "max" in extra_info and val > extra_info["max"]: + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } } - } - errors.append(error) - continue + errors.append(error) + continue - if x not in validate_function_inputs: if isinstance(type_input, list): if val not in type_input: input_config = info @@ -682,6 +685,8 @@ def validate_inputs(prompt, item, validated): for x in input_data_all: if x in validate_function_inputs: input_filtered[x] = input_data_all[x] + if 'input_types' in validate_function_inputs: + input_filtered['input_types'] = [received_types] #ret = obj_class.VALIDATE_INPUTS(**input_filtered) ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 52e1956279c..6a4fa3dd1d6 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -12,6 +12,7 @@ import uuid import urllib.request import urllib.parse +import urllib.error from comfy.graph_utils import GraphBuilder, Node class RunResult: @@ -125,7 +126,6 @@ def _server(self, args_pytest): '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', - '--enable-variants', ]) yield p.kill() @@ -237,6 +237,67 @@ def test_error(self, client: ComfyClient, builder: GraphBuilder): except Exception as e: assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" + @pytest.mark.parametrize("test_value, expect_error", [ + (5, True), + ("foo", True), + (5.0, False), + ]) + def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0) + g.node("SaveImage", images=validation1.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + + @pytest.mark.parametrize("test_type, test_value", [ + ("StubInt", 5), + ("StubFloat", 5.0) + ]) + def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation1.out(0)) + + with pytest.raises(urllib.error.HTTPError): + client.run(g) + + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation2.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation3.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): g = builder # Creating the nodes in this specific order previously caused a bug diff --git a/tests/inference/testing_nodes/testing-pack/flow_control.py b/tests/inference/testing_nodes/testing-pack/flow_control.py index 8befdcf1973..43f1ce02ffc 100644 --- a/tests/inference/testing_nodes/testing-pack/flow_control.py +++ b/tests/inference/testing_nodes/testing-pack/flow_control.py @@ -1,7 +1,9 @@ from comfy.graph_utils import GraphBuilder, is_link from comfy.graph import ExecutionBlocker +from .tools import VariantSupport NUM_FLOW_SOCKETS = 5 +@VariantSupport() class TestWhileLoopOpen: def __init__(self): pass @@ -31,6 +33,7 @@ def while_loop_open(self, condition, **kwargs): values.append(kwargs.get("initial_value%d" % i, None)) return tuple(["stub"] + values) +@VariantSupport() class TestWhileLoopClose: def __init__(self): pass @@ -131,6 +134,7 @@ def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=No "expand": graph.finalize(), } +@VariantSupport() class TestExecutionBlockerNode: def __init__(self): pass diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index e3d864b443b..8c103c18af7 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -1,9 +1,7 @@ import torch +from .tools import VariantSupport class TestLazyMixImages: - def __init__(self): - pass - @classmethod def INPUT_TYPES(cls): return { @@ -50,9 +48,6 @@ def mix(self, mask, image1 = None, image2 = None): return (result[0],) class TestVariadicAverage: - def __init__(self): - pass - @classmethod def INPUT_TYPES(cls): return { @@ -74,9 +69,6 @@ def variadic_average(self, input1, **kwargs): class TestCustomIsChanged: - def __init__(self): - pass - @classmethod def INPUT_TYPES(cls): return { @@ -103,14 +95,116 @@ def IS_CHANGED(cls, should_change=False, *args, **kwargs): else: return False +class TestCustomValidation1: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation1" + + CATEGORY = "Testing/Nodes" + + def custom_validation1(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input1=None, input2=None): + if input1 is not None: + if not isinstance(input1, (torch.Tensor, float)): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, (torch.Tensor, float)): + return f"Invalid type of input2: {type(input2)}" + + return True + +class TestCustomValidation2: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation2" + + CATEGORY = "Testing/Nodes" + + def custom_validation2(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None): + if input1 is not None: + if not isinstance(input1, (torch.Tensor, float)): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, (torch.Tensor, float)): + return f"Invalid type of input2: {type(input2)}" + + if 'input1' in input_types: + if input_types['input1'] not in ["IMAGE", "FLOAT"]: + return f"Invalid type of input1: {input_types['input1']}" + if 'input2' in input_types: + if input_types['input2'] not in ["IMAGE", "FLOAT"]: + return f"Invalid type of input2: {input_types['input2']}" + + return True + +@VariantSupport() +class TestCustomValidation3: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation3" + + CATEGORY = "Testing/Nodes" + + def custom_validation3(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, "TestCustomIsChanged": TestCustomIsChanged, + "TestCustomValidation1": TestCustomValidation1, + "TestCustomValidation2": TestCustomValidation2, + "TestCustomValidation3": TestCustomValidation3, } TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestLazyMixImages": "Lazy Mix Images", "TestVariadicAverage": "Variadic Average", "TestCustomIsChanged": "Custom IsChanged", + "TestCustomValidation1": "Custom Validation 1", + "TestCustomValidation2": "Custom Validation 2", + "TestCustomValidation3": "Custom Validation 3", } diff --git a/tests/inference/testing_nodes/testing-pack/stubs.py b/tests/inference/testing_nodes/testing-pack/stubs.py index b2a5ebf3d70..9be6eac9d13 100644 --- a/tests/inference/testing_nodes/testing-pack/stubs.py +++ b/tests/inference/testing_nodes/testing-pack/stubs.py @@ -51,11 +51,55 @@ def INPUT_TYPES(cls): def stub_mask(self, value, height, width, batch_size): return (torch.ones(batch_size, height, width) * value,) +class StubInt: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}), + }, + } + + RETURN_TYPES = ("INT",) + FUNCTION = "stub_int" + + CATEGORY = "Testing/Stub Nodes" + + def stub_int(self, value): + return (value,) + +class StubFloat: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}), + }, + } + + RETURN_TYPES = ("FLOAT",) + FUNCTION = "stub_float" + + CATEGORY = "Testing/Stub Nodes" + + def stub_float(self, value): + return (value,) + TEST_STUB_NODE_CLASS_MAPPINGS = { "StubImage": StubImage, "StubMask": StubMask, + "StubInt": StubInt, + "StubFloat": StubFloat, } TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { "StubImage": "Stub Image", "StubMask": "Stub Mask", + "StubInt": "Stub Int", + "StubFloat": "Stub Float", } diff --git a/tests/inference/testing_nodes/testing-pack/tools.py b/tests/inference/testing_nodes/testing-pack/tools.py new file mode 100644 index 00000000000..6c8d5eaa0a9 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/tools.py @@ -0,0 +1,48 @@ + +class SmartType(str): + def __ne__(self, other): + if self == "*" or other == "*": + return False + selfset = set(self.split(',')) + otherset = set(other.split(',')) + return not selfset.issubset(otherset) + +def VariantSupport(): + def decorator(cls): + if hasattr(cls, "INPUT_TYPES"): + old_input_types = getattr(cls, "INPUT_TYPES") + def new_input_types(*args, **kwargs): + types = old_input_types(*args, **kwargs) + for category in ["required", "optional"]: + if category not in types: + continue + for key, value in types[category].items(): + if isinstance(value, tuple): + types[category][key] = (SmartType(value[0]),) + value[1:] + return types + setattr(cls, "INPUT_TYPES", new_input_types) + if hasattr(cls, "RETURN_TYPES"): + old_return_types = cls.RETURN_TYPES + setattr(cls, "RETURN_TYPES", tuple(SmartType(x) for x in old_return_types)) + if hasattr(cls, "VALIDATE_INPUTS"): + # Reflection is used to determine what the function signature is, so we can't just change the function signature + raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet") + else: + def validate_inputs(input_types): + inputs = cls.INPUT_TYPES() + for key, value in input_types.items(): + if isinstance(value, SmartType): + continue + if "required" in inputs and key in inputs["required"]: + expected_type = inputs["required"][key][0] + elif "optional" in inputs and key in inputs["optional"]: + expected_type = inputs["optional"][key][0] + else: + expected_type = None + if expected_type is not None and SmartType(value) != expected_type: + return f"Invalid type of {key}: {value} (expected {expected_type})" + return True + setattr(cls, "VALIDATE_INPUTS", validate_inputs) + return cls + return decorator + diff --git a/tests/inference/testing_nodes/testing-pack/util.py b/tests/inference/testing_nodes/testing-pack/util.py index 16209d3fc1b..8e2065c7bc2 100644 --- a/tests/inference/testing_nodes/testing-pack/util.py +++ b/tests/inference/testing_nodes/testing-pack/util.py @@ -1,5 +1,7 @@ from comfy.graph_utils import GraphBuilder +from .tools import VariantSupport +@VariantSupport() class TestAccumulateNode: def __init__(self): pass @@ -27,6 +29,7 @@ def accumulate(self, to_add, accumulation = None): value = accumulation["accum"] + [to_add] return ({"accum": value},) +@VariantSupport() class TestAccumulationHeadNode: def __init__(self): pass @@ -75,6 +78,7 @@ def accumulation_tail(self, accumulation): else: return ({"accum": accum[:-1]}, accum[-1]) +@VariantSupport() class TestAccumulationToListNode: def __init__(self): pass @@ -97,6 +101,7 @@ def INPUT_TYPES(cls): def accumulation_to_list(self, accumulation): return (accumulation["accum"],) +@VariantSupport() class TestListToAccumulationNode: def __init__(self): pass @@ -119,6 +124,7 @@ def INPUT_TYPES(cls): def list_to_accumulation(self, list): return ({"accum": list},) +@VariantSupport() class TestAccumulationGetLengthNode: def __init__(self): pass @@ -140,6 +146,7 @@ def INPUT_TYPES(cls): def accumlength(self, accumulation): return (len(accumulation['accum']),) +@VariantSupport() class TestAccumulationGetItemNode: def __init__(self): pass @@ -162,6 +169,7 @@ def INPUT_TYPES(cls): def get_item(self, accumulation, index): return (accumulation['accum'][index],) +@VariantSupport() class TestAccumulationSetItemNode: def __init__(self): pass @@ -222,6 +230,7 @@ def int_math_operation(self, a, b, operation): from .flow_control import NUM_FLOW_SOCKETS +@VariantSupport() class TestForLoopOpen: def __init__(self): pass @@ -257,6 +266,7 @@ def for_loop_open(self, remaining, **kwargs): "expand": graph.finalize(), } +@VariantSupport() class TestForLoopClose: def __init__(self): pass @@ -295,6 +305,7 @@ def for_loop_close(self, flow_control, **kwargs): } NUM_LIST_SOCKETS = 10 +@VariantSupport() class TestMakeListNode: def __init__(self): pass From 03394ace8c52575aa6fed0f0428ce414040735ef Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 23 Mar 2024 16:49:45 -0700 Subject: [PATCH 10/30] Fix example in unit test This wouldn't have caused any issues in the unit test, but it would have bugged the UI if someone copy+pasted it into their own node pack. --- tests/inference/testing_nodes/testing-pack/tools.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/inference/testing_nodes/testing-pack/tools.py b/tests/inference/testing_nodes/testing-pack/tools.py index 6c8d5eaa0a9..34b28c0eb48 100644 --- a/tests/inference/testing_nodes/testing-pack/tools.py +++ b/tests/inference/testing_nodes/testing-pack/tools.py @@ -1,4 +1,9 @@ +def MakeSmartType(t): + if isinstance(t, str): + return SmartType(t) + return t + class SmartType(str): def __ne__(self, other): if self == "*" or other == "*": @@ -18,12 +23,12 @@ def new_input_types(*args, **kwargs): continue for key, value in types[category].items(): if isinstance(value, tuple): - types[category][key] = (SmartType(value[0]),) + value[1:] + types[category][key] = (MakeSmartType(value[0]),) + value[1:] return types setattr(cls, "INPUT_TYPES", new_input_types) if hasattr(cls, "RETURN_TYPES"): old_return_types = cls.RETURN_TYPES - setattr(cls, "RETURN_TYPES", tuple(SmartType(x) for x in old_return_types)) + setattr(cls, "RETURN_TYPES", tuple(MakeSmartType(x) for x in old_return_types)) if hasattr(cls, "VALIDATE_INPUTS"): # Reflection is used to determine what the function signature is, so we can't just change the function signature raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet") @@ -39,7 +44,7 @@ def validate_inputs(input_types): expected_type = inputs["optional"][key][0] else: expected_type = None - if expected_type is not None and SmartType(value) != expected_type: + if expected_type is not None and MakeSmartType(value) != expected_type: return f"Invalid type of {key}: {value} (expected {expected_type})" return True setattr(cls, "VALIDATE_INPUTS", validate_inputs) From a0bf532558327a04ff4852138ce4b2c0bca28e56 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 20 Apr 2024 17:52:23 -0700 Subject: [PATCH 11/30] Use fstrings instead of '%' formatting syntax --- comfy/caching.py | 3 --- comfy/graph.py | 4 ++-- comfy/graph_utils.py | 2 +- execution.py | 4 ++-- .../testing-pack/flow_control.py | 14 ++++++------- .../testing_nodes/testing-pack/util.py | 20 +++++++++---------- 6 files changed, 22 insertions(+), 25 deletions(-) diff --git a/comfy/caching.py b/comfy/caching.py index 936e2e6dfee..060d53d5584 100644 --- a/comfy/caching.py +++ b/comfy/caching.py @@ -269,12 +269,9 @@ def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.generation += 1 for node_id in node_ids: self._mark_used(node_id) - print("LRUCache: Now at generation %d" % self.generation) def clean_unused(self): - print("LRUCache: Cleaning unused. Current size: %d/%d" % (len(self.cache), self.max_size)) while len(self.cache) > self.max_size and self.min_generation < self.generation: - print("LRUCache: Evicting generation %d" % self.min_generation) self.min_generation += 1 to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation] for key in to_remove: diff --git a/comfy/graph.py b/comfy/graph.py index 656ebeb9795..97a759de862 100644 --- a/comfy/graph.py +++ b/comfy/graph.py @@ -76,10 +76,10 @@ def get_input_info(self, unique_id, input_name): def make_input_strong_link(self, to_node_id, to_input): inputs = self.dynprompt.get_node(to_node_id)["inputs"] if to_input not in inputs: - raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input)) + raise Exception(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all") value = inputs[to_input] if not is_link(value): - raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input)) + raise Exception(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant") from_node_id, from_socket = value self.add_strong_link(from_node_id, from_socket, to_node_id) diff --git a/comfy/graph_utils.py b/comfy/graph_utils.py index a0042e078f7..10c7ec541a6 100644 --- a/comfy/graph_utils.py +++ b/comfy/graph_utils.py @@ -38,7 +38,7 @@ def alloc_prefix(cls, root=None, call_index=None, graph_index=None): call_index = GraphBuilder._default_prefix_call_index if graph_index is None: graph_index = GraphBuilder._default_prefix_graph_index - result = "%s.%d.%d." % (root, call_index, graph_index) + result = f"{root}.{call_index}.{graph_index}." GraphBuilder._default_prefix_graph_index += 1 return result diff --git a/execution.py b/execution.py index 050eea163b3..edeb7105985 100644 --- a/execution.py +++ b/execution.py @@ -299,7 +299,7 @@ def execution_block_cb(block): "node_type": class_type, "executed": list(executed), - "exception_message": "Execution Blocked: %s" % block.message, + "exception_message": f"Execution Blocked: {block.message}", "exception_type": "ExecutionBlocked", "traceback": [], "current_inputs": [], @@ -337,7 +337,7 @@ def pre_execute_cb(call_index): # Check for conflicts for node_id in new_graph.keys(): if dynprompt.get_node(node_id) is not None: - raise Exception("Attempt to add duplicate node %s. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder." % node_id) + raise Exception(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.") for node_id, node_info in new_graph.items(): new_node_ids.append(node_id) display_id = node_info.get("override_display_id", unique_id) diff --git a/tests/inference/testing_nodes/testing-pack/flow_control.py b/tests/inference/testing_nodes/testing-pack/flow_control.py index 43f1ce02ffc..1ef1cf803dd 100644 --- a/tests/inference/testing_nodes/testing-pack/flow_control.py +++ b/tests/inference/testing_nodes/testing-pack/flow_control.py @@ -18,11 +18,11 @@ def INPUT_TYPES(cls): }, } for i in range(NUM_FLOW_SOCKETS): - inputs["optional"]["initial_value%d" % i] = ("*",) + inputs["optional"][f"initial_value{i}"] = ("*",) return inputs RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS) - RETURN_NAMES = tuple(["FLOW_CONTROL"] + ["value%d" % i for i in range(NUM_FLOW_SOCKETS)]) + RETURN_NAMES = tuple(["FLOW_CONTROL"] + [f"value{i}" for i in range(NUM_FLOW_SOCKETS)]) FUNCTION = "while_loop_open" CATEGORY = "Testing/Flow" @@ -30,7 +30,7 @@ def INPUT_TYPES(cls): def while_loop_open(self, condition, **kwargs): values = [] for i in range(NUM_FLOW_SOCKETS): - values.append(kwargs.get("initial_value%d" % i, None)) + values.append(kwargs.get(f"initial_value{i}", None)) return tuple(["stub"] + values) @VariantSupport() @@ -53,11 +53,11 @@ def INPUT_TYPES(cls): } } for i in range(NUM_FLOW_SOCKETS): - inputs["optional"]["initial_value%d" % i] = ("*",) + inputs["optional"][f"initial_value{i}"] = ("*",) return inputs RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS) - RETURN_NAMES = tuple(["value%d" % i for i in range(NUM_FLOW_SOCKETS)]) + RETURN_NAMES = tuple([f"value{i}" for i in range(NUM_FLOW_SOCKETS)]) FUNCTION = "while_loop_close" CATEGORY = "Testing/Flow" @@ -89,7 +89,7 @@ def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=No # We're done with the loop values = [] for i in range(NUM_FLOW_SOCKETS): - values.append(kwargs.get("initial_value%d" % i, None)) + values.append(kwargs.get(f"initial_value{i}", None)) return tuple(values) # We want to loop @@ -124,7 +124,7 @@ def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=No new_open = graph.lookup_node(open_node) assert new_open is not None for i in range(NUM_FLOW_SOCKETS): - key = "initial_value%d" % i + key = f"initial_value{i}" new_open.set_input(key, kwargs.get(key, None)) my_clone = graph.lookup_node("Recurse") assert my_clone is not None diff --git a/tests/inference/testing_nodes/testing-pack/util.py b/tests/inference/testing_nodes/testing-pack/util.py index 8e2065c7bc2..fea83e37a23 100644 --- a/tests/inference/testing_nodes/testing-pack/util.py +++ b/tests/inference/testing_nodes/testing-pack/util.py @@ -242,7 +242,7 @@ def INPUT_TYPES(cls): "remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1}), }, "optional": { - "initial_value%d" % i: ("*",) for i in range(1, NUM_FLOW_SOCKETS) + f"initial_value{i}": ("*",) for i in range(1, NUM_FLOW_SOCKETS) }, "hidden": { "initial_value0": ("*",) @@ -250,7 +250,7 @@ def INPUT_TYPES(cls): } RETURN_TYPES = tuple(["FLOW_CONTROL", "INT",] + ["*"] * (NUM_FLOW_SOCKETS-1)) - RETURN_NAMES = tuple(["flow_control", "remaining"] + ["value%d" % i for i in range(1, NUM_FLOW_SOCKETS)]) + RETURN_NAMES = tuple(["flow_control", "remaining"] + [f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)]) FUNCTION = "for_loop_open" CATEGORY = "Testing/Flow" @@ -259,8 +259,8 @@ def for_loop_open(self, remaining, **kwargs): graph = GraphBuilder() if "initial_value0" in kwargs: remaining = kwargs["initial_value0"] - while_open = graph.node("TestWhileLoopOpen", condition=remaining, initial_value0=remaining, **{("initial_value%d" % i): kwargs.get("initial_value%d" % i, None) for i in range(1, NUM_FLOW_SOCKETS)}) - outputs = [kwargs.get("initial_value%d" % i, None) for i in range(1, NUM_FLOW_SOCKETS)] + while_open = graph.node("TestWhileLoopOpen", condition=remaining, initial_value0=remaining, **{(f"initial_value{i}"): kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)}) + outputs = [kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)] return { "result": tuple(["stub", remaining] + outputs), "expand": graph.finalize(), @@ -278,12 +278,12 @@ def INPUT_TYPES(cls): "flow_control": ("FLOW_CONTROL", {"rawLink": True}), }, "optional": { - "initial_value%d" % i: ("*",{"rawLink": True}) for i in range(1, NUM_FLOW_SOCKETS) + f"initial_value{i}": ("*",{"rawLink": True}) for i in range(1, NUM_FLOW_SOCKETS) }, } RETURN_TYPES = tuple(["*"] * (NUM_FLOW_SOCKETS-1)) - RETURN_NAMES = tuple(["value%d" % i for i in range(1, NUM_FLOW_SOCKETS)]) + RETURN_NAMES = tuple([f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)]) FUNCTION = "for_loop_close" CATEGORY = "Testing/Flow" @@ -293,7 +293,7 @@ def for_loop_close(self, flow_control, **kwargs): while_open = flow_control[0] sub = graph.node("TestIntMathOperation", operation="subtract", a=[while_open,1], b=1) cond = graph.node("TestToBoolNode", value=sub.out(0)) - input_values = {("initial_value%d" % i): kwargs.get("initial_value%d" % i, None) for i in range(1, NUM_FLOW_SOCKETS)} + input_values = {f"initial_value{i}": kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)} while_close = graph.node("TestWhileLoopClose", flow_control=flow_control, condition=cond.out(0), @@ -317,7 +317,7 @@ def INPUT_TYPES(cls): "value1": ("*",), }, "optional": { - "value%d" % i: ("*",) for i in range(1, NUM_LIST_SOCKETS) + f"value{i}": ("*",) for i in range(1, NUM_LIST_SOCKETS) }, } @@ -330,8 +330,8 @@ def INPUT_TYPES(cls): def make_list(self, **kwargs): result = [] for i in range(NUM_LIST_SOCKETS): - if "value%d" % i in kwargs: - result.append(kwargs["value%d" % i]) + if f"value{i}" in kwargs: + result.append(kwargs[f"value{i}"]) return (result,) UTILITY_NODE_CLASS_MAPPINGS = { From 5dc13651b0a867e80d5cbf3deb7acdc84c5debe6 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 20 Apr 2024 18:12:42 -0700 Subject: [PATCH 12/30] Use custom exception types. --- comfy/graph.py | 12 +++++++++--- execution.py | 5 ++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/comfy/graph.py b/comfy/graph.py index 97a759de862..bf4be4ae1d2 100644 --- a/comfy/graph.py +++ b/comfy/graph.py @@ -2,6 +2,12 @@ from comfy.graph_utils import is_link +class DependencyCycleError(Exception): + pass + +class NodeInputError(Exception): + pass + class DynamicPrompt: def __init__(self, original_prompt): # The original prompt provided by the user @@ -76,10 +82,10 @@ def get_input_info(self, unique_id, input_name): def make_input_strong_link(self, to_node_id, to_input): inputs = self.dynprompt.get_node(to_node_id)["inputs"] if to_input not in inputs: - raise Exception(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all") + raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all") value = inputs[to_input] if not is_link(value): - raise Exception(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant") + raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant") from_node_id, from_socket = value self.add_strong_link(from_node_id, from_socket, to_node_id) @@ -141,7 +147,7 @@ def stage_node_execution(self): return None available = self.get_ready_nodes() if len(available) == 0: - raise Exception("Dependency cycle detected") + raise DependencyCycleError("Dependency cycle detected") next_node = available[0] # If an output node is available, do that first. # Technically this has no effect on the overall length of execution, but it feels better as a user diff --git a/execution.py b/execution.py index edeb7105985..ecd0850afb3 100644 --- a/execution.py +++ b/execution.py @@ -23,6 +23,9 @@ class ExecutionResult(Enum): FAILURE = 1 SLEEPING = 2 +class DuplicateNodeError(Exception): + pass + class IsChangedCache: def __init__(self, dynprompt, outputs_cache): self.dynprompt = dynprompt @@ -337,7 +340,7 @@ def pre_execute_cb(call_index): # Check for conflicts for node_id in new_graph.keys(): if dynprompt.get_node(node_id) is not None: - raise Exception(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.") + raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.") for node_id, node_info in new_graph.items(): new_node_ids.append(node_id) display_id = node_info.get("override_display_id", unique_id) From dd3bafb40b37e377ffa630edee238c19c03a6d44 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 20 Apr 2024 22:40:38 -0700 Subject: [PATCH 13/30] Display an error for dependency cycles Previously, dependency cycles that were created during node expansion would cause the application to quit (due to an uncaught exception). Now, we'll throw a proper error to the UI. We also make an attempt to 'blame' the most relevant node in the UI. --- comfy/graph.py | 42 +++++++++++++++++-- execution.py | 6 ++- tests/inference/test_execution.py | 30 +++++++++++++ .../testing-pack/specific_tests.py | 32 ++++++++++++++ 4 files changed, 106 insertions(+), 4 deletions(-) diff --git a/comfy/graph.py b/comfy/graph.py index bf4be4ae1d2..b20c7bf3828 100644 --- a/comfy/graph.py +++ b/comfy/graph.py @@ -144,10 +144,27 @@ def add_strong_link(self, from_node_id, from_socket, to_node_id): def stage_node_execution(self): assert self.staged_node_id is None if self.is_empty(): - return None + return None, None, None available = self.get_ready_nodes() if len(available) == 0: - raise DependencyCycleError("Dependency cycle detected") + cycled_nodes = self.get_nodes_in_cycle() + # Because cycles composed entirely of static nodes are caught during initial validation, + # we will 'blame' the first node in the cycle that is not a static node. + blamed_node = cycled_nodes[0] + for node_id in cycled_nodes: + display_node_id = self.dynprompt.get_display_node_id(node_id) + if display_node_id != node_id: + blamed_node = display_node_id + break + ex = DependencyCycleError("Dependency cycle detected") + error_details = { + "node_id": blamed_node, + "exception_message": str(ex), + "exception_type": "graph.DependencyCycleError", + "traceback": [], + "current_inputs": [] + } + return None, error_details, ex next_node = available[0] # If an output node is available, do that first. # Technically this has no effect on the overall length of execution, but it feels better as a user @@ -160,7 +177,7 @@ def stage_node_execution(self): next_node = node_id break self.staged_node_id = next_node - return self.staged_node_id + return self.staged_node_id, None, None def unstage_node_execution(self): assert self.staged_node_id is not None @@ -171,6 +188,25 @@ def complete_node_execution(self): self.pop_node(node_id) self.staged_node_id = None + def get_nodes_in_cycle(self): + # We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle. + # We're skipping some of the performance optimizations from the original TopologicalSort to keep + # the code simple (and because having a cycle in the first place is a catastrophic error) + blocked_by = { node_id: {} for node_id in self.pendingNodes } + for from_node_id in self.blocking: + for to_node_id in self.blocking[from_node_id]: + if True in self.blocking[from_node_id][to_node_id].values(): + blocked_by[to_node_id][from_node_id] = True + to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] + while len(to_remove) > 0: + for node_id in to_remove: + for to_node_id in blocked_by: + if node_id in blocked_by[to_node_id]: + del blocked_by[to_node_id][node_id] + del blocked_by[node_id] + to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] + return list(blocked_by.keys()) + # Return this from a node and any users will be blocked with the given error message. class ExecutionBlocker: def __init__(self, message): diff --git a/execution.py b/execution.py index ecd0850afb3..ee4637a38ce 100644 --- a/execution.py +++ b/execution.py @@ -473,7 +473,11 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): execution_list.add_node(node_id) while not execution_list.is_empty(): - node_id = execution_list.stage_node_execution() + node_id, error, ex = execution_list.stage_node_execution() + if error is not None: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 6a4fa3dd1d6..0ae70b8caad 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -234,6 +234,7 @@ def test_error(self, client: ComfyClient, builder: GraphBuilder): try: client.run(g) + assert False, "Should have raised an error" except Exception as e: assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" @@ -298,6 +299,35 @@ def test_validation_error_edge3(self, test_type, test_value, expect_error, clien else: client.run(g) + def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), mask=mask.out(0)) + lazy_mix2 = g.node("TestLazyMixImages", image1=lazy_mix1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix2.out(0)) + + # When the cycle exists on initial submission, it should raise a validation error + with pytest.raises(urllib.error.HTTPError): + client.run(g) + + def test_dynamic_cycle_error(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + generator = g.node("TestDynamicDependencyCycle", input1=input1.out(0), input2=input2.out(0)) + g.node("SaveImage", images=generator.out(0)) + + # When the cycle is in a graph that is generated dynamically, it should raise a runtime error + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" + assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node" + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): g = builder # Creating the nodes in this specific order previously caused a bug diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index 8c103c18af7..56b8f70b2af 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -1,5 +1,6 @@ import torch from .tools import VariantSupport +from comfy.graph_utils import GraphBuilder class TestLazyMixImages: @classmethod @@ -191,6 +192,35 @@ def custom_validation3(self, input1, input2): result = input1 * input2 return (result,) +class TestDynamicDependencyCycle: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE",), + "input2": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "dynamic_dependency_cycle" + + CATEGORY = "Testing/Nodes" + + def dynamic_dependency_cycle(self, input1, input2): + g = GraphBuilder() + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0)) + mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0)) + + # Create the cyle + mix1.set_input("image2", mix2.out(0)) + + return { + "result": (mix2.out(0),), + "expand": g.finalize(), + } + TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, @@ -198,6 +228,7 @@ def custom_validation3(self, input1, input2): "TestCustomValidation1": TestCustomValidation1, "TestCustomValidation2": TestCustomValidation2, "TestCustomValidation3": TestCustomValidation3, + "TestDynamicDependencyCycle": TestDynamicDependencyCycle, } TEST_NODE_DISPLAY_NAME_MAPPINGS = { @@ -207,4 +238,5 @@ def custom_validation3(self, input1, input2): "TestCustomValidation1": "Custom Validation 1", "TestCustomValidation2": "Custom Validation 2", "TestCustomValidation3": "Custom Validation 3", + "TestDynamicDependencyCycle": "Dynamic Dependency Cycle", } From 7dbee88485ede3b514d55e72ef1f8adf79036fd0 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 20 Apr 2024 22:54:38 -0700 Subject: [PATCH 14/30] Add docs on when ExecutionBlocker should be used --- comfy/graph.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/comfy/graph.py b/comfy/graph.py index b20c7bf3828..1135620ec0c 100644 --- a/comfy/graph.py +++ b/comfy/graph.py @@ -208,6 +208,16 @@ def get_nodes_in_cycle(self): return list(blocked_by.keys()) # Return this from a node and any users will be blocked with the given error message. +# If the message is None, execution will be blocked silently instead. +# Generally, you should avoid using this functionality unless absolutley necessary. Whenever it's +# possible, a lazy input will be more efficient and have a better user experience. +# This functionality is useful in two cases: +# 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node +# like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using +# lazy evaluation to let it conditionally disable itself.) +# 2. You have a node with multiple possible outputs, some of which are invalid and should not be used. +# (I would recommend not making nodes like this in the future -- instead, make multiple nodes with +# different outputs. Unfortunately, there are several popular existing nodes using this pattern.) class ExecutionBlocker: def __init__(self, message): self.message = message From b5e4583583da31b8061ee156d7d4434d8a8150d7 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 20 Apr 2024 23:01:34 -0700 Subject: [PATCH 15/30] Remove unused functionality --- comfy/graph_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy/graph_utils.py b/comfy/graph_utils.py index 10c7ec541a6..8595e942d32 100644 --- a/comfy/graph_utils.py +++ b/comfy/graph_utils.py @@ -27,8 +27,7 @@ def __init__(self, prefix = None): def set_default_prefix(cls, prefix_root, call_index, graph_index = 0): cls._default_prefix_root = prefix_root cls._default_prefix_call_index = call_index - if graph_index is not None: - cls._default_prefix_graph_index = graph_index + cls._default_prefix_graph_index = graph_index @classmethod def alloc_prefix(cls, root=None, call_index=None, graph_index=None): From 75774c6ad111b702c56c2efb605dedf49459a211 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 20 Apr 2024 23:02:43 -0700 Subject: [PATCH 16/30] Rename ExecutionResult.SLEEPING to PENDING --- execution.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/execution.py b/execution.py index ee4637a38ce..52de3ec2cf9 100644 --- a/execution.py +++ b/execution.py @@ -21,7 +21,7 @@ class ExecutionResult(Enum): SUCCESS = 0 FAILURE = 1 - SLEEPING = 2 + PENDING = 2 class DuplicateNodeError(Exception): pass @@ -292,7 +292,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp if len(required_inputs) > 0: for i in required_inputs: execution_list.make_input_strong_link(unique_id, i) - return (ExecutionResult.SLEEPING, None, None) + return (ExecutionResult.PENDING, None, None) def execution_block_cb(block): if block.message is not None: @@ -363,7 +363,7 @@ def pre_execute_cb(call_index): for link in new_output_links: execution_list.add_strong_link(link[0], link[1], unique_id) pending_subgraph_results[unique_id] = cached_outputs - return (ExecutionResult.SLEEPING, None, None) + return (ExecutionResult.PENDING, None, None) caches.outputs.set(unique_id, output_data) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -482,7 +482,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break - elif result == ExecutionResult.SLEEPING: + elif result == ExecutionResult.PENDING: execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: execution_list.complete_node_execution() From ecbef304ed5d389081a15aba922eea3195e978e7 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 20 Apr 2024 23:07:18 -0700 Subject: [PATCH 17/30] Remove superfluous function parameter --- comfy/graph.py | 3 +++ execution.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/graph.py b/comfy/graph.py index 1135620ec0c..9149ca4a23a 100644 --- a/comfy/graph.py +++ b/comfy/graph.py @@ -45,6 +45,9 @@ def get_display_node_id(self, node_id): def all_node_ids(self): return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys())) + def get_original_prompt(self): + return self.original_prompt + def get_input_info(class_def, input_name): valid_inputs = class_def.INPUT_TYPES() input_info = None diff --git a/execution.py b/execution.py index 52de3ec2cf9..6d0f40a932a 100644 --- a/execution.py +++ b/execution.py @@ -81,7 +81,7 @@ def recursive_debug_dump(self): } return result -def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynprompt=None, extra_data={}): +def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} for x in inputs: @@ -106,7 +106,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": - input_data_all[x] = [prompt] + input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}] if h[x] == "DYNPROMPT": input_data_all[x] = [dynprompt] if h[x] == "EXTRA_PNGINFO": @@ -275,7 +275,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp output_ui = [] has_subgraph = False else: - input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt.original_prompt, dynprompt, extra_data) + input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) From 1f065889053a9b1328dde614a729eae886ee60a3 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sun, 21 Apr 2024 00:10:04 -0700 Subject: [PATCH 18/30] Pass None for uneval inputs instead of default This applies to `VALIDATE_INPUTS`, `check_lazy_status`, and lazy values in evaluation functions. --- execution.py | 19 +++++++--- tests/inference/test_execution.py | 16 ++++++++ .../testing-pack/specific_tests.py | 37 +++++++++++++++++-- 3 files changed, 64 insertions(+), 8 deletions(-) diff --git a/execution.py b/execution.py index 6d0f40a932a..2aa939c1bbd 100644 --- a/execution.py +++ b/execution.py @@ -41,7 +41,7 @@ def get(self, node_id): if "is_changed" in node: self.is_changed[node_id] = node["is_changed"] else: - input_data_all = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) + input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) try: is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] @@ -84,18 +84,25 @@ def recursive_debug_dump(self): def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} + missing_keys = {} for x in inputs: input_data = inputs[x] input_type, input_category, input_info = get_input_info(class_def, x) + def mark_missing(): + missing_keys[x] = True + input_data_all[x] = (None,) if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] if outputs is None: + mark_missing() continue # This might be a lazily-evaluated input cached_output = outputs.get(input_unique_id) if cached_output is None: + mark_missing() continue if output_index >= len(cached_output): + mark_missing() continue obj = cached_output[output_index] input_data_all[x] = obj @@ -113,7 +120,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e input_data_all[x] = [extra_data.get('extra_pnginfo', None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] - return input_data_all + return input_data_all, missing_keys def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): # check if node wants the lists @@ -275,7 +282,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp output_ui = [] has_subgraph = False else: - input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -288,7 +295,9 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp if hasattr(obj, "check_lazy_status"): required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) - required_inputs = [x for x in required_inputs if isinstance(x,str) and x not in input_data_all] + required_inputs = [x for x in required_inputs if isinstance(x,str) and ( + x not in input_data_all or x in missing_keys + )] if len(required_inputs) > 0: for i in required_inputs: execution_list.make_input_strong_link(unique_id, i) @@ -685,7 +694,7 @@ def validate_inputs(prompt, item, validated): continue if len(validate_function_inputs) > 0: - input_data_all = get_input_data(inputs, obj_class, unique_id) + input_data_all, _ = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs: diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 0ae70b8caad..40de998331f 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -299,6 +299,22 @@ def test_validation_error_edge3(self, test_type, test_value, expect_error, clien else: client.run(g) + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + def test_validation_error_edge4(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation4.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index 56b8f70b2af..03dbc8b42d8 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -18,7 +18,7 @@ def INPUT_TYPES(cls): CATEGORY = "Testing/Nodes" - def check_lazy_status(self, mask, image1 = None, image2 = None): + def check_lazy_status(self, mask, image1, image2): mask_min = mask.min() mask_max = mask.max() needed = [] @@ -29,7 +29,7 @@ def check_lazy_status(self, mask, image1 = None, image2 = None): return needed # Not trying to handle different batch sizes here just to keep the demo simple - def mix(self, mask, image1 = None, image2 = None): + def mix(self, mask, image1, image2): mask_min = mask.min() mask_max = mask.max() if mask_min == 0.0 and mask_max == 0.0: @@ -45,7 +45,6 @@ def mix(self, mask, image1 = None, image2 = None): mask = mask.repeat(1, 1, 1, image1.shape[3]) result = image1 * (1. - mask) + image2 * mask, - print(result[0]) return (result[0],) class TestVariadicAverage: @@ -192,6 +191,36 @@ def custom_validation3(self, input1, input2): result = input1 * input2 return (result,) +class TestCustomValidation4: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("FLOAT",), + "input2": ("FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation4" + + CATEGORY = "Testing/Nodes" + + def custom_validation4(self, input1, input2): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input1, input2): + if input1 is not None: + if not isinstance(input1, float): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, float): + return f"Invalid type of input2: {type(input2)}" + + return True + class TestDynamicDependencyCycle: @classmethod def INPUT_TYPES(cls): @@ -228,6 +257,7 @@ def dynamic_dependency_cycle(self, input1, input2): "TestCustomValidation1": TestCustomValidation1, "TestCustomValidation2": TestCustomValidation2, "TestCustomValidation3": TestCustomValidation3, + "TestCustomValidation4": TestCustomValidation4, "TestDynamicDependencyCycle": TestDynamicDependencyCycle, } @@ -238,5 +268,6 @@ def dynamic_dependency_cycle(self, input1, input2): "TestCustomValidation1": "Custom Validation 1", "TestCustomValidation2": "Custom Validation 2", "TestCustomValidation3": "Custom Validation 3", + "TestCustomValidation4": "Custom Validation 4", "TestDynamicDependencyCycle": "Dynamic Dependency Cycle", } From 2dda3f2827691539eac6730802b38ef3b73a04e3 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sun, 21 Apr 2024 15:39:20 -0700 Subject: [PATCH 19/30] Add a test for mixed node expansion This test ensures that a node that returns a combination of expanded subgraphs and literal values functions correctly. --- execution.py | 1 + tests/inference/test_execution.py | 19 ++++++++++ .../testing-pack/specific_tests.py | 35 +++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/execution.py b/execution.py index 2aa939c1bbd..9db69cbd2a4 100644 --- a/execution.py +++ b/execution.py @@ -224,6 +224,7 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb if isinstance(r, ExecutionBlocker): r = tuple([r] * len(obj.RETURN_TYPES)) results.append(r) + subgraph_results.append((None, r)) if has_subgraph: output = subgraph_results diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 40de998331f..1fb58d5e0fe 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -399,3 +399,22 @@ def test_for_loop(self, client: ComfyClient, builder: GraphBuilder): expected = 255 // (2 ** iterations) assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" assert result.did_run(is_changed) + + def test_mixed_expansion_returns(self, client: ComfyClient, builder: GraphBuilder): + g = builder + val_list = g.node("TestMakeListNode", value1=0.1, value2=0.2, value3=0.3) + mixed = g.node("TestMixedExpansionReturns", input1=val_list.out(0)) + output_dynamic = g.node("SaveImage", images=mixed.out(0)) + output_literal = g.node("SaveImage", images=mixed.out(1)) + + result = client.run(g) + images_dynamic = result.get_images(output_dynamic) + assert len(images_dynamic) == 3, "Should have 2 images" + assert numpy.array(images_dynamic[0]).min() == 25 and numpy.array(images_dynamic[0]).max() == 25, "First image should be 0.1" + assert numpy.array(images_dynamic[1]).min() == 51 and numpy.array(images_dynamic[1]).max() == 51, "Second image should be 0.2" + assert numpy.array(images_dynamic[2]).min() == 76 and numpy.array(images_dynamic[2]).max() == 76, "Third image should be 0.3" + + images_literal = result.get_images(output_literal) + assert len(images_literal) == 3, "Should have 2 images" + for i in range(3): + assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white" diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index 03dbc8b42d8..8e8ce32ced8 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -250,6 +250,39 @@ def dynamic_dependency_cycle(self, input1, input2): "expand": g.finalize(), } +class TestMixedExpansionReturns: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE","IMAGE") + FUNCTION = "mixed_expansion_returns" + + CATEGORY = "Testing/Nodes" + + def mixed_expansion_returns(self, input1): + white_image = torch.ones([1, 512, 512, 3]) + if input1 <= 0.1: + return (torch.ones([1, 512, 512, 3]) * 0.1, white_image) + elif input1 <= 0.2: + return { + "result": (torch.ones([1, 512, 512, 3]) * 0.2, white_image), + } + else: + g = GraphBuilder() + mask = g.node("StubMask", value=0.3, height=512, width=512, batch_size=1) + black = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + white = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mix = g.node("TestLazyMixImages", image1=black.out(0), image2=white.out(0), mask=mask.out(0)) + return { + "result": (mix.out(0), white_image), + "expand": g.finalize(), + } + TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, @@ -259,6 +292,7 @@ def dynamic_dependency_cycle(self, input1, input2): "TestCustomValidation3": TestCustomValidation3, "TestCustomValidation4": TestCustomValidation4, "TestDynamicDependencyCycle": TestDynamicDependencyCycle, + "TestMixedExpansionReturns": TestMixedExpansionReturns, } TEST_NODE_DISPLAY_NAME_MAPPINGS = { @@ -270,4 +304,5 @@ def dynamic_dependency_cycle(self, input1, input2): "TestCustomValidation3": "Custom Validation 3", "TestCustomValidation4": "Custom Validation 4", "TestDynamicDependencyCycle": "Dynamic Dependency Cycle", + "TestMixedExpansionReturns": "Mixed Expansion Returns", } From 06f3ce9200734b54f123ff81e5a28757cc9ce5cb Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sun, 21 Apr 2024 16:10:01 -0700 Subject: [PATCH 20/30] Raise exception for bad get_node calls. --- comfy/graph.py | 8 +++++++- execution.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/comfy/graph.py b/comfy/graph.py index 9149ca4a23a..6671bc3a023 100644 --- a/comfy/graph.py +++ b/comfy/graph.py @@ -8,6 +8,9 @@ class DependencyCycleError(Exception): class NodeInputError(Exception): pass +class NodeNotFoundError(Exception): + pass + class DynamicPrompt: def __init__(self, original_prompt): # The original prompt provided by the user @@ -22,7 +25,10 @@ def get_node(self, node_id): return self.ephemeral_prompt[node_id] if node_id in self.original_prompt: return self.original_prompt[node_id] - return None + raise NodeNotFoundError(f"Node {node_id} not found") + + def has_node(self, node_id): + return node_id in self.original_prompt or node_id in self.ephemeral_prompt def add_ephemeral_node(self, node_id, node_info, parent_id, display_id): self.ephemeral_prompt[node_id] = node_info diff --git a/execution.py b/execution.py index 9db69cbd2a4..4688be2ad00 100644 --- a/execution.py +++ b/execution.py @@ -349,7 +349,7 @@ def pre_execute_cb(call_index): else: # Check for conflicts for node_id in new_graph.keys(): - if dynprompt.get_node(node_id) is not None: + if dynprompt.has_node(node_id): raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.") for node_id, node_info in new_graph.items(): new_node_ids.append(node_id) From fa48ad3a1fcb45b0027eae58fdab073cb9f269ef Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sun, 21 Apr 2024 21:39:52 -0700 Subject: [PATCH 21/30] Minor refactor of IsChangedCache.get --- execution.py | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/execution.py b/execution.py index 4abc10d9e47..1e35da8db1a 100644 --- a/execution.py +++ b/execution.py @@ -33,24 +33,28 @@ def __init__(self, dynprompt, outputs_cache): self.is_changed = {} def get(self, node_id): - if node_id not in self.is_changed: - node = self.dynprompt.get_node(node_id) - class_type = node["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if hasattr(class_def, "IS_CHANGED"): - if "is_changed" in node: - self.is_changed[node_id] = node["is_changed"] - else: - input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) - try: - is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") - node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] - self.is_changed[node_id] = node["is_changed"] - except: - node["is_changed"] = float("NaN") - self.is_changed[node_id] = node["is_changed"] - else: - self.is_changed[node_id] = False + if node_id in self.is_changed: + return self.is_changed[node_id] + + node = self.dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if not hasattr(class_def, "IS_CHANGED"): + self.is_changed[node_id] = False + return self.is_changed[node_id] + + if "is_changed" in node: + self.is_changed[node_id] = node["is_changed"] + return self.is_changed[node_id] + + input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) + try: + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") + node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] + except: + node["is_changed"] = float("NaN") + finally: + self.is_changed[node_id] = node["is_changed"] return self.is_changed[node_id] class CacheSet: From afa4c7b26089c6d3c47db9da5c56908e902dae4a Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sun, 21 Apr 2024 21:55:03 -0700 Subject: [PATCH 22/30] Refactor `map_node_over_list` function --- execution.py | 50 ++++++++++++++++---------------------------------- 1 file changed, 16 insertions(+), 34 deletions(-) diff --git a/execution.py b/execution.py index 1e35da8db1a..5895e1031cc 100644 --- a/execution.py +++ b/execution.py @@ -128,59 +128,41 @@ def mark_missing(): def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): # check if node wants the lists - input_is_list = False - if hasattr(obj, "INPUT_IS_LIST"): - input_is_list = obj.INPUT_IS_LIST + input_is_list = getattr(obj, "INPUT_IS_LIST", False) if len(input_data_all) == 0: max_len_input = 0 else: - max_len_input = max([len(x) for x in input_data_all.values()]) + max_len_input = max(len(x) for x in input_data_all.values()) # get a slice of inputs, repeat last input when list isn't long enough def slice_dict(d, i): - d_new = dict() - for k,v in d.items(): - d_new[k] = v[i if len(v) > i else -1] - return d_new + return {k: v[i if len(v) > i else -1] for k, v in d.items()} results = [] - if input_is_list: + def process_inputs(inputs, index=None): if allow_interrupt: nodes.before_node_execution() execution_block = None - for k, v in input_data_all.items(): - for input in v: - if isinstance(v, ExecutionBlocker): - execution_block = execution_block_cb(v) if execution_block_cb is not None else v - break - + for k, v in inputs.items(): + if isinstance(v, ExecutionBlocker): + execution_block = execution_block_cb(v) if execution_block_cb else v + break if execution_block is None: - if pre_execute_cb is not None: - pre_execute_cb(0) - results.append(getattr(obj, func)(**input_data_all)) + if pre_execute_cb is not None and index is not None: + pre_execute_cb(index) + results.append(getattr(obj, func)(**inputs)) else: results.append(execution_block) + + if input_is_list: + process_inputs(input_data_all, 0) elif max_len_input == 0: - if allow_interrupt: - nodes.before_node_execution() - results.append(getattr(obj, func)()) + process_inputs({}) else: for i in range(max_len_input): - if allow_interrupt: - nodes.before_node_execution() input_dict = slice_dict(input_data_all, i) - execution_block = None - for k, v in input_dict.items(): - if isinstance(v, ExecutionBlocker): - execution_block = execution_block_cb(v) if execution_block_cb is not None else v - break - if execution_block is None: - if pre_execute_cb is not None: - pre_execute_cb(i) - results.append(getattr(obj, func)(**input_dict)) - else: - results.append(execution_block) + process_inputs(input_dict, i) return results def merge_result_data(results, obj): From 8d17f3c7bf683f37fed2a8a9821e65a2296be82e Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sun, 16 Jun 2024 18:39:24 -0700 Subject: [PATCH 23/30] Fix ui output for duplicated nodes --- comfy/caching.py | 44 ++++++--------------- execution.py | 16 ++++---- tests/inference/test_execution.py | 66 +++++++++++++++++-------------- 3 files changed, 59 insertions(+), 67 deletions(-) diff --git a/comfy/caching.py b/comfy/caching.py index 060d53d5584..abcf68ae452 100644 --- a/comfy/caching.py +++ b/comfy/caching.py @@ -122,13 +122,6 @@ def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_map order_mapping[ancestor_id] = len(ancestors) - 1 self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) -class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature): - def __init__(self, dynprompt, node_ids, is_changed_cache): - super().__init__(dynprompt, node_ids, is_changed_cache) - - def include_node_id_in_input(self): - return True - class BasicCache: def __init__(self, key_class): self.key_class = key_class @@ -151,10 +144,8 @@ def all_node_ids(self): node_ids = node_ids.union(subcache.all_node_ids()) return node_ids - def clean_unused(self): - assert self.initialized + def _clean_cache(self): preserve_keys = set(self.cache_key_set.get_used_keys()) - preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) to_remove = [] for key in self.cache: if key not in preserve_keys: @@ -162,6 +153,9 @@ def clean_unused(self): for key in to_remove: del self.cache[key] + def _clean_subcaches(self): + preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) + to_remove = [] for key in self.subcaches: if key not in preserve_subcaches: @@ -169,6 +163,11 @@ def clean_unused(self): for key in to_remove: del self.subcaches[key] + def clean_unused(self): + assert self.initialized + self._clean_cache() + self._clean_subcaches() + def _set_immediate(self, node_id, value): assert self.initialized cache_key = self.cache_key_set.get_data_key(node_id) @@ -246,15 +245,6 @@ def ensure_subcache_for(self, node_id, children_ids): assert cache is not None return cache._ensure_subcache(node_id, children_ids) - def all_active_values(self): - active_nodes = self.all_node_ids() - result = [] - for node_id in active_nodes: - value = self.get(node_id) - if value is not None: - result.append(value) - return result - class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): super().__init__(key_class) @@ -279,6 +269,7 @@ def clean_unused(self): del self.used_generation[key] if key in self.children: del self.children[key] + self._clean_subcaches() def get(self, node_id): self._mark_used(node_id) @@ -294,6 +285,9 @@ def set(self, node_id, value): return self._set_immediate(node_id, value) def ensure_subcache_for(self, node_id, children_ids): + # Just uses subcaches for tracking 'live' nodes + super()._ensure_subcache(node_id, children_ids) + self.cache_key_set.add_keys(children_ids) self._mark_used(node_id) cache_key = self.cache_key_set.get_data_key(node_id) @@ -303,15 +297,3 @@ def ensure_subcache_for(self, node_id, children_ids): self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self - def all_active_values(self): - explored = set() - to_explore = set(self.cache_key_set.get_used_keys()) - while len(to_explore) > 0: - cache_key = to_explore.pop() - if cache_key not in explored: - self.used_generation[cache_key] = self.generation - explored.add(cache_key) - if cache_key in self.children: - to_explore.update(self.children[cache_key]) - return [self.cache[key] for key in explored if key in self.cache] - diff --git a/execution.py b/execution.py index 5895e1031cc..b5dd94f5e13 100644 --- a/execution.py +++ b/execution.py @@ -15,7 +15,7 @@ import comfy.graph_utils from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy.graph_utils import is_link, GraphBuilder -from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID +from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID from comfy.cli_args import args class ExecutionResult(Enum): @@ -69,13 +69,13 @@ def __init__(self, lru_size=None): # blowing away the cache every time def init_lru_cache(self, cache_size): self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.ui = LRUCache(CacheKeySetInputSignatureWithID, max_size=cache_size) + self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) # Performs like the old cache -- dump data ASAP def init_classic_cache(self): self.outputs = HierarchicalCache(CacheKeySetInputSignature) - self.ui = HierarchicalCache(CacheKeySetInputSignatureWithID) + self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def recursive_debug_dump(self): @@ -486,10 +486,12 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): ui_outputs = {} meta_outputs = {} - for ui_info in self.caches.ui.all_active_values(): - node_id = ui_info["meta"]["node_id"] - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] + all_node_ids = self.caches.ui.all_node_ids() + for node_id in all_node_ids: + ui_info = self.caches.ui.get(node_id) + if ui_info is not None: + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] self.history_result = { "outputs": ui_outputs, "meta": meta_outputs, diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 1fb58d5e0fe..e9f93797622 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -117,16 +117,26 @@ class TestExecution: # # Initialize server and client # - @fixture(scope="class", autouse=True) - def _server(self, args_pytest): + @fixture(scope="class", autouse=True, params=[ + # (use_lru, lru_size) + (False, 0), + (True, 0), + (True, 100), + ]) + def _server(self, args_pytest, request): # Start server - p = subprocess.Popen([ - 'python','main.py', - '--output-directory', args_pytest["output_dir"], - '--listen', args_pytest["listen"], - '--port', str(args_pytest["port"]), - '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', - ]) + pargs = [ + 'python','main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + ] + use_lru, lru_size = request.param + if use_lru: + pargs += ['--cache-lru', str(lru_size)] + print("Running server with args:", pargs) + p = subprocess.Popen(pargs) yield p.kill() torch.cuda.empty_cache() @@ -159,15 +169,9 @@ def client(self, shared_client, request): shared_client.set_test_name(f"execution[{request.node.name}]") yield shared_client - def clear_cache(self, client: ComfyClient): - g = GraphBuilder(prefix="foo") - random = g.node("StubImage", content="NOISE", height=1, width=1, batch_size=1) - g.node("PreviewImage", images=random.out(0)) - client.run(g) - @fixture - def builder(self): - yield GraphBuilder(prefix="") + def builder(self, request): + yield GraphBuilder(prefix=request.node.name) def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -187,7 +191,6 @@ def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder): assert result.did_run(lazy_mix) def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): - self.clear_cache(client) g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -196,14 +199,12 @@ def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) g.node("SaveImage", images=lazy_mix.out(0)) - result1 = client.run(g) + client.run(g) result2 = client.run(g) for node_id, node in g.nodes.items(): - assert result1.did_run(node), f"Node {node_id} didn't run" assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): - self.clear_cache(client) g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -212,15 +213,11 @@ def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) g.node("SaveImage", images=lazy_mix.out(0)) - result1 = client.run(g) + client.run(g) mask.inputs['value'] = 0.4 result2 = client.run(g) - for node_id, node in g.nodes.items(): - assert result1.did_run(node), f"Node {node_id} didn't run" assert not result2.did_run(input1), "Input1 should have been cached" assert not result2.did_run(input2), "Input2 should have been cached" - assert result2.did_run(mask), "Mask should have been re-run" - assert result2.did_run(lazy_mix), "Lazy mix should have been re-run" def test_error(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -365,7 +362,6 @@ def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): assert result4.did_run(is_changed), "is_changed should not have been cached" def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder): - self.clear_cache(client) g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) @@ -378,8 +374,6 @@ def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder): result_image = result.get_images(output)[0] expected = 255 // 4 assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" - assert result.did_run(input1) - assert result.did_run(input2) def test_for_loop(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -418,3 +412,17 @@ def test_mixed_expansion_returns(self, client: ComfyClient, builder: GraphBuilde assert len(images_literal) == 3, "Should have 2 images" for i in range(3): assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white" + + def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + output1 = g.node("PreviewImage", images=input1.out(0)) + output2 = g.node("PreviewImage", images=input1.out(0)) + + result = client.run(g) + images1 = result.get_images(output1) + images2 = result.get_images(output2) + assert len(images1) == 1, "Should have 1 image" + assert len(images2) == 1, "Should have 1 image" + From 4712df8a05c60fd247fc59d2a934b4822c8dc2cb Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sun, 16 Jun 2024 20:33:30 -0700 Subject: [PATCH 24/30] Add documentation on `check_lazy_status` --- custom_nodes/example_node.py.example | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index f066325930d..20f9bea75da 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -54,7 +54,8 @@ class Example: "min": 0, #Minimum value "max": 4096, #Maximum value "step": 64, #Slider's step - "display": "number" # Cosmetic only: display as "number" or "slider" + "display": "number", # Cosmetic only: display as "number" or "slider" + "lazy": True # Will only be evaluated if check_lazy_status requires it }), "float_field": ("FLOAT", { "default": 1.0, @@ -62,11 +63,14 @@ class Example: "max": 10.0, "step": 0.01, "round": 0.001, #The value represeting the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. - "display": "number"}), + "display": "number", + "lazy": True + }), "print_to_screen": (["enable", "disable"],), "string_field": ("STRING", { "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node - "default": "Hello World!" + "default": "Hello World!", + "lazy": True }), }, } @@ -80,6 +84,23 @@ class Example: CATEGORY = "Example" + def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen): + """ + Return a list of input names that need to be evaluated. + + This function will be called if there are any lazy inputs which have not yet been + evaluated. As long as you return at least one field which has not yet been evaluated + (and more exist), this function will be called again once the value of the requested + field is available. + + Any evaluated inputs will be passed as arguments to this function. Any unevaluated + inputs will have the value None. + """ + if print_to_screen == "enable": + return ["int_field", "float_field", "string_field"] + else: + return [] + def test(self, image, string_field, int_field, float_field, print_to_screen): if print_to_screen == "enable": print(f"""Your input contains: From 48d03c47bc9fa550192bfb81781221e2d151a694 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 20 Jul 2024 20:28:25 -0700 Subject: [PATCH 25/30] Add file for execution model unit tests --- tests/inference/extra_model_paths.yaml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 tests/inference/extra_model_paths.yaml diff --git a/tests/inference/extra_model_paths.yaml b/tests/inference/extra_model_paths.yaml new file mode 100644 index 00000000000..75b2e1ae4a6 --- /dev/null +++ b/tests/inference/extra_model_paths.yaml @@ -0,0 +1,4 @@ +# Config for testing nodes +testing: + custom_nodes: tests/inference/testing_nodes + From 64e3a43527a5f0b7a6f906ccbff947f134a282a0 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Thu, 1 Aug 2024 20:04:49 -0700 Subject: [PATCH 26/30] Clean up Javascript code as per review --- web/scripts/ui.js | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index f35e2e0c33d..05258e378c7 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -241,11 +241,8 @@ class ComfyList { if (item.outputs) { app.nodeOutputs = {}; for (const [key, value] of Object.entries(item.outputs)) { - if (item.meta && item.meta[key] && item.meta[key].display_node) { - app.nodeOutputs[item.meta[key].display_node] = value; - } else { - app.nodeOutputs[key] = value; - } + const realKey = item?.meta?.[key]?.display_node ?? key; + app.nodeOutputs[realKey] = value; } } }, From bb5de4dfd4beef06af6666b8ae950329a0e9a8d5 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Thu, 1 Aug 2024 20:05:14 -0700 Subject: [PATCH 27/30] Improve documentation Converted some comments to docstrings as per review --- comfy/graph.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/comfy/graph.py b/comfy/graph.py index 6671bc3a023..8980c693c50 100644 --- a/comfy/graph.py +++ b/comfy/graph.py @@ -136,9 +136,11 @@ def pop_node(self, unique_id): def is_empty(self): return len(self.pendingNodes) == 0 -# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, -# it can still be returned to the graph after having further dependencies added. class ExecutionList(TopologicalSort): + """ + ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, + it can still be returned to the graph after having further dependencies added. + """ def __init__(self, dynprompt, output_cache): super().__init__(dynprompt) self.output_cache = output_cache @@ -216,18 +218,20 @@ def get_nodes_in_cycle(self): to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] return list(blocked_by.keys()) -# Return this from a node and any users will be blocked with the given error message. -# If the message is None, execution will be blocked silently instead. -# Generally, you should avoid using this functionality unless absolutley necessary. Whenever it's -# possible, a lazy input will be more efficient and have a better user experience. -# This functionality is useful in two cases: -# 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node -# like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using -# lazy evaluation to let it conditionally disable itself.) -# 2. You have a node with multiple possible outputs, some of which are invalid and should not be used. -# (I would recommend not making nodes like this in the future -- instead, make multiple nodes with -# different outputs. Unfortunately, there are several popular existing nodes using this pattern.) class ExecutionBlocker: + """ + Return this from a node and any users will be blocked with the given error message. + If the message is None, execution will be blocked silently instead. + Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's + possible, a lazy input will be more efficient and have a better user experience. + This functionality is useful in two cases: + 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node + like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using + lazy evaluation to let it conditionally disable itself.) + 2. You have a node with multiple possible outputs, some of which are invalid and should not be used. + (I would recommend not making nodes like this in the future -- instead, make multiple nodes with + different outputs. Unfortunately, there are several popular existing nodes using this pattern.) + """ def __init__(self, message): self.message = message From c4666bf7dbf48e6f70ca53fb2add15744e78084d Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Thu, 1 Aug 2024 20:24:08 -0700 Subject: [PATCH 28/30] Add a new unit test for mixed lazy results This test validates that when an output list is fed to a lazy node, the node will properly evaluate previous nodes that are needed by any inputs to the lazy node. No code in the execution model has been changed. The test already passes. --- tests/inference/test_execution.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index e9f93797622..b9dec659831 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -413,6 +413,23 @@ def test_mixed_expansion_returns(self, client: ComfyClient, builder: GraphBuilde for i in range(3): assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white" + def test_mixed_lazy_results(self, client: ComfyClient, builder: GraphBuilder): + g = builder + val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0) + mask = g.node("StubMask", value=val_list.out(0), height=512, width=512, batch_size=1) + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + rebatch = g.node("RebatchImages", images=mix.out(0), batch_size=3) + output = g.node("SaveImage", images=rebatch.out(0)) + + result = client.run(g) + images = result.get_images(output) + assert len(images) == 3, "Should have 3 image" + assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be 0.0" + assert numpy.array(images[1]).min() == 127 and numpy.array(images[1]).max() == 127, "Second image should be 0.5" + assert numpy.array(images[2]).min() == 255 and numpy.array(images[2]).max() == 255, "Third image should be 1.0" + def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) From 36131f0a3b70ff0bde5fdc098b31abdc5de6f878 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Wed, 7 Aug 2024 22:22:50 -0700 Subject: [PATCH 29/30] Allow kwargs in VALIDATE_INPUTS functions When kwargs are used, validation is skipped for all inputs as if they had been mentioned explicitly. --- execution.py | 11 +++++--- tests/inference/test_execution.py | 16 +++++++++++ .../testing-pack/specific_tests.py | 27 +++++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/execution.py b/execution.py index 30c0b3834b6..3bffae8e94f 100644 --- a/execution.py +++ b/execution.py @@ -530,8 +530,11 @@ def validate_inputs(prompt, item, validated): valid = True validate_function_inputs = [] + validate_has_kwargs = False if hasattr(obj_class, "VALIDATE_INPUTS"): - validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args + argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS) + validate_function_inputs = argspec.args + validate_has_kwargs = argspec.varkw is not None received_types = {} for x in valid_inputs: @@ -641,7 +644,7 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if x not in validate_function_inputs: + if x not in validate_function_inputs and not validate_has_kwargs: if "min" in extra_info and val < extra_info["min"]: error = { "type": "value_smaller_than_min", @@ -695,11 +698,11 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if len(validate_function_inputs) > 0: + if len(validate_function_inputs) > 0 or validate_has_kwargs: input_data_all, _ = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: - if x in validate_function_inputs: + if x in validate_function_inputs or validate_has_kwargs: input_filtered[x] = input_data_all[x] if 'input_types' in validate_function_inputs: input_filtered['input_types'] = [received_types] diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index b9dec659831..8616ca1e8e8 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -312,6 +312,22 @@ def test_validation_error_edge4(self, test_type, test_value, expect_error, clien else: client.run(g) + @pytest.mark.parametrize("test_value1, test_value2, expect_error", [ + (0.0, 0.5, False), + (0.0, 5.0, False), + (0.0, 7.0, True) + ]) + def test_validation_error_kwargs(self, test_value1, test_value2, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + validation5 = g.node("TestCustomValidation5", input1=test_value1, input2=test_value2) + g.node("SaveImage", images=validation5.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index 8e8ce32ced8..5884cae0c5a 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -221,6 +221,31 @@ def VALIDATE_INPUTS(cls, input1, input2): return True +class TestCustomValidation5: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("FLOAT", {"min": 0.0, "max": 1.0}), + "input2": ("FLOAT", {"min": 0.0, "max": 1.0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation5" + + CATEGORY = "Testing/Nodes" + + def custom_validation5(self, input1, input2): + value = input1 * input2 + return (torch.ones([1, 512, 512, 3]) * value,) + + @classmethod + def VALIDATE_INPUTS(cls, **kwargs): + if kwargs['input2'] == 7.0: + return "7s are not allowed. I've never liked 7s." + return True + class TestDynamicDependencyCycle: @classmethod def INPUT_TYPES(cls): @@ -291,6 +316,7 @@ def mixed_expansion_returns(self, input1): "TestCustomValidation2": TestCustomValidation2, "TestCustomValidation3": TestCustomValidation3, "TestCustomValidation4": TestCustomValidation4, + "TestCustomValidation5": TestCustomValidation5, "TestDynamicDependencyCycle": TestDynamicDependencyCycle, "TestMixedExpansionReturns": TestMixedExpansionReturns, } @@ -303,6 +329,7 @@ def mixed_expansion_returns(self, input1): "TestCustomValidation2": "Custom Validation 2", "TestCustomValidation3": "Custom Validation 3", "TestCustomValidation4": "Custom Validation 4", + "TestCustomValidation5": "Custom Validation 5", "TestDynamicDependencyCycle": "Dynamic Dependency Cycle", "TestMixedExpansionReturns": "Mixed Expansion Returns", } From fd7229e51b3fb3a07372c1fec4e6e58017277a2f Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Wed, 7 Aug 2024 22:35:18 -0700 Subject: [PATCH 30/30] List cached nodes in `execution_cached` message This was previously just bugged in this PR. --- execution.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index 3bffae8e94f..ee675893328 100644 --- a/execution.py +++ b/execution.py @@ -466,15 +466,19 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) cache.clean_unused() - current_outputs = self.caches.outputs.all_node_ids() + cached_nodes = [] + for node_id in prompt: + if self.caches.outputs.get(node_id) is not None: + cached_nodes.append(node_id) comfy.model_management.cleanup_models(keep_clone_weights_loaded=True) self.add_message("execution_cached", - { "nodes": list(current_outputs) , "prompt_id": prompt_id}, + { "nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) pending_subgraph_results = {} executed = set() execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) + current_outputs = self.caches.outputs.all_node_ids() for node_id in list(execute_outputs): execution_list.add_node(node_id)