diff --git a/examples/package.json b/examples/package.json index 0ff54bde..71d5a4d5 100644 --- a/examples/package.json +++ b/examples/package.json @@ -8,7 +8,7 @@ "devDependencies": { "@langchain/anthropic": "^0.3.0", "@langchain/community": "^0.3.0", - "@langchain/core": "^0.3.0", + "@langchain/core": "^0.3.6", "@langchain/groq": "^0.1.1", "@langchain/langgraph": "workspace:*", "@langchain/mistralai": "^0.1.0", diff --git a/libs/langgraph/package.json b/libs/langgraph/package.json index 92982ebb..df9489f6 100644 --- a/libs/langgraph/package.json +++ b/libs/langgraph/package.json @@ -43,7 +43,7 @@ "@jest/globals": "^29.5.0", "@langchain/anthropic": "^0.3.0", "@langchain/community": "^0.3.0", - "@langchain/core": "^0.3.0", + "@langchain/core": "^0.3.6", "@langchain/langgraph-checkpoint-sqlite": "workspace:*", "@langchain/openai": "^0.3.0", "@langchain/scripts": ">=0.1.3 <0.2.0", diff --git a/libs/langgraph/src/graph/graph.ts b/libs/langgraph/src/graph/graph.ts index 1b121ee8..e2e26389 100644 --- a/libs/langgraph/src/graph/graph.ts +++ b/libs/langgraph/src/graph/graph.ts @@ -3,14 +3,17 @@ import { _coerceToRunnable, Runnable, RunnableConfig, + RunnableInterface, + RunnableIOSchema, RunnableLike, } from "@langchain/core/runnables"; import { - Node as RunnableGraphNode, - Graph as RunnableGraph, + Node as DrawableGraphNode, + Graph as DrawableGraph, } from "@langchain/core/runnables/graph"; import { All, BaseCheckpointSaver } from "@langchain/langgraph-checkpoint"; import { z } from "zod"; +import { validate as isUuid } from "uuid"; import { PregelNode } from "../pregel/read.js"; import { Channel, Pregel } from "../pregel/index.js"; import type { PregelParams } from "../pregel/types.js"; @@ -24,7 +27,7 @@ import { Send, TAG_HIDDEN, } from "../constants.js"; -import { RunnableCallable } from "../utils.js"; +import { gatherIteratorSync, RunnableCallable } from "../utils.js"; import { InvalidUpdateError, NodeInterrupt } from "../errors.js"; /** Special reserved node name denoting the start of a graph. */ @@ -483,10 +486,10 @@ export class CompiledGraph< */ override getGraph( config?: RunnableConfig & { xray?: boolean | number } - ): RunnableGraph { + ): DrawableGraph { const xray = config?.xray; - const graph = new RunnableGraph(); - const startNodes: Record = { + const graph = new DrawableGraph(); + const startNodes: Record = { [START]: graph.addNode( { schema: z.any(), @@ -494,69 +497,161 @@ export class CompiledGraph< START ), }; - const endNodes: Record = { - [END]: graph.addNode( - { - schema: z.any(), - }, - END - ), - }; - for (const [key, node] of Object.entries>( - this.builder.nodes - )) { - if (config?.xray) { - const subgraph = isCompiledGraph(node) - ? node.getGraph({ - ...config, - xray: typeof xray === "number" && xray > 0 ? xray - 1 : xray, - }) - : node.runnable.getGraph(config); - subgraph.trimFirstNode(); - subgraph.trimLastNode(); - if (Object.keys(subgraph.nodes).length > 1) { - const [newEndNode, newStartNode] = graph.extend(subgraph, key); - if (newEndNode !== undefined) { - endNodes[key] = newEndNode; + const endNodes: Record = {}; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let subgraphs: Record> = {}; + if (xray) { + subgraphs = Object.fromEntries( + gatherIteratorSync(this.getSubgraphs()).filter( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (x): x is [string, CompiledGraph] => isCompiledGraph(x[1]) + ) + ); + } + + function addEdge( + start: string, + end: string, + label?: string, + conditional = false + ) { + if (end === END && endNodes[END] === undefined) { + endNodes[END] = graph.addNode({ schema: z.any() }, END); + } + return graph.addEdge( + startNodes[start], + endNodes[end], + label !== end ? label : undefined, + conditional + ); + } + + for (const [key, nodeSpec] of Object.entries(this.builder.nodes) as [ + N, + NodeSpec + ][]) { + const displayKey = _escapeMermaidKeywords(key); + const node = nodeSpec.runnable; + const metadata = nodeSpec.metadata ?? {}; + if ( + this.interruptBefore?.includes(key) && + this.interruptAfter?.includes(key) + ) { + metadata.__interrupt = "before,after"; + } else if (this.interruptBefore?.includes(key)) { + metadata.__interrupt = "before"; + } else if (this.interruptAfter?.includes(key)) { + metadata.__interrupt = "after"; + } + if (xray) { + const newXrayValue = typeof xray === "number" ? xray - 1 : xray; + const drawableSubgraph = + subgraphs[key] !== undefined + ? subgraphs[key].getGraph({ + ...config, + xray: newXrayValue, + }) + : node.getGraph(config); + drawableSubgraph.trimFirstNode(); + drawableSubgraph.trimLastNode(); + if (Object.keys(drawableSubgraph.nodes).length > 1) { + const [e, s] = graph.extend(drawableSubgraph, displayKey); + if (e === undefined) { + throw new Error( + `Could not extend subgraph "${key}" due to missing entrypoint.` + ); + } + + // TODO: Remove default name once we stop supporting core 0.2.0 + // eslint-disable-next-line no-inner-declarations + function _isRunnableInterface( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + thing: any + ): thing is RunnableInterface { + return thing ? thing.lc_runnable : false; + } + // eslint-disable-next-line no-inner-declarations + function _nodeDataStr( + id: string | undefined, + data: RunnableInterface | RunnableIOSchema + ): string { + if (id !== undefined && !isUuid(id)) { + return id; + } else if (_isRunnableInterface(data)) { + try { + let dataStr = data.getName(); + dataStr = dataStr.startsWith("Runnable") + ? dataStr.slice("Runnable".length) + : dataStr; + return dataStr; + } catch (error) { + return data.getName(); + } + } else { + return data.name ?? "UnknownSchema"; + } } - if (newStartNode !== undefined) { - startNodes[key] = newStartNode; + // TODO: Remove casts when we stop supporting core 0.2.0 + if (s !== undefined) { + startNodes[displayKey] = { + name: _nodeDataStr(s.id, s.data), + ...s, + } as DrawableGraphNode; } + endNodes[displayKey] = { + name: _nodeDataStr(e.id, e.data), + ...e, + } as DrawableGraphNode; } else { - const newNode = graph.addNode(node.runnable, key); - startNodes[key] = newNode; - endNodes[key] = newNode; + // TODO: Remove when we stop supporting core 0.2.0 + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + const newNode = graph.addNode(node, displayKey, metadata); + startNodes[displayKey] = newNode; + endNodes[displayKey] = newNode; } } else { - const newNode = graph.addNode(node.runnable, key); - startNodes[key] = newNode; - endNodes[key] = newNode; + // TODO: Remove when we stop supporting core 0.2.0 + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + const newNode = graph.addNode(node, displayKey, metadata); + startNodes[displayKey] = newNode; + endNodes[displayKey] = newNode; } } - for (const [start, end] of this.builder.allEdges) { - graph.addEdge(startNodes[start], endNodes[end]); + const sortedEdges = [...this.builder.allEdges].sort(([a], [b]) => { + if (a < b) { + return -1; + } else if (b > a) { + return 1; + } else { + return 0; + } + }); + for (const [start, end] of sortedEdges) { + addEdge(_escapeMermaidKeywords(start), _escapeMermaidKeywords(end)); } for (const [start, branches] of Object.entries(this.builder.branches)) { const defaultEnds: Record = { ...Object.fromEntries( Object.keys(this.builder.nodes) .filter((k) => k !== start) - .map((k) => [k, k]) + .map((k) => [_escapeMermaidKeywords(k), _escapeMermaidKeywords(k)]) ), [END]: END, }; for (const branch of Object.values(branches)) { - let ends: Record; + let ends; if (branch.ends !== undefined) { ends = branch.ends; } else { ends = defaultEnds; } for (const [label, end] of Object.entries(ends)) { - graph.addEdge( - startNodes[start], - endNodes[end], - label !== end ? label : undefined, + addEdge( + _escapeMermaidKeywords(start), + _escapeMermaidKeywords(end), + label, true ); } @@ -575,3 +670,10 @@ function isCompiledGraph(x: unknown): x is CompiledGraph { typeof (x as CompiledGraph).attachEdge === "function" ); } + +function _escapeMermaidKeywords(key: string) { + if (key === "subgraph") { + return `"${key}"`; + } + return key; +} diff --git a/libs/langgraph/src/tests/data/mermaid.png b/libs/langgraph/src/tests/data/mermaid.png index 60741bf1..e23d3770 100644 Binary files a/libs/langgraph/src/tests/data/mermaid.png and b/libs/langgraph/src/tests/data/mermaid.png differ diff --git a/libs/langgraph/src/tests/data/multiple_sinks.png b/libs/langgraph/src/tests/data/multiple_sinks.png new file mode 100644 index 00000000..6f36fde9 Binary files /dev/null and b/libs/langgraph/src/tests/data/multiple_sinks.png differ diff --git a/libs/langgraph/src/tests/data/nested_mermaid.png b/libs/langgraph/src/tests/data/nested_mermaid.png new file mode 100644 index 00000000..bc0206ef Binary files /dev/null and b/libs/langgraph/src/tests/data/nested_mermaid.png differ diff --git a/libs/langgraph/src/tests/diagrams.test.ts b/libs/langgraph/src/tests/diagrams.test.ts index 8f8d4139..aec3b44a 100644 --- a/libs/langgraph/src/tests/diagrams.test.ts +++ b/libs/langgraph/src/tests/diagrams.test.ts @@ -1,6 +1,7 @@ import { test, expect } from "@jest/globals"; import { createReactAgent } from "../prebuilt/index.js"; import { FakeSearchTool, FakeToolCallingChatModel } from "./utils.js"; +import { Annotation, StateGraph } from "../web.js"; test("prebuilt agent", async () => { // Define the tools for the agent to use @@ -14,16 +15,87 @@ test("prebuilt agent", async () => { const mermaid = graph.drawMermaid(); expect(mermaid).toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%% graph TD; -\t__start__[__start__]:::startclass; -\t__end__[__end__]:::endclass; -\tagent([agent]):::otherclass; -\ttools([tools]):::otherclass; +\t__start__([

__start__

]):::first +\tagent(agent) +\ttools(tools) +\t__end__([

__end__

]):::last \t__start__ --> agent; \ttools --> agent; -\tagent -. continue .-> tools; +\tagent -.  continue  .-> tools; \tagent -.-> __end__; -\tclassDef startclass fill:#ffdfba; -\tclassDef endclass fill:#baffc9; -\tclassDef otherclass fill:#fad7de; +\tclassDef default fill:#f2f0ff,line-height:1.2; +\tclassDef first fill-opacity:0; +\tclassDef last fill:#bfb6fc; +`); +}); + +test("graph with multiple sinks", async () => { + const StateAnnotation = Annotation.Root({}); + const app = new StateGraph(StateAnnotation) + .addNode("inner1", async () => {}) + .addNode("inner2", async () => {}) + .addNode("inner3", async () => {}) + .addEdge("__start__", "inner1") + .addConditionalEdges("inner1", async () => "inner2", ["inner2", "inner3"]) + .compile(); + + const graph = app.getGraph(); + const mermaid = graph.drawMermaid(); + expect(mermaid).toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%% +graph TD; +\t__start__([

__start__

]):::first +\tinner1(inner1) +\tinner2(inner2) +\tinner3(inner3) +\t__start__ --> inner1; +\tinner1 -.-> inner2; +\tinner1 -.-> inner3; +\tclassDef default fill:#f2f0ff,line-height:1.2; +\tclassDef first fill-opacity:0; +\tclassDef last fill:#bfb6fc; +`); +}); + +test("graph with subgraphs", async () => { + const SubgraphStateAnnotation = Annotation.Root({}); + const subgraph = new StateGraph(SubgraphStateAnnotation) + .addNode("inner1", async () => {}) + .addNode("inner2", async () => {}) + .addNode("inner3", async () => {}) + .addEdge("__start__", "inner1") + .addConditionalEdges("inner1", async () => "inner2", ["inner2", "inner3"]) + .compile(); + + const StateAnnotation = Annotation.Root({}); + + const app = new StateGraph(StateAnnotation) + .addNode("starter", async () => {}) + .addNode("inner", subgraph) + .addNode("final", async () => {}) + .addEdge("__start__", "starter") + .addConditionalEdges("starter", async () => "final", ["inner", "final"]) + .compile({ interruptBefore: ["starter"] }); + + const graph = app.getGraph({ xray: true }); + const mermaid = graph.drawMermaid(); + console.log(mermaid); + expect(mermaid).toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%% +graph TD; +\t__start__([

__start__

]):::first +\tstarter(starter
__interrupt = before) +\tinner_inner1(inner1) +\tinner_inner2(inner2) +\tinner_inner3(inner3) +\tfinal(final) +\t__start__ --> starter; +\tstarter -.-> inner_inner1; +\tstarter -.-> final; +\tsubgraph inner +\tinner_inner1 -.-> inner_inner2; +\tinner_inner1 -.-> inner_inner3; +\tend +\tclassDef default fill:#f2f0ff,line-height:1.2; +\tclassDef first fill-opacity:0; +\tclassDef last fill:#bfb6fc; `); }); diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index f1a675e6..341cdec1 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -4113,8 +4113,8 @@ describe("MessageGraph", () => { }, ]), edges: expect.arrayContaining([ - { source: "__start__", target: "agent" }, - { source: "action", target: "agent" }, + { conditional: false, source: "__start__", target: "agent" }, + { conditional: false, source: "action", target: "agent" }, { source: "agent", target: "action", @@ -4260,8 +4260,8 @@ describe("MessageGraph", () => { }, ]), edges: expect.arrayContaining([ - { source: "__start__", target: "agent" }, - { source: "action", target: "agent" }, + { conditional: false, source: "__start__", target: "agent" }, + { conditional: false, source: "action", target: "agent" }, { source: "agent", target: "action", @@ -4946,10 +4946,10 @@ it("StateGraph branch then node", async () => { }, ]), edges: expect.arrayContaining([ - { source: "__start__", target: "prepare" }, - { source: "tool_two_fast", target: "finish" }, - { source: "tool_two_slow", target: "finish" }, - { source: "finish", target: "__end__" }, + { source: "__start__", target: "prepare", conditional: false }, + { source: "tool_two_fast", target: "finish", conditional: false }, + { source: "tool_two_slow", target: "finish", conditional: false }, + { source: "finish", target: "__end__", conditional: false }, { source: "prepare", target: "tool_two_slow", conditional: true }, { source: "prepare", target: "tool_two_fast", conditional: true }, { source: "prepare", target: "finish", conditional: true }, diff --git a/yarn.lock b/yarn.lock index 8b0480fc..bdbbb12a 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1579,9 +1579,9 @@ __metadata: languageName: node linkType: hard -"@langchain/core@npm:^0.3.0": - version: 0.3.3 - resolution: "@langchain/core@npm:0.3.3" +"@langchain/core@npm:^0.3.6": + version: 0.3.6 + resolution: "@langchain/core@npm:0.3.6" dependencies: ansi-styles: ^5.0.0 camelcase: 6 @@ -1594,7 +1594,7 @@ __metadata: uuid: ^10.0.0 zod: ^3.22.4 zod-to-json-schema: ^3.22.3 - checksum: 7f803d7855bd18ac8e0c67e14ef12f5f6d6023b73bee472335a936951a206ae9d80958c81dd696338aede5def407b4eb761161aefab816461a38f0685e67bc2e + checksum: 16666fb36e3c4ea42c26910f668161051a66cfe24b3ac4f97d9094bc0482527582790c9862be17ff4a81d5f32b9acdcfb7fc7517371189ab73fae3bd8fa1a1d2 languageName: node linkType: hard @@ -1730,7 +1730,7 @@ __metadata: "@jest/globals": ^29.5.0 "@langchain/anthropic": ^0.3.0 "@langchain/community": ^0.3.0 - "@langchain/core": ^0.3.0 + "@langchain/core": ^0.3.6 "@langchain/langgraph-checkpoint": ~0.0.9 "@langchain/langgraph-checkpoint-sqlite": "workspace:*" "@langchain/openai": ^0.3.0 @@ -6240,7 +6240,7 @@ __metadata: dependencies: "@langchain/anthropic": ^0.3.0 "@langchain/community": ^0.3.0 - "@langchain/core": ^0.3.0 + "@langchain/core": ^0.3.6 "@langchain/groq": ^0.1.1 "@langchain/langgraph": "workspace:*" "@langchain/mistralai": ^0.1.0