Skip to content

Commit

Permalink
feat(core): Update mermaid drawing to support subgraphs (#6917)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Oct 2, 2024
1 parent 0f370ba commit 9d2600d
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 149 deletions.
191 changes: 130 additions & 61 deletions langchain-core/src/runnables/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,31 @@ import type {
import { isRunnableInterface } from "./utils.js";
import { drawMermaid, drawMermaidPng } from "./graph_mermaid.js";

const MAX_DATA_DISPLAY_NAME_LENGTH = 42;

export { Node, Edge };

function nodeDataStr(node: Node): string {
if (!isUuid(node.id)) {
return node.id;
} else if (isRunnableInterface(node.data)) {
function nodeDataStr(
id: string | undefined,
data: RunnableInterface | RunnableIOSchema
): string {
if (id !== undefined && !isUuid(id)) {
return id;
} else if (isRunnableInterface(data)) {
try {
let data = node.data.getName();
data = data.startsWith("Runnable") ? data.slice("Runnable".length) : data;
if (data.length > MAX_DATA_DISPLAY_NAME_LENGTH) {
data = `${data.substring(0, MAX_DATA_DISPLAY_NAME_LENGTH)}...`;
}
return data;
let dataStr = data.getName();
dataStr = dataStr.startsWith("Runnable")
? dataStr.slice("Runnable".length)
: dataStr;
return dataStr;
} catch (error) {
return node.data.getName();
return data.getName();
}
} else {
return node.data.name ?? "UnknownSchema";
return data.name ?? "UnknownSchema";
}
}

function nodeDataJson(node: Node) {
// if node.data is implements Runnable
// if node.data implements Runnable
if (isRunnableInterface(node.data)) {
return {
type: "runnable",
Expand All @@ -55,6 +55,11 @@ export class Graph {

edges: Edge[] = [];

constructor(params?: { nodes: Record<string, Node>; edges: Edge[] }) {
this.nodes = params?.nodes ?? this.nodes;
this.edges = params?.edges ?? this.edges;
}

// Convert the graph to a JSON-serializable format.
// eslint-disable-next-line @typescript-eslint/no-explicit-any
toJSON(): Record<string, any> {
Expand Down Expand Up @@ -86,12 +91,22 @@ export class Graph {
};
}

addNode(data: RunnableInterface | RunnableIOSchema, id?: string): Node {
addNode(
data: RunnableInterface | RunnableIOSchema,
id?: string,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
metadata?: Record<string, any>
): Node {
if (id !== undefined && this.nodes[id] !== undefined) {
throw new Error(`Node with id ${id} already exists`);
}
const nodeId = id || uuidv4();
const node: Node = { id: nodeId, data };
const nodeId = id ?? uuidv4();
const node: Node = {
id: nodeId,
data,
name: nodeDataStr(id, data),
metadata,
};
this.nodes[nodeId] = node;
return node;
}
Expand Down Expand Up @@ -129,25 +144,11 @@ export class Graph {
}

firstNode(): Node | undefined {
const targets = new Set(this.edges.map((edge) => edge.target));
const found: Node[] = [];
Object.values(this.nodes).forEach((node) => {
if (!targets.has(node.id)) {
found.push(node);
}
});
return found[0];
return _firstNode(this);
}

lastNode(): Node | undefined {
const sources = new Set(this.edges.map((edge) => edge.source));
const found: Node[] = [];
Object.values(this.nodes).forEach((node) => {
if (!sources.has(node.id)) {
found.push(node);
}
});
return found[0];
return _lastNode(this);
}

/**
Expand Down Expand Up @@ -188,28 +189,55 @@ export class Graph {

trimFirstNode(): void {
const firstNode = this.firstNode();
if (firstNode) {
const outgoingEdges = this.edges.filter(
(edge) => edge.source === firstNode.id
);
if (Object.keys(this.nodes).length === 1 || outgoingEdges.length === 1) {
this.removeNode(firstNode);
}
if (firstNode && _firstNode(this, [firstNode.id])) {
this.removeNode(firstNode);
}
}

trimLastNode(): void {
const lastNode = this.lastNode();
if (lastNode) {
const incomingEdges = this.edges.filter(
(edge) => edge.target === lastNode.id
);
if (Object.keys(this.nodes).length === 1 || incomingEdges.length === 1) {
this.removeNode(lastNode);
}
if (lastNode && _lastNode(this, [lastNode.id])) {
this.removeNode(lastNode);
}
}

/**
* Return a new graph with all nodes re-identified,
* using their unique, readable names where possible.
*/
reid(): Graph {
const nodeLabels: Record<string, string> = Object.fromEntries(
Object.values(this.nodes).map((node) => [node.id, node.name])
);
const nodeLabelCounts = new Map<string, number>();
Object.values(nodeLabels).forEach((label) => {
nodeLabelCounts.set(label, (nodeLabelCounts.get(label) || 0) + 1);
});

const getNodeId = (nodeId: string): string => {
const label = nodeLabels[nodeId];
if (isUuid(nodeId) && nodeLabelCounts.get(label) === 1) {
return label;
} else {
return nodeId;
}
};

return new Graph({
nodes: Object.fromEntries(
Object.entries(this.nodes).map(([id, node]) => [
getNodeId(id),
{ ...node, id: getNodeId(id) },
])
),
edges: this.edges.map((edge) => ({
...edge,
source: getNodeId(edge.source),
target: getNodeId(edge.target),
})),
});
}

drawMermaid(params?: {
withStyles?: boolean;
curveStyle?: string;
Expand All @@ -219,23 +247,21 @@ export class Graph {
const {
withStyles,
curveStyle,
nodeColors = { start: "#ffdfba", end: "#baffc9", other: "#fad7de" },
nodeColors = {
default: "fill:#f2f0ff,line-height:1.2",
first: "fill-opacity:0",
last: "fill:#bfb6fc",
},
wrapLabelNWords,
} = params ?? {};
const nodes: Record<string, string> = {};
for (const node of Object.values(this.nodes)) {
nodes[node.id] = nodeDataStr(node);
}
const graph = this.reid();
const firstNode = graph.firstNode();

const firstNode = this.firstNode();
const firstNodeLabel = firstNode ? nodeDataStr(firstNode) : undefined;

const lastNode = this.lastNode();
const lastNodeLabel = lastNode ? nodeDataStr(lastNode) : undefined;
const lastNode = graph.lastNode();

return drawMermaid(nodes, this.edges, {
firstNodeLabel,
lastNodeLabel,
return drawMermaid(graph.nodes, graph.edges, {
firstNode: firstNode?.id,
lastNode: lastNode?.id,
withStyles,
curveStyle,
nodeColors,
Expand All @@ -256,3 +282,46 @@ export class Graph {
});
}
}
/**
* Find the single node that is not a target of any edge.
* Exclude nodes/sources with ids in the exclude list.
* If there is no such node, or there are multiple, return undefined.
* When drawing the graph, this node would be the origin.
*/
function _firstNode(graph: Graph, exclude: string[] = []): Node | undefined {
const targets = new Set(
graph.edges
.filter((edge) => !exclude.includes(edge.source))
.map((edge) => edge.target)
);

const found: Node[] = [];
for (const node of Object.values(graph.nodes)) {
if (!exclude.includes(node.id) && !targets.has(node.id)) {
found.push(node);
}
}
return found.length === 1 ? found[0] : undefined;
}

/**
* Find the single node that is not a source of any edge.
* Exclude nodes/targets with ids in the exclude list.
* If there is no such node, or there are multiple, return undefined.
* When drawing the graph, this node would be the destination.
*/
function _lastNode(graph: Graph, exclude: string[] = []): Node | undefined {
const sources = new Set(
graph.edges
.filter((edge) => !exclude.includes(edge.target))
.map((edge) => edge.source)
);

const found: Node[] = [];
for (const node of Object.values(graph.nodes)) {
if (!exclude.includes(node.id) && !sources.has(node.id)) {
found.push(node);
}
}
return found.length === 1 ? found[0] : undefined;
}
Loading

0 comments on commit 9d2600d

Please sign in to comment.