Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement getState, updateState, getStateHistory #148

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion langgraph/src/checkpoint/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ export abstract class BaseCheckpointSaver {
config: RunnableConfig
): Promise<CheckpointTuple | undefined>;

abstract list(config: RunnableConfig): AsyncGenerator<CheckpointTuple>;
abstract list(
config: RunnableConfig,
limit?: number,
before?: RunnableConfig
): AsyncGenerator<CheckpointTuple>;

abstract put(
config: RunnableConfig,
Expand Down
15 changes: 11 additions & 4 deletions langgraph/src/checkpoint/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,21 @@ export class MemorySaver extends BaseCheckpointSaver {
return undefined;
}

async *list(config: RunnableConfig): AsyncGenerator<CheckpointTuple> {
async *list(
config: RunnableConfig,
limit?: number,
before?: RunnableConfig
): AsyncGenerator<CheckpointTuple> {
const thread_id = config.configurable?.thread_id;
const checkpoints = this.storage[thread_id] ?? {};

// sort in desc order
for (const [checkpoint_id, checkpoint] of Object.entries(checkpoints).sort(
(a, b) => b[0].localeCompare(a[0])
)) {
for (const [checkpoint_id, checkpoint] of Object.entries(checkpoints)
.filter((c) =>
before ? c[0] < before.configurable?.checkpoint_id : true
)
.sort((a, b) => b[0].localeCompare(a[0]))
.slice(0, limit)) {
yield {
config: { configurable: { thread_id, checkpoint_id } },
checkpoint: this.serde.parse(checkpoint[0]) as Checkpoint,
Expand Down
21 changes: 15 additions & 6 deletions langgraph/src/checkpoint/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,25 @@ CREATE TABLE IF NOT EXISTS checkpoints (
return undefined;
}

async *list(config: RunnableConfig): AsyncGenerator<CheckpointTuple> {
async *list(
config: RunnableConfig,
limit?: number,
before?: RunnableConfig
): AsyncGenerator<CheckpointTuple> {
this.setup();
const thread_id = config.configurable?.thread_id;
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ${
before ? "AND checkpoint_id < ?" : ""
} ORDER BY checkpoint_id DESC`;
if (limit) {
sql += ` LIMIT ${limit}`;
}
const args = [thread_id, before?.configurable?.checkpoint_id].filter(
Boolean
);

try {
const rows: Row[] = this.db
.prepare(
`SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ORDER BY checkpoint_id DESC`
)
.all(thread_id) as Row[];
const rows: Row[] = this.db.prepare(sql).all(...args) as Row[];

if (rows) {
for (const row of rows) {
Expand Down
150 changes: 149 additions & 1 deletion langgraph/src/pregel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
RunnableConfig,
RunnableFunc,
RunnableLike,
RunnableSequence,
_coerceToRunnable,
ensureConfig,
patchConfig,
Expand Down Expand Up @@ -40,7 +41,12 @@ import {
TAG_HIDDEN,
} from "../constants.js";
import { initializeAsyncLocalStorageSingleton } from "../setup/async_local_storage.js";
import { All, PregelExecutableTask, PregelTaskDescription } from "./types.js";
import {
All,
PregelExecutableTask,
PregelTaskDescription,
StateSnapshot,
} from "./types.js";
import {
EmptyChannelError,
GraphRecursionError,
Expand Down Expand Up @@ -308,6 +314,148 @@ export class Pregel<
}
}

async getState(config: RunnableConfig): Promise<StateSnapshot> {
if (!this.checkpointer) {
throw new GraphValueError("No checkpointer set");
}

const saved = await this.checkpointer.getTuple(config);
const checkpoint = saved ? saved.checkpoint : emptyCheckpoint();
const channels = emptyChannels(this.channels, checkpoint);
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const [_, nextTasks] = _prepareNextTasks(
checkpoint,
this.nodes,
channels,
false
);
return {
values: readChannels(channels, this.streamChannelsAsIs),
next: nextTasks.map((task) => task.name),
metadata: saved?.metadata,
config: saved ? saved.config : config,
parentConfig: saved?.parentConfig,
};
}

async *getStateHistory(
config: RunnableConfig,
limit?: number,
before?: RunnableConfig
): AsyncIterableIterator<StateSnapshot> {
if (!this.checkpointer) {
throw new GraphValueError("No checkpointer set");
}
for await (const saved of this.checkpointer.list(config, limit, before)) {
const channels = emptyChannels(this.channels, saved.checkpoint);
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const [_, nextTasks] = _prepareNextTasks(
saved.checkpoint,
this.nodes,
channels,
false
);
yield {
values: readChannels(channels, this.streamChannelsAsIs),
next: nextTasks.map((task) => task.name),
metadata: saved.metadata,
config: saved.config,
parentConfig: saved.parentConfig,
};
}
}

async updateState(
config: RunnableConfig,
values: Record<string, unknown> | unknown,
asNode?: keyof Nn
): Promise<RunnableConfig> {
if (!this.checkpointer) {
throw new GraphValueError("No checkpointer set");
}

// Get the latest checkpoint
const saved = await this.checkpointer.getTuple(config);
const checkpoint = saved
? copyCheckpoint(saved.checkpoint)
: emptyCheckpoint();
// Find last that updated the state, if not provided
const maxSeens = Object.entries(checkpoint.versions_seen).reduce(
(acc, [node, versions]) => {
const maxSeen = Math.max(...Object.values(versions));
if (maxSeen) {
if (!acc[maxSeen]) {
acc[maxSeen] = [];
}
acc[maxSeen].push(node);
}
return acc;
},
{} as Record<number, string[]>
);
if (!asNode && !Object.keys(maxSeens).length) {
if (!Array.isArray(this.inputs) && this.inputs in this.nodes) {
asNode = this.inputs as keyof Nn;
}
} else if (!asNode) {
const maxSeen = Math.max(...Object.keys(maxSeens).map(Number));
const nodes = maxSeens[maxSeen];
if (nodes.length === 1) {
asNode = nodes[0] as keyof Nn;
}
}
if (!asNode) {
throw new InvalidUpdateError("Ambiguous update, specify as_node");
}
// update channels
const channels = emptyChannels(this.channels, checkpoint);
// create task to run all writers of the chosen node
const writers = this.nodes[asNode].getWriters();
if (!writers.length) {
throw new InvalidUpdateError(
`No writers found for node ${asNode as string}`
);
}
const task: PregelExecutableTask<keyof Nn, keyof Cc> = {
name: asNode,
input: values,
proc:
// eslint-disable-next-line @typescript-eslint/no-explicit-any
writers.length > 1 ? RunnableSequence.from(writers as any) : writers[0],
writes: [],
config: undefined,
};
// execute task
await task.proc.invoke(
task.input,
patchConfig(config, {
runName: `${this.name}UpdateState`,
configurable: {
[CONFIG_KEY_SEND]: (items: [keyof Cc, unknown][]) =>
task.writes.push(...items),
[CONFIG_KEY_READ]: _localRead.bind(
undefined,
checkpoint,
channels,
task.writes as Array<[string, unknown]>
),
},
})
);
// apply to checkpoint and save
_applyWrites(checkpoint, channels, task.writes);
const step = (saved?.metadata?.step ?? -2) + 1;
return await this.checkpointer.put(
saved?.config ?? config,
createCheckpoint(checkpoint, channels, step),
{
source: "update",
step,
writes: { [asNode]: values },
}
);
}

_defaults(config: PregelOptions<Nn, Cc>): [
boolean, // debug
StreamMode, // stream mode
Expand Down
5 changes: 5 additions & 0 deletions langgraph/src/pregel/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Runnable, RunnableConfig } from "@langchain/core/runnables";
import { CheckpointMetadata } from "../checkpoint/base.js";

export interface PregelTaskDescription {
readonly name: string;
Expand Down Expand Up @@ -30,6 +31,10 @@ export interface StateSnapshot {
* Config used to fetch this snapshot
*/
readonly config: RunnableConfig;
/**
* Metadata about the checkpoint
*/
readonly metadata?: CheckpointMetadata;
/**
* Config used to fetch the parent snapshot, if any
* @default undefined
Expand Down
Loading