Skip to content

Commit

Permalink
Implement send API (#305)
Browse files Browse the repository at this point in the history
* Allow multiple writes to the same node at a given step

* Implement send

* Fix

* Fix build

* Fix test

* Adds map reduce test

* Fix streaming and test

* Fix

* Update test

* Fix display of next tasks in graph state
  • Loading branch information
jacoblee93 authored Aug 12, 2024
1 parent 802111d commit 1e66bc2
Show file tree
Hide file tree
Showing 16 changed files with 935 additions and 141 deletions.
1 change: 1 addition & 0 deletions langgraph/.eslintrc.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ module.exports = {
"no-await-in-loop": 0,
"no-bitwise": 0,
"no-console": 0,
"no-empty-function": 0,
"no-restricted-syntax": 0,
"no-shadow": 0,
"no-continue": 0,
Expand Down
1 change: 1 addition & 0 deletions langgraph/src/channels/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,6 @@ export function createCheckpoint<ValueType>(
channel_values: values,
channel_versions: { ...checkpoint.channel_versions },
versions_seen: deepCopy(checkpoint.versions_seen),
pending_sends: checkpoint.pending_sends ?? [],
};
}
2 changes: 1 addition & 1 deletion langgraph/src/channels/ephemeral_value.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export class EphemeralValue<Value> extends BaseChannel<Value, Value, Value> {
}

fromCheckpoint(checkpoint?: Value) {
const empty = new EphemeralValue<Value>();
const empty = new EphemeralValue<Value>(this.guard);
if (checkpoint) {
empty.value = checkpoint;
}
Expand Down
8 changes: 8 additions & 0 deletions langgraph/src/checkpoint/base.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { RunnableConfig } from "@langchain/core/runnables";
import { DefaultSerializer, SerializerProtocol } from "../serde/base.js";
import { uuid6 } from "./id.js";
import { SendInterface } from "../constants.js";

export interface CheckpointMetadata {
source: "input" | "loop" | "update";
Expand Down Expand Up @@ -50,6 +51,11 @@ export interface Checkpoint<
* @default {}
*/
versions_seen: Record<N, Record<C, number>>;
/**
* List of packets sent to nodes but not yet processed.
* Cleared by the next checkpoint.
*/
pending_sends: SendInterface[];
}

export interface ReadonlyCheckpoint extends Readonly<Checkpoint> {
Expand Down Expand Up @@ -101,6 +107,7 @@ export function emptyCheckpoint(): Checkpoint {
channel_values: {},
channel_versions: {},
versions_seen: {},
pending_sends: [],
};
}

Expand All @@ -112,6 +119,7 @@ export function copyCheckpoint(checkpoint: ReadonlyCheckpoint): Checkpoint {
channel_values: { ...checkpoint.channel_values },
channel_versions: { ...checkpoint.channel_versions },
versions_seen: deepCopy(checkpoint.versions_seen),
pending_sends: [...checkpoint.pending_sends],
};
}

Expand Down
36 changes: 36 additions & 0 deletions langgraph/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,39 @@ export const INTERRUPT = "__interrupt__";
export const TAG_HIDDEN = "langsmith:hidden";

export const TASKS = "__pregel_tasks";

export interface SendInterface {
node: string;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
args: any;
}

export function _isSendInterface(x: unknown): x is SendInterface {
const operation = x as SendInterface;
return typeof operation.node === "string" && operation.args !== undefined;
}

/**
* A message or packet to send to a specific node in the graph.
*
* The `Send` class is used within a `StateGraph`'s conditional edges to
* dynamically invoke a node with a custom state at the next step.
*
* Importantly, the sent state can differ from the core graph's state,
* allowing for flexible and dynamic workflow management.
*
* One such example is a "map-reduce" workflow where your graph invokes
* the same node multiple times in parallel with different states,
* before aggregating the results back into the main graph's state.
*/
export class Send implements SendInterface {
lg_name = "Send";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
constructor(public node: string, public args: any) {}
}

export function _isSend(x: unknown): x is Send {
const operation = x as Send;
return operation.lg_name === "Send";
}
45 changes: 29 additions & 16 deletions langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ import { BaseChannel } from "../channels/base.js";
import { EphemeralValue } from "../channels/ephemeral_value.js";
import { All } from "../pregel/types.js";
import { ChannelWrite, PASSTHROUGH } from "../pregel/write.js";
import { TAG_HIDDEN } from "../constants.js";
import { _isSend, Send, TAG_HIDDEN } from "../constants.js";
import { RunnableCallable } from "../utils.js";
import { InvalidUpdateError } from "../errors.js";

export const START = "__start__";
export const END = "__end__";
Expand All @@ -33,7 +34,11 @@ export class Branch<IO, N extends string> {
condition: (
input: IO,
config?: RunnableConfig
) => string | string[] | Promise<string> | Promise<string[]>;
) =>
| string
| Send
| (string | Send)[]
| Promise<string | Send | (string | Send)[]>;

ends?: Record<string, N | typeof END>;

Expand All @@ -48,7 +53,7 @@ export class Branch<IO, N extends string> {
}

compile(
writer: (dests: string[]) => Runnable | undefined,
writer: (dests: (string | Send)[]) => Runnable | undefined,
reader?: (config: RunnableConfig) => IO
) {
return ChannelWrite.registerWriter(
Expand All @@ -62,23 +67,27 @@ export class Branch<IO, N extends string> {
async _route(
input: IO,
config: RunnableConfig,
writer: (dests: string[]) => Runnable | undefined,
writer: (dests: (string | Send)[]) => Runnable | undefined,
reader?: (config: RunnableConfig) => IO
): Promise<Runnable | undefined> {
let result = await this.condition(reader ? reader(config) : input, config);
if (!Array.isArray(result)) {
result = [result];
}

let destinations: string[];
let destinations: (string | Send)[];
if (this.ends) {
destinations = result.map((r) => this.ends![r]);
// destinations = [r if isinstance(r, Send) else self.ends[r] for r in result]
destinations = result.map((r) => (_isSend(r) ? r : this.ends![r]));
} else {
destinations = result;
}
if (destinations.some((dest) => !dest)) {
throw new Error("Branch condition returned unknown or null destination");
}
if (destinations.filter(_isSend).some((packet) => packet.node === END)) {
throw new InvalidUpdateError("Cannot send a packet to the END node");
}
return writer(destinations);
}
}
Expand Down Expand Up @@ -118,9 +127,9 @@ export class Graph<
return this.edges;
}

addNode<K extends string>(
addNode<K extends string, NodeInput = RunInput>(
key: K,
action: RunnableLike<RunInput, RunOutput>
action: RunnableLike<NodeInput, RunOutput>
): Graph<N | K, RunInput, RunOutput> {
this.warnIfCompiled(
`Adding a node to a graph that has already been compiled. This will not be reflected in the compiled graph.`
Expand All @@ -134,7 +143,8 @@ export class Graph<
}

this.nodes[key as unknown as N] = _coerceToRunnable<RunInput, RunOutput>(
action
// Account for arbitrary state due to Send API
action as RunnableLike<RunInput, RunOutput>
);

return this as Graph<N | K, RunInput, RunOutput>;
Expand Down Expand Up @@ -394,13 +404,16 @@ export class CompiledGraph<
// attach branch writer
this.nodes[start].pipe(
branch.compile((dests) => {
const channels = dests.map((dest) =>
dest === END ? END : `branch:${start}:${name}:${dest}`
);
return new ChannelWrite(
channels.map((channel) => ({ channel, value: PASSTHROUGH })),
[TAG_HIDDEN]
);
const writes = dests.map((dest) => {
if (_isSend(dest)) {
return dest;
}
return {
channel: dest === END ? END : `branch:${start}:${name}:${dest}`,
value: PASSTHROUGH,
};
});
return new ChannelWrite(writes, [TAG_HIDDEN]);
})
);

Expand Down
25 changes: 16 additions & 9 deletions langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { NamedBarrierValue } from "../channels/named_barrier_value.js";
import { EphemeralValue } from "../channels/ephemeral_value.js";
import { RunnableCallable } from "../utils.js";
import { All } from "../pregel/types.js";
import { TAG_HIDDEN } from "../constants.js";
import { _isSend, Send, TAG_HIDDEN } from "../constants.js";
import { InvalidUpdateError } from "../errors.js";

const ROOT = "__root__";
Expand Down Expand Up @@ -139,9 +139,9 @@ export class StateGraph<
]);
}

addNode<K extends string>(
addNode<K extends string, NodeInput = S>(
key: K,
action: RunnableLike<S, U>
action: RunnableLike<NodeInput, U>
): StateGraph<SD, S, U, N | K> {
if (key in this.channels) {
throw new Error(
Expand Down Expand Up @@ -333,7 +333,7 @@ export class CompiledStateGraph<
writers: [new ChannelWrite(stateWriteEntries, [TAG_HIDDEN])],
});
} else {
this.channels[key] = new EphemeralValue();
this.channels[key] = new EphemeralValue(false);
this.nodes[key] = new PregelNode<S, U>({
triggers: [],
// read state keys
Expand Down Expand Up @@ -403,10 +403,17 @@ export class CompiledStateGraph<
if (!filteredDests.length) {
return;
}
const writes: ChannelWriteEntry[] = filteredDests.map((dest) => ({
channel: `branch:${start}:${name}:${dest}`,
value: start,
}));
const writes: (ChannelWriteEntry | Send)[] = filteredDests.map(
(dest) => {
if (_isSend(dest)) {
return dest;
}
return {
channel: `branch:${start}:${name}:${dest}`,
value: start,
};
}
);
return new ChannelWrite(writes, [TAG_HIDDEN]);
},
// reader
Expand All @@ -424,7 +431,7 @@ export class CompiledStateGraph<
}
const channelName = `branch:${start}:${name}:${end}`;
(this.channels as Record<string, BaseChannel>)[channelName] =
new EphemeralValue();
new EphemeralValue(false);
this.nodes[end as N].triggers.push(channelName);
}
}
Expand Down
Loading

0 comments on commit 1e66bc2

Please sign in to comment.