From 37cabe3091e4395fd77752e4d551204391c20c73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20L=C3=B3pez=20Barr=C3=B3n?= Date: Sat, 14 Dec 2024 09:57:29 +1100 Subject: [PATCH] test port existence in connection view MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Christian López Barrón --- grill/views/_graph.py | 2 +- grill/views/description.py | 6 +++--- tests/test_views.py | 9 +++++++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/grill/views/_graph.py b/grill/views/_graph.py index 5d50371c..3424fb66 100644 --- a/grill/views/_graph.py +++ b/grill/views/_graph.py @@ -621,7 +621,7 @@ def _load_graph(self, graph): def _add_node(nx_node): node_data = graph.nodes[nx_node] - plugs = node_data.pop('plugs', ()) # implementation detail + plugs = node_data.get('plugs', ()) nodes_attrs = ChainMap(node_data, graph_node_attrs) if (shape := nodes_attrs.get('shape')) == 'record': try: diff --git a/grill/views/description.py b/grill/views/description.py index f26227de..02151555 100644 --- a/grill/views/description.py +++ b/grill/views/description.py @@ -255,8 +255,8 @@ def traverse(api: UsdShade.ConnectableAPI): node_id = _get_node_id(current_prim) label = f'<' label += table_row.format(port="", color="white", text=f'{api.GetPrim().GetName()}') - plugs = {"": 0} # {graphviz port name: port index order} - for index, plug in enumerate(chain(api.GetInputs(), api.GetOutputs()), start=1): # we start at 1 because index 0 is the node itself + plugs = [""] # port names for this node. Empty string is used to refer to the node itself (no port). + for plug in chain(api.GetInputs(), api.GetOutputs()): plug_name = plug.GetBaseName() sources, __ = plug.GetConnectedSources() # (valid, invalid): we care only about valid sources (index 0) color = plug_colors[type(plug)] if isinstance(plug, UsdShade.Output) or sources else background_color @@ -264,7 +264,7 @@ def traverse(api: UsdShade.ConnectableAPI): for source in sources: _add_edges(_get_node_id(source.source.GetPrim()), source.sourceName, node_id, plug_name) traverse(source.source) - plugs[plug_name] = index + plugs.append(plug_name) label += '
>' all_nodes[node_id] = dict(label=label, plugs=plugs) diff --git a/tests/test_views.py b/tests/test_views.py index 1b8cb95e..f60b6e88 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -111,8 +111,10 @@ def test_connection_view(self): stage = Usd.Stage.CreateInMemory() material = UsdShade.Material.Define(stage, '/TexModel/boardMat') pbrShader = UsdShade.Shader.Define(stage, '/TexModel/boardMat/PBRShader') - pbrShader.CreateInput("roughness", Sdf.ValueTypeNames.Float).Set(0.4) - material.CreateSurfaceOutput().ConnectToSource(pbrShader.ConnectableAPI(), "surface") + roughness_name = "roughness" + pbrShader.CreateInput(roughness_name, Sdf.ValueTypeNames.Float).Set(0.4) + surface_name = "surface" + material.CreateSurfaceOutput().ConnectToSource(pbrShader.ConnectableAPI(), surface_name) # Ensure cycles don't cause recursion cycle_input = pbrShader.CreateInput("cycle_in", Sdf.ValueTypeNames.Float) cycle_output = pbrShader.CreateOutput("cycle_out", Sdf.ValueTypeNames.Float) @@ -121,6 +123,9 @@ def test_connection_view(self): # GraphView capabilities are tested elsewhere, so mock 'view' here. viewer._graph_view.view = lambda indices: None viewer.setPrim(material) + graph = viewer._graph_view._graph + self.assertEqual(graph.nodes[str(material.GetPrim().GetPath())]['plugs'], ['', surface_name]) + self.assertEqual(graph.nodes[str(pbrShader.GetPrim().GetPath())]['plugs'], ['', cycle_input.GetName(), roughness_name, cycle_output.GetName(), surface_name]) viewer.setPrim(None) def test_scenegraph_composition(self):