Skip to content

Commit

Permalink
feat: Add xray support for drawing subgraphs (#542)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Oct 2, 2024
1 parent cfb6c36 commit 39eb622
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 70 deletions.
2 changes: 1 addition & 1 deletion examples/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"devDependencies": {
"@langchain/anthropic": "^0.3.0",
"@langchain/community": "^0.3.0",
"@langchain/core": "^0.3.0",
"@langchain/core": "^0.3.6",
"@langchain/groq": "^0.1.1",
"@langchain/langgraph": "workspace:*",
"@langchain/mistralai": "^0.1.0",
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"@jest/globals": "^29.5.0",
"@langchain/anthropic": "^0.3.0",
"@langchain/community": "^0.3.0",
"@langchain/core": "^0.3.0",
"@langchain/core": "^0.3.6",
"@langchain/langgraph-checkpoint-sqlite": "workspace:*",
"@langchain/openai": "^0.3.0",
"@langchain/scripts": ">=0.1.3 <0.2.0",
Expand Down
194 changes: 148 additions & 46 deletions libs/langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ import {
_coerceToRunnable,
Runnable,
RunnableConfig,
RunnableInterface,
RunnableIOSchema,
RunnableLike,
} from "@langchain/core/runnables";
import {
Node as RunnableGraphNode,
Graph as RunnableGraph,
Node as DrawableGraphNode,
Graph as DrawableGraph,
} from "@langchain/core/runnables/graph";
import { All, BaseCheckpointSaver } from "@langchain/langgraph-checkpoint";
import { z } from "zod";
import { validate as isUuid } from "uuid";
import { PregelNode } from "../pregel/read.js";
import { Channel, Pregel } from "../pregel/index.js";
import type { PregelParams } from "../pregel/types.js";
Expand All @@ -24,7 +27,7 @@ import {
Send,
TAG_HIDDEN,
} from "../constants.js";
import { RunnableCallable } from "../utils.js";
import { gatherIteratorSync, RunnableCallable } from "../utils.js";
import { InvalidUpdateError, NodeInterrupt } from "../errors.js";

/** Special reserved node name denoting the start of a graph. */
Expand Down Expand Up @@ -483,80 +486,172 @@ export class CompiledGraph<
*/
override getGraph(
config?: RunnableConfig & { xray?: boolean | number }
): RunnableGraph {
): DrawableGraph {
const xray = config?.xray;
const graph = new RunnableGraph();
const startNodes: Record<string, RunnableGraphNode> = {
const graph = new DrawableGraph();
const startNodes: Record<string, DrawableGraphNode> = {
[START]: graph.addNode(
{
schema: z.any(),
},
START
),
};
const endNodes: Record<string, RunnableGraphNode> = {
[END]: graph.addNode(
{
schema: z.any(),
},
END
),
};
for (const [key, node] of Object.entries<NodeSpec<unknown, unknown>>(
this.builder.nodes
)) {
if (config?.xray) {
const subgraph = isCompiledGraph(node)
? node.getGraph({
...config,
xray: typeof xray === "number" && xray > 0 ? xray - 1 : xray,
})
: node.runnable.getGraph(config);
subgraph.trimFirstNode();
subgraph.trimLastNode();
if (Object.keys(subgraph.nodes).length > 1) {
const [newEndNode, newStartNode] = graph.extend(subgraph, key);
if (newEndNode !== undefined) {
endNodes[key] = newEndNode;
const endNodes: Record<string, DrawableGraphNode> = {};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let subgraphs: Record<string, CompiledGraph<any>> = {};
if (xray) {
subgraphs = Object.fromEntries(
gatherIteratorSync(this.getSubgraphs()).filter(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(x): x is [string, CompiledGraph<any>] => isCompiledGraph(x[1])
)
);
}

function addEdge(
start: string,
end: string,
label?: string,
conditional = false
) {
if (end === END && endNodes[END] === undefined) {
endNodes[END] = graph.addNode({ schema: z.any() }, END);
}
return graph.addEdge(
startNodes[start],
endNodes[end],
label !== end ? label : undefined,
conditional
);
}

for (const [key, nodeSpec] of Object.entries(this.builder.nodes) as [
N,
NodeSpec<RunInput, RunOutput>
][]) {
const displayKey = _escapeMermaidKeywords(key);
const node = nodeSpec.runnable;
const metadata = nodeSpec.metadata ?? {};
if (
this.interruptBefore?.includes(key) &&
this.interruptAfter?.includes(key)
) {
metadata.__interrupt = "before,after";
} else if (this.interruptBefore?.includes(key)) {
metadata.__interrupt = "before";
} else if (this.interruptAfter?.includes(key)) {
metadata.__interrupt = "after";
}
if (xray) {
const newXrayValue = typeof xray === "number" ? xray - 1 : xray;
const drawableSubgraph =
subgraphs[key] !== undefined
? subgraphs[key].getGraph({
...config,
xray: newXrayValue,
})
: node.getGraph(config);
drawableSubgraph.trimFirstNode();
drawableSubgraph.trimLastNode();
if (Object.keys(drawableSubgraph.nodes).length > 1) {
const [e, s] = graph.extend(drawableSubgraph, displayKey);
if (e === undefined) {
throw new Error(
`Could not extend subgraph "${key}" due to missing entrypoint.`
);
}

// TODO: Remove default name once we stop supporting core 0.2.0
// eslint-disable-next-line no-inner-declarations
function _isRunnableInterface(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
thing: any
): thing is RunnableInterface {
return thing ? thing.lc_runnable : false;
}
// eslint-disable-next-line no-inner-declarations
function _nodeDataStr(
id: string | undefined,
data: RunnableInterface | RunnableIOSchema
): string {
if (id !== undefined && !isUuid(id)) {
return id;
} else if (_isRunnableInterface(data)) {
try {
let dataStr = data.getName();
dataStr = dataStr.startsWith("Runnable")
? dataStr.slice("Runnable".length)
: dataStr;
return dataStr;
} catch (error) {
return data.getName();
}
} else {
return data.name ?? "UnknownSchema";
}
}
if (newStartNode !== undefined) {
startNodes[key] = newStartNode;
// TODO: Remove casts when we stop supporting core 0.2.0
if (s !== undefined) {
startNodes[displayKey] = {
name: _nodeDataStr(s.id, s.data),
...s,
} as DrawableGraphNode;
}
endNodes[displayKey] = {
name: _nodeDataStr(e.id, e.data),
...e,
} as DrawableGraphNode;
} else {
const newNode = graph.addNode(node.runnable, key);
startNodes[key] = newNode;
endNodes[key] = newNode;
// TODO: Remove when we stop supporting core 0.2.0
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
const newNode = graph.addNode(node, displayKey, metadata);
startNodes[displayKey] = newNode;
endNodes[displayKey] = newNode;
}
} else {
const newNode = graph.addNode(node.runnable, key);
startNodes[key] = newNode;
endNodes[key] = newNode;
// TODO: Remove when we stop supporting core 0.2.0
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
const newNode = graph.addNode(node, displayKey, metadata);
startNodes[displayKey] = newNode;
endNodes[displayKey] = newNode;
}
}
for (const [start, end] of this.builder.allEdges) {
graph.addEdge(startNodes[start], endNodes[end]);
const sortedEdges = [...this.builder.allEdges].sort(([a], [b]) => {
if (a < b) {
return -1;
} else if (b > a) {
return 1;
} else {
return 0;
}
});
for (const [start, end] of sortedEdges) {
addEdge(_escapeMermaidKeywords(start), _escapeMermaidKeywords(end));
}
for (const [start, branches] of Object.entries(this.builder.branches)) {
const defaultEnds: Record<string, string> = {
...Object.fromEntries(
Object.keys(this.builder.nodes)
.filter((k) => k !== start)
.map((k) => [k, k])
.map((k) => [_escapeMermaidKeywords(k), _escapeMermaidKeywords(k)])
),
[END]: END,
};
for (const branch of Object.values(branches)) {
let ends: Record<string, string>;
let ends;
if (branch.ends !== undefined) {
ends = branch.ends;
} else {
ends = defaultEnds;
}
for (const [label, end] of Object.entries(ends)) {
graph.addEdge(
startNodes[start],
endNodes[end],
label !== end ? label : undefined,
addEdge(
_escapeMermaidKeywords(start),
_escapeMermaidKeywords(end),
label,
true
);
}
Expand All @@ -575,3 +670,10 @@ function isCompiledGraph(x: unknown): x is CompiledGraph<any> {
typeof (x as CompiledGraph<any>).attachEdge === "function"
);
}

function _escapeMermaidKeywords(key: string) {
if (key === "subgraph") {
return `"${key}"`;
}
return key;
}
Binary file modified libs/langgraph/src/tests/data/mermaid.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added libs/langgraph/src/tests/data/multiple_sinks.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added libs/langgraph/src/tests/data/nested_mermaid.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
88 changes: 80 additions & 8 deletions libs/langgraph/src/tests/diagrams.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { test, expect } from "@jest/globals";
import { createReactAgent } from "../prebuilt/index.js";
import { FakeSearchTool, FakeToolCallingChatModel } from "./utils.js";
import { Annotation, StateGraph } from "../web.js";

test("prebuilt agent", async () => {
// Define the tools for the agent to use
Expand All @@ -14,16 +15,87 @@ test("prebuilt agent", async () => {
const mermaid = graph.drawMermaid();
expect(mermaid).toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
\t__start__[__start__]:::startclass;
\t__end__[__end__]:::endclass;
\tagent([agent]):::otherclass;
\ttools([tools]):::otherclass;
\t__start__([<p>__start__</p>]):::first
\tagent(agent)
\ttools(tools)
\t__end__([<p>__end__</p>]):::last
\t__start__ --> agent;
\ttools --> agent;
\tagent -. continue .-> tools;
\tagent -. &nbsp;continue&nbsp; .-> tools;
\tagent -.-> __end__;
\tclassDef startclass fill:#ffdfba;
\tclassDef endclass fill:#baffc9;
\tclassDef otherclass fill:#fad7de;
\tclassDef default fill:#f2f0ff,line-height:1.2;
\tclassDef first fill-opacity:0;
\tclassDef last fill:#bfb6fc;
`);
});

test("graph with multiple sinks", async () => {
const StateAnnotation = Annotation.Root({});
const app = new StateGraph(StateAnnotation)
.addNode("inner1", async () => {})
.addNode("inner2", async () => {})
.addNode("inner3", async () => {})
.addEdge("__start__", "inner1")
.addConditionalEdges("inner1", async () => "inner2", ["inner2", "inner3"])
.compile();

const graph = app.getGraph();
const mermaid = graph.drawMermaid();
expect(mermaid).toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
\t__start__([<p>__start__</p>]):::first
\tinner1(inner1)
\tinner2(inner2)
\tinner3(inner3)
\t__start__ --> inner1;
\tinner1 -.-> inner2;
\tinner1 -.-> inner3;
\tclassDef default fill:#f2f0ff,line-height:1.2;
\tclassDef first fill-opacity:0;
\tclassDef last fill:#bfb6fc;
`);
});

test("graph with subgraphs", async () => {
const SubgraphStateAnnotation = Annotation.Root({});
const subgraph = new StateGraph(SubgraphStateAnnotation)
.addNode("inner1", async () => {})
.addNode("inner2", async () => {})
.addNode("inner3", async () => {})
.addEdge("__start__", "inner1")
.addConditionalEdges("inner1", async () => "inner2", ["inner2", "inner3"])
.compile();

const StateAnnotation = Annotation.Root({});

const app = new StateGraph(StateAnnotation)
.addNode("starter", async () => {})
.addNode("inner", subgraph)
.addNode("final", async () => {})
.addEdge("__start__", "starter")
.addConditionalEdges("starter", async () => "final", ["inner", "final"])
.compile({ interruptBefore: ["starter"] });

const graph = app.getGraph({ xray: true });
const mermaid = graph.drawMermaid();
console.log(mermaid);
expect(mermaid).toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
\t__start__([<p>__start__</p>]):::first
\tstarter(starter<hr/><small><em>__interrupt = before</em></small>)
\tinner_inner1(inner1)
\tinner_inner2(inner2)
\tinner_inner3(inner3)
\tfinal(final)
\t__start__ --> starter;
\tstarter -.-> inner_inner1;
\tstarter -.-> final;
\tsubgraph inner
\tinner_inner1 -.-> inner_inner2;
\tinner_inner1 -.-> inner_inner3;
\tend
\tclassDef default fill:#f2f0ff,line-height:1.2;
\tclassDef first fill-opacity:0;
\tclassDef last fill:#bfb6fc;
`);
});
Loading

0 comments on commit 39eb622

Please sign in to comment.