Skip to content

Commit

Permalink
updated views._graph.Node internals to use port instead of plug
Browse files Browse the repository at this point in the history
Signed-off-by: Christian López Barrón <[email protected]>
  • Loading branch information
chrizzFTD committed Dec 13, 2024
1 parent b7c3fa0 commit 5171a88
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 95 deletions.
122 changes: 62 additions & 60 deletions grill/views/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# - Tooltip on nodes for _GraphViewer
# - Context menu items
# - Ability to move further in canvas after Nodes don't exist
# - when switching a node left to right with precise source layers, the source node plugs do not refresh if we're moving the target node
# - when switching a node left to right with precise source layers, the source node ports do not refresh if we're moving the target node
# - refactor conditionals for _GraphSVGViewer from the description module


Expand All @@ -64,19 +64,20 @@ def _adjust_graphviz_html_table_label(label):
return label


def _get_html_table_from_fields(**fields):
def _get_html_table_from_ports(**ports):
label = '<table>'
for index, (port, text) in enumerate(fields.items()):
for index, (name, text) in enumerate(ports.items()):
bgcolor = "white" if index % 2 == 0 else "#f0f6ff" # light blue
text = f'<font color="#242828">{text}</font>'
label += f"<tr><td port='{port}' bgcolor='{bgcolor}'>{text}</td></tr>"
label += f"<tr><td port='{name}' bgcolor='{bgcolor}'>{text}</td></tr>"
label += "</table>"
return label


def _get_plugs_from_label(label) -> dict[str, str]:
def _get_ports_from_label(label) -> dict[str, str]:
if not label.startswith("{"): # Only for record labels.
raise ValueError(f"Label needs to start with '{{' to extract plugs from it, for example: '{{<plug1>item|<plug2>another_item}}'. Got label: '{label}'")
raise ValueError(f"Label needs to start with '{{' to extract ports from it, for example: '{{<port1>item|<port2>another_item}}'. Got label: '{label}'")
# see https://graphviz.org/doc/info/shapes.html#record
fields = label.strip("{}").split("|")
return dict(field.strip("<>").split(">", 1) for field in fields)

Expand All @@ -93,12 +94,12 @@ def _dot_2_svg(sourcepath):
class _Node(QtWidgets.QGraphicsTextItem):

# Note: keep 'label' as an argument to use as much as possible as-is for clients to provide their own HTML style
def __init__(self, parent=None, label="", color="", fillcolor="", plugs: tuple = (), visible=True):
def __init__(self, parent=None, label="", color="", fillcolor="", ports: tuple = (), visible=True):
super().__init__(parent)
self._edges = []
self._plugs = dict(zip(plugs, range(len(plugs)))) or {} # {identifier: index}
self._active_plugs_by_side = dict() # {index: {left[int]: {}, right[int]: {}}
self._plug_items = {} # {index: (QEllipse, QEllipse)}
self._ports = dict(zip(ports, range(len(ports)))) or {} # {identifier: index}
self._active_ports_by_side = dict() # {index: {left[int]: {}, right[int]: {}}
self._port_items = {} # {index: (QEllipse, QEllipse)}
self._pen = QtGui.QPen(QtGui.QColor(color), 1, QtCore.Qt.SolidLine, QtCore.Qt.RoundCap, QtCore.Qt.RoundJoin)
self._fillcolor = QtGui.QColor(fillcolor)
self.setHtml("<style>th, td {text-align: center;padding: 3px}</style>" + label)
Expand Down Expand Up @@ -164,61 +165,62 @@ def itemChange(self, change: QtWidgets.QGraphicsItem.GraphicsItemChange, value):
edge.adjust()
return super().itemChange(change, value)

def _activatePlug(self, edge, plug_index, side, position):
if plug_index is None:
def _activatePort(self, edge, port, side, position):
if port is None:
return # we're at the center, nothing to draw nor activate
try:
plugs_by_side = self._active_plugs_by_side[plug_index] # {index: {left[int]: {}, right[int]: {}}
except KeyError: # first time we're activating a plug, so add a visual ellipse for it
ports_by_side = self._active_ports_by_side[port] # {index: {left[int]: {}, right[int]: {}}
except KeyError: # first time we're activating a port, so add a visual ellipse for it
radius = 4

def _add_plug_item():
def _add_port_item():
item = QtWidgets.QGraphicsEllipseItem(-radius, -radius, 2 * radius, 2 * radius)
item.setPen(_NO_PEN)
self.scene().addItem(item)
return item

self._plug_items[plug_index] = (_add_plug_item(), _add_plug_item())
self._active_plugs_by_side[plug_index] = plugs_by_side = {0: dict(), 1: dict()}
self._port_items[port] = (_add_port_item(), _add_port_item())
self._active_ports_by_side[port] = ports_by_side = {0: dict(), 1: dict()}

plugs_by_side[side][edge] = True
ports_by_side[side][edge] = True
other_side = bool(not side)
inactive_plugs = plugs_by_side[other_side]
inactive_plugs.pop(edge, None)
plug_items = self._plug_items[plug_index] # {index: (QEllipse, QEllipse)}
if not inactive_plugs:
plug_items[other_side].setVisible(False)
this_item = plug_items[side]
inactive_ports = ports_by_side[other_side]
inactive_ports.pop(edge, None)
port_items = self._port_items[port] # {index: (QEllipse, QEllipse)}
if not inactive_ports:
port_items[other_side].setVisible(False)
this_item = port_items[side]
this_item.setVisible(True)
this_item.setBrush(edge._brush)
plug_items[side].setPos(position)
port_items[side].setPos(position)


class _Edge(QtWidgets.QGraphicsItem):
def __init__(self, source: _Node, target: _Node, *, source_plug: int =None, target_plug: int =None, label="", color="", is_bidirectional=False, parent: QtWidgets.QGraphicsItem = None):
def __init__(self, source: _Node, target: _Node, *, source_port: int = None, target_port: int = None, label="", color="", is_bidirectional=False, parent: QtWidgets.QGraphicsItem = None):
super().__init__(parent)
source.add_edge(self)
target.add_edge(self)
self._source = source
self._target = target
self._source_plug = source_plug
self._target_plug = target_plug
self._is_source_plugged = source_plug is not None
self._is_target_plugged = target_plug is not None
self._source_port = source_port
self._target_port = target_port
self._is_source_port_used = source_port is not None
self._is_target_port_used = target_port is not None
self._is_cycle = is_cycle = source == target

self._plug_positions = plug_positions = {}
self._port_positions = port_positions = {}
outer_shift = 10 # surrounding rect has ~5 px top and bottom

for node, plug, max_plug_idx in (source, source_plug, max(source._plugs.values(), default=0)), (target, target_plug, max(target._plugs.values(), default=0)):
# TODO: this is the main reason of why Node._ports has {port: index}. See if it can be removed
for node, port, max_port_idx in (source, source_port, max(source._ports.values(), default=0)), (target, target_port, max(target._ports.values(), default=0)):
bounds = node.boundingRect()
if plug is None:
plug_positions[node, plug] = {None: QtCore.QPointF(bounds.right() - 5, bounds.height() / 2 - 20) if is_cycle else bounds.center()}
if port is None:
port_positions[node, port] = {None: QtCore.QPointF(bounds.right() - 5, bounds.height() / 2 - 20) if is_cycle else bounds.center()}
continue
# max_plug_idx can be 0, so we add 1 since this needs to be 1-index based
port_size = (bounds.height() - outer_shift) / (max_plug_idx + 1)
y_pos = (plug * port_size) + (port_size / 2) + (outer_shift / 2)
plug_positions[node, plug] = {
# max_port_idx can be 0, so we add 1 since this needs to be 1-index based
port_size = (bounds.height() - outer_shift) / (max_port_idx + 1)
y_pos = (port * port_size) + (port_size / 2) + (outer_shift / 2)
port_positions[node, port] = {
0: QtCore.QPointF(0, y_pos), # left
1: QtCore.QPointF(bounds.right(), y_pos), # right
}
Expand All @@ -229,7 +231,7 @@ def __init__(self, source: _Node, target: _Node, *, source_plug: int =None, targ
self._line = QtCore.QLineF()
self.setZValue(-1)

self._spline_path = QtGui.QPainterPath() if (self._is_source_plugged or self._is_target_plugged) else None
self._spline_path = QtGui.QPainterPath() if (self._is_source_port_used or self._is_target_port_used) else None

self._colors = colors = color.split(":")
main_color = QtGui.QColor(colors[0])
Expand Down Expand Up @@ -265,8 +267,8 @@ def boundingRect(self) -> QtCore.QRectF:

@property
def _cycle_start_position(self):
if not self._is_source_plugged:
return self._source.pos() + self._plug_positions[self._source, self._source_plug][None]
if not self._is_source_port_used:
return self._source.pos() + self._port_positions[self._source, self._source_port][None]

return self._line.p1() + QtCore.QPointF(-3, -31)

Expand All @@ -279,14 +281,14 @@ def adjust(self):

source_on_left = self._is_cycle or (self._source.boundingRect().center().x() + source_pos.x() < target_bounds.center().x() + target_pos.x())

is_source_plugged = self._is_source_plugged
is_target_plugged = self._is_target_plugged
source_side = source_on_left if is_source_plugged else None
target_side = not source_side if is_target_plugged else None
source_point = source_pos + self._plug_positions[self._source, self._source_plug][source_side]
target_point = target_pos + self._plug_positions[self._target, self._target_plug][target_side]
is_source_port_used = self._is_source_port_used
is_target_port_used = self._is_target_port_used
source_side = source_on_left if is_source_port_used else None
target_side = not source_side if is_target_port_used else None
source_point = source_pos + self._port_positions[self._source, self._source_port][source_side]
target_point = target_pos + self._port_positions[self._target, self._target_port][target_side]

if not is_target_plugged:
if not is_target_port_used:
line = QtCore.QLineF(source_point, target_point)
if not self._spline_path and self._bidirectional_shift and source_point != target_point:
# offset in case of bidirectional connections when we are not using splines (as lines would overlap)
Expand Down Expand Up @@ -319,15 +321,15 @@ def adjust(self):
falloff = (length / 100) ** 2 if length < 100 else 1
control_point_shift = (1 if source_on_left else -1) * 75 * falloff

control_point1 = source_point + QtCore.QPointF(control_point_shift, 0) if is_source_plugged else source_point
control_point2 = target_point + QtCore.QPointF(-control_point_shift, 0) if is_target_plugged else target_point
control_point1 = source_point + QtCore.QPointF(control_point_shift, 0) if is_source_port_used else source_point
control_point2 = target_point + QtCore.QPointF(-control_point_shift, 0) if is_target_port_used else target_point

self._spline_path = QtGui.QPainterPath()
self._spline_path.moveTo(source_point)
self._spline_path.cubicTo(control_point1, control_point2, target_point)

self._source._activatePlug(self, self._source_plug, source_side, source_point)
self._target._activatePlug(self, self._target_plug, target_side, target_point)
self._source._activatePort(self, self._source_port, source_side, source_point)
self._target._activatePort(self, self._target_port, target_side, target_point)
if self._label_text:
self._label_text.setPos((source_point + target_point) / 2)

Expand Down Expand Up @@ -621,20 +623,20 @@ def _load_graph(self, graph):

def _add_node(nx_node):
node_data = graph.nodes[nx_node]
plugs = node_data.get('plugs', ())
ports = node_data.get('ports', ())
nodes_attrs = ChainMap(node_data, graph_node_attrs)
if (shape := nodes_attrs.get('shape')) == 'record':
try:
label = node_data['label']
except KeyError:
raise ValueError(f"'label' must be supplied when 'record' shape is set for node: '{nx_node}' with data: {node_data}")
if plugs:
if ports:
raise ValueError(f"record 'shape' and 'ports' are mutually exclusive, pick one for node: '{nx_node}' with data: {node_data}")
try:
plugs = _get_plugs_from_label(label)
ports = _get_ports_from_label(label)
except ValueError as exc:
raise ValueError(f"In order to use the 'record' shape, a record 'label' in the form of: '{{<port1>text1|<port2>text2}}' must be used") from exc
label = _get_html_table_from_fields(**plugs)
label = _get_html_table_from_ports(**ports)
else:
label = node_data.get('label')
if shape in {'none', 'plaintext'}:
Expand All @@ -648,7 +650,7 @@ def _add_node(nx_node):
label=label,
color=nodes_attrs.get("color", ""),
fillcolor=nodes_attrs.get("fillcolor", "white"),
plugs=plugs,
ports=ports,
visible=nodes_attrs.get('style', "") != "invis",
)
item.linkActivated.connect(self._graph_url_changed)
Expand Down Expand Up @@ -680,9 +682,9 @@ def _add_node(nx_node):
color = edge_data.get('color', edge_color)
label = edge_data.get('label', '')
kwargs = dict()
if source._plugs or target._plugs:
kwargs['target_plug'] = target._plugs[edge_data['headport']] if edge_data.get('headport') is not None else None
kwargs['source_plug'] = source._plugs[edge_data['tailport']] if edge_data.get('tailport') is not None else None
if source._ports or target._ports:
kwargs['target_port'] = target._ports[edge_data['headport']] if edge_data.get('headport') is not None else None
kwargs['source_port'] = source._ports[edge_data['tailport']] if edge_data.get('tailport') is not None else None

edge = _Edge(source, target, color=color, label=label, is_bidirectional=is_bidirectional, **kwargs)
self.scene().addItem(edge)
Expand Down
22 changes: 11 additions & 11 deletions grill/views/description.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _graph_from_connections(prim: Usd.Prim) -> nx.MultiDiGraph:
graph.graph['edge'] = {"color": 'crimson'}

all_nodes = dict() # {node_id: {graphviz_attr: value}}
edges = list() # [(source_node_id, target_node_id, {source_plug_name, target_plug_name, graphviz_attrs})]
edges = list() # [(source_node_id, target_node_id, {source_port_name, target_port_name, graphviz_attrs})]

@cache
def _get_node_id(api):
Expand All @@ -240,7 +240,7 @@ def _add_edges(src_node, src_name, tgt_node, tgt_name):
tooltip = f"{src_node}.{src_name} -> {tgt_node}.{tgt_name}"
edges.append((src_node, tgt_node, {"tailport": src_name, "headport": tgt_name, "tooltip": tooltip}))

plug_colors = {
port_colors = {
UsdShade.Input: outline_color, # blue
UsdShade.Output: "#F08080" # "lightcoral", # pink
}
Expand All @@ -255,18 +255,18 @@ def traverse(api: UsdShade.ConnectableAPI):
node_id = _get_node_id(current_prim)
label = f'<<table border="1" cellspacing="2" style="ROUNDED" bgcolor="{background_color}" color="{outline_color}">'
label += table_row.format(port="", color="white", text=f'<font color="{outline_color}"><b>{api.GetPrim().GetName()}</b></font>')
plugs = [""] # port names for this node. Empty string is used to refer to the node itself (no port).
for plug in chain(api.GetInputs(), api.GetOutputs()):
plug_name = plug.GetBaseName()
sources, __ = plug.GetConnectedSources() # (valid, invalid): we care only about valid sources (index 0)
color = plug_colors[type(plug)] if isinstance(plug, UsdShade.Output) or sources else background_color
label += table_row.format(port=plug_name, color=color, text=f'<font color="#242828">{plug_name}</font>')
ports = [""] # port names for this node. Empty string is used to refer to the node itself (no port).
for port in chain(api.GetInputs(), api.GetOutputs()):
port_name = port.GetBaseName()
sources, __ = port.GetConnectedSources() # (valid, invalid): we care only about valid sources (index 0)
color = port_colors[type(port)] if isinstance(port, UsdShade.Output) or sources else background_color
label += table_row.format(port=port_name, color=color, text=f'<font color="#242828">{port_name}</font>')
for source in sources:
_add_edges(_get_node_id(source.source.GetPrim()), source.sourceName, node_id, plug_name)
_add_edges(_get_node_id(source.source.GetPrim()), source.sourceName, node_id, port_name)
traverse(source.source)
plugs.append(plug_name)
ports.append(port_name)
label += '</table>>'
all_nodes[node_id] = dict(label=label, plugs=plugs)
all_nodes[node_id] = dict(label=label, ports=ports)

traverse(connections_api)

Expand Down
Loading

0 comments on commit 5171a88

Please sign in to comment.