From 5171a888fb82e733a3ddfafa8f25a059c4a0c214 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20L=C3=B3pez=20Barr=C3=B3n?= Date: Sat, 14 Dec 2024 10:30:54 +1100 Subject: [PATCH] updated views._graph.Node internals to use port instead of plug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Christian López Barrón --- grill/views/_graph.py | 122 ++++++++++++++++---------------- grill/views/description.py | 22 +++--- tests/test_data/_mini_graph.dot | 6 +- tests/test_views.py | 33 ++++----- 4 files changed, 88 insertions(+), 95 deletions(-) diff --git a/grill/views/_graph.py b/grill/views/_graph.py index 3424fb66..a8db2d7b 100644 --- a/grill/views/_graph.py +++ b/grill/views/_graph.py @@ -38,7 +38,7 @@ # - Tooltip on nodes for _GraphViewer # - Context menu items # - Ability to move further in canvas after Nodes don't exist -# - when switching a node left to right with precise source layers, the source node plugs do not refresh if we're moving the target node +# - when switching a node left to right with precise source layers, the source node ports do not refresh if we're moving the target node # - refactor conditionals for _GraphSVGViewer from the description module @@ -64,19 +64,20 @@ def _adjust_graphviz_html_table_label(label): return label -def _get_html_table_from_fields(**fields): +def _get_html_table_from_ports(**ports): label = '' - for index, (port, text) in enumerate(fields.items()): + for index, (name, text) in enumerate(ports.items()): bgcolor = "white" if index % 2 == 0 else "#f0f6ff" # light blue text = f'{text}' - label += f"" + label += f"" label += "
{text}
{text}
" return label -def _get_plugs_from_label(label) -> dict[str, str]: +def _get_ports_from_label(label) -> dict[str, str]: if not label.startswith("{"): # Only for record labels. - raise ValueError(f"Label needs to start with '{{' to extract plugs from it, for example: '{{item|another_item}}'. Got label: '{label}'") + raise ValueError(f"Label needs to start with '{{' to extract ports from it, for example: '{{item|another_item}}'. Got label: '{label}'") + # see https://graphviz.org/doc/info/shapes.html#record fields = label.strip("{}").split("|") return dict(field.strip("<>").split(">", 1) for field in fields) @@ -93,12 +94,12 @@ def _dot_2_svg(sourcepath): class _Node(QtWidgets.QGraphicsTextItem): # Note: keep 'label' as an argument to use as much as possible as-is for clients to provide their own HTML style - def __init__(self, parent=None, label="", color="", fillcolor="", plugs: tuple = (), visible=True): + def __init__(self, parent=None, label="", color="", fillcolor="", ports: tuple = (), visible=True): super().__init__(parent) self._edges = [] - self._plugs = dict(zip(plugs, range(len(plugs)))) or {} # {identifier: index} - self._active_plugs_by_side = dict() # {index: {left[int]: {}, right[int]: {}} - self._plug_items = {} # {index: (QEllipse, QEllipse)} + self._ports = dict(zip(ports, range(len(ports)))) or {} # {identifier: index} + self._active_ports_by_side = dict() # {index: {left[int]: {}, right[int]: {}} + self._port_items = {} # {index: (QEllipse, QEllipse)} self._pen = QtGui.QPen(QtGui.QColor(color), 1, QtCore.Qt.SolidLine, QtCore.Qt.RoundCap, QtCore.Qt.RoundJoin) self._fillcolor = QtGui.QColor(fillcolor) self.setHtml("" + label) @@ -164,61 +165,62 @@ def itemChange(self, change: QtWidgets.QGraphicsItem.GraphicsItemChange, value): edge.adjust() return super().itemChange(change, value) - def _activatePlug(self, edge, plug_index, side, position): - if plug_index is None: + def _activatePort(self, edge, port, side, position): + if port is None: return # we're at the center, nothing to draw nor activate try: - plugs_by_side = self._active_plugs_by_side[plug_index] # {index: {left[int]: {}, right[int]: {}} - except KeyError: # first time we're activating a plug, so add a visual ellipse for it + ports_by_side = self._active_ports_by_side[port] # {index: {left[int]: {}, right[int]: {}} + except KeyError: # first time we're activating a port, so add a visual ellipse for it radius = 4 - def _add_plug_item(): + def _add_port_item(): item = QtWidgets.QGraphicsEllipseItem(-radius, -radius, 2 * radius, 2 * radius) item.setPen(_NO_PEN) self.scene().addItem(item) return item - self._plug_items[plug_index] = (_add_plug_item(), _add_plug_item()) - self._active_plugs_by_side[plug_index] = plugs_by_side = {0: dict(), 1: dict()} + self._port_items[port] = (_add_port_item(), _add_port_item()) + self._active_ports_by_side[port] = ports_by_side = {0: dict(), 1: dict()} - plugs_by_side[side][edge] = True + ports_by_side[side][edge] = True other_side = bool(not side) - inactive_plugs = plugs_by_side[other_side] - inactive_plugs.pop(edge, None) - plug_items = self._plug_items[plug_index] # {index: (QEllipse, QEllipse)} - if not inactive_plugs: - plug_items[other_side].setVisible(False) - this_item = plug_items[side] + inactive_ports = ports_by_side[other_side] + inactive_ports.pop(edge, None) + port_items = self._port_items[port] # {index: (QEllipse, QEllipse)} + if not inactive_ports: + port_items[other_side].setVisible(False) + this_item = port_items[side] this_item.setVisible(True) this_item.setBrush(edge._brush) - plug_items[side].setPos(position) + port_items[side].setPos(position) class _Edge(QtWidgets.QGraphicsItem): - def __init__(self, source: _Node, target: _Node, *, source_plug: int =None, target_plug: int =None, label="", color="", is_bidirectional=False, parent: QtWidgets.QGraphicsItem = None): + def __init__(self, source: _Node, target: _Node, *, source_port: int = None, target_port: int = None, label="", color="", is_bidirectional=False, parent: QtWidgets.QGraphicsItem = None): super().__init__(parent) source.add_edge(self) target.add_edge(self) self._source = source self._target = target - self._source_plug = source_plug - self._target_plug = target_plug - self._is_source_plugged = source_plug is not None - self._is_target_plugged = target_plug is not None + self._source_port = source_port + self._target_port = target_port + self._is_source_port_used = source_port is not None + self._is_target_port_used = target_port is not None self._is_cycle = is_cycle = source == target - self._plug_positions = plug_positions = {} + self._port_positions = port_positions = {} outer_shift = 10 # surrounding rect has ~5 px top and bottom - for node, plug, max_plug_idx in (source, source_plug, max(source._plugs.values(), default=0)), (target, target_plug, max(target._plugs.values(), default=0)): + # TODO: this is the main reason of why Node._ports has {port: index}. See if it can be removed + for node, port, max_port_idx in (source, source_port, max(source._ports.values(), default=0)), (target, target_port, max(target._ports.values(), default=0)): bounds = node.boundingRect() - if plug is None: - plug_positions[node, plug] = {None: QtCore.QPointF(bounds.right() - 5, bounds.height() / 2 - 20) if is_cycle else bounds.center()} + if port is None: + port_positions[node, port] = {None: QtCore.QPointF(bounds.right() - 5, bounds.height() / 2 - 20) if is_cycle else bounds.center()} continue - # max_plug_idx can be 0, so we add 1 since this needs to be 1-index based - port_size = (bounds.height() - outer_shift) / (max_plug_idx + 1) - y_pos = (plug * port_size) + (port_size / 2) + (outer_shift / 2) - plug_positions[node, plug] = { + # max_port_idx can be 0, so we add 1 since this needs to be 1-index based + port_size = (bounds.height() - outer_shift) / (max_port_idx + 1) + y_pos = (port * port_size) + (port_size / 2) + (outer_shift / 2) + port_positions[node, port] = { 0: QtCore.QPointF(0, y_pos), # left 1: QtCore.QPointF(bounds.right(), y_pos), # right } @@ -229,7 +231,7 @@ def __init__(self, source: _Node, target: _Node, *, source_plug: int =None, targ self._line = QtCore.QLineF() self.setZValue(-1) - self._spline_path = QtGui.QPainterPath() if (self._is_source_plugged or self._is_target_plugged) else None + self._spline_path = QtGui.QPainterPath() if (self._is_source_port_used or self._is_target_port_used) else None self._colors = colors = color.split(":") main_color = QtGui.QColor(colors[0]) @@ -265,8 +267,8 @@ def boundingRect(self) -> QtCore.QRectF: @property def _cycle_start_position(self): - if not self._is_source_plugged: - return self._source.pos() + self._plug_positions[self._source, self._source_plug][None] + if not self._is_source_port_used: + return self._source.pos() + self._port_positions[self._source, self._source_port][None] return self._line.p1() + QtCore.QPointF(-3, -31) @@ -279,14 +281,14 @@ def adjust(self): source_on_left = self._is_cycle or (self._source.boundingRect().center().x() + source_pos.x() < target_bounds.center().x() + target_pos.x()) - is_source_plugged = self._is_source_plugged - is_target_plugged = self._is_target_plugged - source_side = source_on_left if is_source_plugged else None - target_side = not source_side if is_target_plugged else None - source_point = source_pos + self._plug_positions[self._source, self._source_plug][source_side] - target_point = target_pos + self._plug_positions[self._target, self._target_plug][target_side] + is_source_port_used = self._is_source_port_used + is_target_port_used = self._is_target_port_used + source_side = source_on_left if is_source_port_used else None + target_side = not source_side if is_target_port_used else None + source_point = source_pos + self._port_positions[self._source, self._source_port][source_side] + target_point = target_pos + self._port_positions[self._target, self._target_port][target_side] - if not is_target_plugged: + if not is_target_port_used: line = QtCore.QLineF(source_point, target_point) if not self._spline_path and self._bidirectional_shift and source_point != target_point: # offset in case of bidirectional connections when we are not using splines (as lines would overlap) @@ -319,15 +321,15 @@ def adjust(self): falloff = (length / 100) ** 2 if length < 100 else 1 control_point_shift = (1 if source_on_left else -1) * 75 * falloff - control_point1 = source_point + QtCore.QPointF(control_point_shift, 0) if is_source_plugged else source_point - control_point2 = target_point + QtCore.QPointF(-control_point_shift, 0) if is_target_plugged else target_point + control_point1 = source_point + QtCore.QPointF(control_point_shift, 0) if is_source_port_used else source_point + control_point2 = target_point + QtCore.QPointF(-control_point_shift, 0) if is_target_port_used else target_point self._spline_path = QtGui.QPainterPath() self._spline_path.moveTo(source_point) self._spline_path.cubicTo(control_point1, control_point2, target_point) - self._source._activatePlug(self, self._source_plug, source_side, source_point) - self._target._activatePlug(self, self._target_plug, target_side, target_point) + self._source._activatePort(self, self._source_port, source_side, source_point) + self._target._activatePort(self, self._target_port, target_side, target_point) if self._label_text: self._label_text.setPos((source_point + target_point) / 2) @@ -621,20 +623,20 @@ def _load_graph(self, graph): def _add_node(nx_node): node_data = graph.nodes[nx_node] - plugs = node_data.get('plugs', ()) + ports = node_data.get('ports', ()) nodes_attrs = ChainMap(node_data, graph_node_attrs) if (shape := nodes_attrs.get('shape')) == 'record': try: label = node_data['label'] except KeyError: raise ValueError(f"'label' must be supplied when 'record' shape is set for node: '{nx_node}' with data: {node_data}") - if plugs: + if ports: raise ValueError(f"record 'shape' and 'ports' are mutually exclusive, pick one for node: '{nx_node}' with data: {node_data}") try: - plugs = _get_plugs_from_label(label) + ports = _get_ports_from_label(label) except ValueError as exc: raise ValueError(f"In order to use the 'record' shape, a record 'label' in the form of: '{{text1|text2}}' must be used") from exc - label = _get_html_table_from_fields(**plugs) + label = _get_html_table_from_ports(**ports) else: label = node_data.get('label') if shape in {'none', 'plaintext'}: @@ -648,7 +650,7 @@ def _add_node(nx_node): label=label, color=nodes_attrs.get("color", ""), fillcolor=nodes_attrs.get("fillcolor", "white"), - plugs=plugs, + ports=ports, visible=nodes_attrs.get('style', "") != "invis", ) item.linkActivated.connect(self._graph_url_changed) @@ -680,9 +682,9 @@ def _add_node(nx_node): color = edge_data.get('color', edge_color) label = edge_data.get('label', '') kwargs = dict() - if source._plugs or target._plugs: - kwargs['target_plug'] = target._plugs[edge_data['headport']] if edge_data.get('headport') is not None else None - kwargs['source_plug'] = source._plugs[edge_data['tailport']] if edge_data.get('tailport') is not None else None + if source._ports or target._ports: + kwargs['target_port'] = target._ports[edge_data['headport']] if edge_data.get('headport') is not None else None + kwargs['source_port'] = source._ports[edge_data['tailport']] if edge_data.get('tailport') is not None else None edge = _Edge(source, target, color=color, label=label, is_bidirectional=is_bidirectional, **kwargs) self.scene().addItem(edge) diff --git a/grill/views/description.py b/grill/views/description.py index 02151555..aae195fa 100644 --- a/grill/views/description.py +++ b/grill/views/description.py @@ -229,7 +229,7 @@ def _graph_from_connections(prim: Usd.Prim) -> nx.MultiDiGraph: graph.graph['edge'] = {"color": 'crimson'} all_nodes = dict() # {node_id: {graphviz_attr: value}} - edges = list() # [(source_node_id, target_node_id, {source_plug_name, target_plug_name, graphviz_attrs})] + edges = list() # [(source_node_id, target_node_id, {source_port_name, target_port_name, graphviz_attrs})] @cache def _get_node_id(api): @@ -240,7 +240,7 @@ def _add_edges(src_node, src_name, tgt_node, tgt_name): tooltip = f"{src_node}.{src_name} -> {tgt_node}.{tgt_name}" edges.append((src_node, tgt_node, {"tailport": src_name, "headport": tgt_name, "tooltip": tooltip})) - plug_colors = { + port_colors = { UsdShade.Input: outline_color, # blue UsdShade.Output: "#F08080" # "lightcoral", # pink } @@ -255,18 +255,18 @@ def traverse(api: UsdShade.ConnectableAPI): node_id = _get_node_id(current_prim) label = f'<' label += table_row.format(port="", color="white", text=f'{api.GetPrim().GetName()}') - plugs = [""] # port names for this node. Empty string is used to refer to the node itself (no port). - for plug in chain(api.GetInputs(), api.GetOutputs()): - plug_name = plug.GetBaseName() - sources, __ = plug.GetConnectedSources() # (valid, invalid): we care only about valid sources (index 0) - color = plug_colors[type(plug)] if isinstance(plug, UsdShade.Output) or sources else background_color - label += table_row.format(port=plug_name, color=color, text=f'{plug_name}') + ports = [""] # port names for this node. Empty string is used to refer to the node itself (no port). + for port in chain(api.GetInputs(), api.GetOutputs()): + port_name = port.GetBaseName() + sources, __ = port.GetConnectedSources() # (valid, invalid): we care only about valid sources (index 0) + color = port_colors[type(port)] if isinstance(port, UsdShade.Output) or sources else background_color + label += table_row.format(port=port_name, color=color, text=f'{port_name}') for source in sources: - _add_edges(_get_node_id(source.source.GetPrim()), source.sourceName, node_id, plug_name) + _add_edges(_get_node_id(source.source.GetPrim()), source.sourceName, node_id, port_name) traverse(source.source) - plugs.append(plug_name) + ports.append(port_name) label += '
>' - all_nodes[node_id] = dict(label=label, plugs=plugs) + all_nodes[node_id] = dict(label=label, ports=ports) traverse(connections_api) diff --git a/tests/test_data/_mini_graph.dot b/tests/test_data/_mini_graph.dot index eee049be..ec9bea04 100644 --- a/tests/test_data/_mini_graph.dot +++ b/tests/test_data/_mini_graph.dot @@ -7,8 +7,8 @@ edge [color=crimson]; parent [shape=box, fillcolor="#afd7ff", color="#1E90FF", style="filled,rounded"]; child1 [shape=box, fillcolor="#afd7ff", color="#1E90FF", style="filled,rounded"]; child2 [shape=box, fillcolor="#afd7ff", color="#1E90FF", style=invis]; -ancestor [shape=none, label=<
ancestor
cycle_in
roughness
cycle_out
surface
>]; -successor [shape=none, label=<
successor
surface
>]; +ancestor [ports="('', 'cycle_in', 'roughness', 'cycle_out', 'surface')", shape=none, label=<
ancestor
cycle_in
roughness
cycle_out
surface
>]; +successor [ports="('', 'surface')", shape=none, label=<
successor
surface
>]; 1 -> 1 [key=0, color="sienna:crimson:orange"]; 1 -> 2 [key=0, color=crimson]; 2 -> 1 [key=0, color=seagreen]; @@ -18,4 +18,4 @@ parent -> child1 [key=0]; parent -> child2 [key=0, label=invis]; ancestor -> ancestor [key=0, tailport="cycle_out", headport="cycle_in", tooltip="ancestor.cycle_out -> ancestor.cycle_in"]; ancestor -> successor [key=0, tailport=surface, headport=surface, tooltip="ancestor.surface -> successor.surface"]; -} +} \ No newline at end of file diff --git a/tests/test_views.py b/tests/test_views.py index afc8c90e..bb1a8029 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -124,8 +124,8 @@ def test_connection_view(self): viewer._graph_view.view = lambda indices: None viewer.setPrim(material) graph = viewer._graph_view._graph - self.assertEqual(graph.nodes[str(material.GetPrim().GetPath())]['plugs'], ['', surface_name]) - self.assertEqual(graph.nodes[str(pbrShader.GetPrim().GetPath())]['plugs'], ['', cycle_input.GetBaseName(), roughness_name, cycle_output.GetBaseName(), surface_name]) + self.assertEqual(graph.nodes[str(material.GetPrim().GetPath())]['ports'], ['', surface_name]) + self.assertEqual(graph.nodes[str(pbrShader.GetPrim().GetPath())]['ports'], ['', cycle_input.GetBaseName(), roughness_name, cycle_output.GetBaseName(), surface_name]) viewer.setPrim(None) def test_scenegraph_composition(self): @@ -549,7 +549,7 @@ def test_graph_views(self): (dict(shape='record'), "'label' must be supplied"), (dict(shape='record', label='no record'), "a record 'label' in the form of"), (dict(shape='record', label='{1}'), "a record 'label' in the form of"), - (dict(shape='record', label='{<0>1}', plugs={'first': 1, 'second': 2}), "record 'shape' and 'ports' are mutually exclusive"), + (dict(shape='record', label='{<0>1}', ports=('first', 'second')), "record 'shape' and 'ports' are mutually exclusive"), (dict(shape='none'), "A label must be provided"), ): invalid_graph = _graph.nx.MultiDiGraph() @@ -607,13 +607,7 @@ def test_graph_views(self): connection_nodes = dict( ancestor=dict( - plugs={ - '': 0, - 'cycle_in': 1, - 'roughness': 2, - 'cycle_out': 3, - 'surface': 4 - }, + ports=('', 'cycle_in', 'roughness', 'cycle_out', 'surface'), shape='none', connections=dict( surface=[('successor', 'surface')], @@ -621,7 +615,7 @@ def test_graph_views(self): ), ), successor=dict( - plugs={'': 0, 'surface': 1}, + ports=('', 'surface'), shape='none', connections=dict(), ) @@ -636,19 +630,16 @@ def _add_edges(src_node, src_name, tgt_node, tgt_name): label = f'<' label += table_row.format(port="", color="white", text=f'{node}') - # for index, plug in enumerate(data['plugs'], start=1): # we start at 1 because index 0 is the node itself - for plug, index in data['plugs'].items(): # we start at 1 because index 0 is the node itself - if not plug: + for port in data['ports']: + if not port: continue - plug_name = plug - sources = data['connections'].get(plug, []) # (valid, invalid): we care only about valid sources (index 0) + sources = data['connections'].get(port, []) # (valid, invalid): we care only about valid sources (index 0) color = r"#F08080" if sources else background_color - # color = plug_colors[type(plug)] if isinstance(plug, UsdShade.Output) or sources else background_color - label += table_row.format(port=plug_name, color=color, text=f'{plug_name}') - for source_node, source_plug in sources: - # node_id='ancestor', plug_name='cycle_out', ancestor, source.sourceName='cycle_in' + label += table_row.format(port=port, color=color, text=f'{port}') + for source_node, source_port in sources: + # node_id='ancestor', port_name='cycle_out', ancestor, source.sourceName='cycle_in' # tooltip='/TexModel/boardMat/PBRShader.cycle_in -> /TexModel/boardMat/PBRShader.cycle_out' - _add_edges(node, plug_name, source_node, source_plug) + _add_edges(node, port, source_node, source_port) label += '
>' data['label'] = label