Skip to content

Commit

Permalink
test port existence in connection view
Browse files Browse the repository at this point in the history
Signed-off-by: Christian López Barrón <[email protected]>
  • Loading branch information
chrizzFTD committed Dec 13, 2024
1 parent e74af59 commit 37cabe3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion grill/views/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def _load_graph(self, graph):

def _add_node(nx_node):
node_data = graph.nodes[nx_node]
plugs = node_data.pop('plugs', ()) # implementation detail
plugs = node_data.get('plugs', ())
nodes_attrs = ChainMap(node_data, graph_node_attrs)
if (shape := nodes_attrs.get('shape')) == 'record':
try:
Expand Down
6 changes: 3 additions & 3 deletions grill/views/description.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,16 +255,16 @@ def traverse(api: UsdShade.ConnectableAPI):
node_id = _get_node_id(current_prim)
label = f'<<table border="1" cellspacing="2" style="ROUNDED" bgcolor="{background_color}" color="{outline_color}">'
label += table_row.format(port="", color="white", text=f'<font color="{outline_color}"><b>{api.GetPrim().GetName()}</b></font>')
plugs = {"": 0} # {graphviz port name: port index order}
for index, plug in enumerate(chain(api.GetInputs(), api.GetOutputs()), start=1): # we start at 1 because index 0 is the node itself
plugs = [""] # port names for this node. Empty string is used to refer to the node itself (no port).
for plug in chain(api.GetInputs(), api.GetOutputs()):
plug_name = plug.GetBaseName()
sources, __ = plug.GetConnectedSources() # (valid, invalid): we care only about valid sources (index 0)
color = plug_colors[type(plug)] if isinstance(plug, UsdShade.Output) or sources else background_color
label += table_row.format(port=plug_name, color=color, text=f'<font color="#242828">{plug_name}</font>')
for source in sources:
_add_edges(_get_node_id(source.source.GetPrim()), source.sourceName, node_id, plug_name)
traverse(source.source)
plugs[plug_name] = index
plugs.append(plug_name)
label += '</table>>'
all_nodes[node_id] = dict(label=label, plugs=plugs)

Expand Down
9 changes: 7 additions & 2 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ def test_connection_view(self):
stage = Usd.Stage.CreateInMemory()
material = UsdShade.Material.Define(stage, '/TexModel/boardMat')
pbrShader = UsdShade.Shader.Define(stage, '/TexModel/boardMat/PBRShader')
pbrShader.CreateInput("roughness", Sdf.ValueTypeNames.Float).Set(0.4)
material.CreateSurfaceOutput().ConnectToSource(pbrShader.ConnectableAPI(), "surface")
roughness_name = "roughness"
pbrShader.CreateInput(roughness_name, Sdf.ValueTypeNames.Float).Set(0.4)
surface_name = "surface"
material.CreateSurfaceOutput().ConnectToSource(pbrShader.ConnectableAPI(), surface_name)
# Ensure cycles don't cause recursion
cycle_input = pbrShader.CreateInput("cycle_in", Sdf.ValueTypeNames.Float)
cycle_output = pbrShader.CreateOutput("cycle_out", Sdf.ValueTypeNames.Float)
Expand All @@ -121,6 +123,9 @@ def test_connection_view(self):
# GraphView capabilities are tested elsewhere, so mock 'view' here.
viewer._graph_view.view = lambda indices: None
viewer.setPrim(material)
graph = viewer._graph_view._graph
self.assertEqual(graph.nodes[str(material.GetPrim().GetPath())]['plugs'], ['', surface_name])
self.assertEqual(graph.nodes[str(pbrShader.GetPrim().GetPath())]['plugs'], ['', cycle_input.GetName(), roughness_name, cycle_output.GetName(), surface_name])
viewer.setPrim(None)

def test_scenegraph_composition(self):
Expand Down

0 comments on commit 37cabe3

Please sign in to comment.