Skip to content

Commit

Permalink
feat: migrated checkpoint APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
chentschel committed Oct 28, 2024
1 parent 301fd12 commit 738226d
Showing 1 changed file with 97 additions and 21 deletions.
118 changes: 97 additions & 21 deletions libs/checkpoint-vercel-kv/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ import {

// snake_case is used to match Python implementation
interface KVRow {
checkpoint: string;
metadata: string;
parent_checkpoint_id: string;
type: string;
checkpoint: Uint8Array;
metadata: Uint8Array;
}

interface KVConfig {
Expand All @@ -40,8 +42,7 @@ export class VercelKVSaver extends BaseCheckpointSaver {
* for the given thread ID is retrieved.
*/
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
const thread_id = config.configurable?.thread_id;
const checkpoint_id = config.configurable?.checkpoint_id;
const { thread_id, checkpoint_id } = config.configurable ?? {};

if (!thread_id) {
return undefined;
Expand All @@ -57,22 +58,52 @@ export class VercelKVSaver extends BaseCheckpointSaver {
return undefined;
}

const checkpointP = this.serde.loadsTyped(row.type, row.checkpoint);
const metadataP = this.serde.loadsTyped(row.type, row.metadata);

const [checkpoint, metadata] = await Promise.all([
this.serde.parse(row.checkpoint),
this.serde.parse(row.metadata),
checkpointP as Checkpoint,
metadataP as CheckpointMetadata,
]);

const pendingWrites: CheckpointPendingWrite[] = [];

// PENDING WRITES
// const serializedWrites = await this.kv.mget(
// `${thread_id}:${checkpoint_id}`
// );

// const pendingWrites: CheckpointPendingWrite[] = await Promise.all(
// serializedWrites.map(async (serializedWrite) => {
// return [
// serializedWrite.task_id,
// serializedWrite.channel,
// await this.serde.loadsTyped(
// serializedWrite.type,
// serializedWrite.value
// ),
// ] as CheckpointPendingWrite;
// })
// );

return {
checkpoint: checkpoint as Checkpoint,
metadata: metadata as CheckpointMetadata,
config: checkpoint_id
? config
: {
checkpoint,
metadata,
pendingWrites,
config: {
configurable: {
thread_id,
checkpoint_id: (checkpoint as Checkpoint).id,
},
},
parentConfig: row.parent_checkpoint_id
? {
configurable: {
thread_id,
checkpoint_id: (checkpoint as Checkpoint).id,
checkpoint_id: row.parent_checkpoint_id,
},
},
}
: undefined,
};
}

Expand Down Expand Up @@ -104,22 +135,29 @@ export class VercelKVSaver extends BaseCheckpointSaver {
// Execute the LUA script with the thread_id as an argument
const keys: string[] = await this.kv.eval(luaScript, [], [thread_id]);

// Filter keys based on the before parameter
const filteredKeys = keys.filter((key: string) => {
const [, checkpoint_id] = key.split(":");

return !before || checkpoint_id < before?.configurable?.checkpoint_id;
});

// TODO: Implement filter by metadata in the KV query.

const sortedKeys = filteredKeys
.sort((a: string, b: string) => b.localeCompare(a))
.slice(0, limit);

const rows: (KVRow | null)[] = await this.kv.mget(...sortedKeys);

for (const row of rows) {
if (row) {
const checkpointP = this.serde.loadsTyped(row.type, row.checkpoint);
const metadataP = this.serde.loadsTyped(row.type, row.metadata);

const [checkpoint, metadata] = await Promise.all([
this.serde.parse(row.checkpoint),
this.serde.parse(row.metadata),
checkpointP as Checkpoint,
metadataP as CheckpointMetadata,
]);

yield {
Expand All @@ -129,13 +167,25 @@ export class VercelKVSaver extends BaseCheckpointSaver {
checkpoint_id: (checkpoint as Checkpoint).id,
},
},
checkpoint: checkpoint as Checkpoint,
metadata: metadata as CheckpointMetadata,
checkpoint: checkpoint,
metadata: metadata,
parentConfig: row.parent_checkpoint_id
? {
configurable: {
thread_id,
checkpoint_id: row.parent_checkpoint_id,
},
}
: undefined,
};
}
}
}

/**
* Saves a checkpoint. The checkpoint is associated
* with the provided config and its parent config (if any).
*/
async put(
config: RunnableConfig,
checkpoint: Checkpoint,
Expand All @@ -147,9 +197,18 @@ export class VercelKVSaver extends BaseCheckpointSaver {
throw new Error("Thread ID and Checkpoint ID must be defined");
}

const [checkpointType, checkpointValue] = this.serde.dumpsTyped(checkpoint);
const [metadataType, metadataValue] = this.serde.dumpsTyped(metadata);

if (checkpointType !== metadataType) {
throw new Error("Mismatched checkpoint and metadata types.");
}

const row: KVRow = {
checkpoint: this.serde.stringify(checkpoint),
metadata: this.serde.stringify(metadata),
parent_checkpoint_id: config.configurable?.checkpoint_id,
type: checkpointType,
checkpoint: checkpointValue,
metadata: metadataValue,
};

// LUA script to set checkpoint data atomically"
Expand All @@ -173,6 +232,9 @@ export class VercelKVSaver extends BaseCheckpointSaver {
};
}

/**
* Saves intermediate writes associated with a checkpoint.
*/
async putWrites(
config: RunnableConfig,
writes: PendingWrite[],
Expand All @@ -191,7 +253,21 @@ export class VercelKVSaver extends BaseCheckpointSaver {
);
}

const key = `${thread_id}:${checkpoint_ns}:${checkpoint_id}:${taskId}`;
await this.kv.set(key, writes);
const values: Record<string, any> = writes.reduce(
(acc, [channel, value], idx) => {
const key = `${thread_id}:${checkpoint_id}:${taskId}:${idx}`;
const [type, serializedValue] = this.serde.dumpsTyped(value);
return {
...acc,
[key]: {
channel,
type,
value: serializedValue,
},
};
}
);

await this.kv.mset(values);
}
}

0 comments on commit 738226d

Please sign in to comment.