Skip to content

Commit

Permalink
Add Annotation.Root to make it easier to access State, Update and Nod…
Browse files Browse the repository at this point in the history
…e types (#307)

* Add Annotation.Root to make it easier to access State, Update and Node types

* Use declare to prevent overwriting, add test

* Lint

* Update test

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
nfcampos and jacoblee93 authored Aug 12, 2024
1 parent 1e66bc2 commit ffef128
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 97 deletions.
3 changes: 2 additions & 1 deletion langgraph/.eslintrc.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module.exports = {
project: "./tsconfig.json",
sourceType: "module",
},
plugins: ["@typescript-eslint", "no-instanceof"],
plugins: ["@typescript-eslint", "no-instanceof", "eslint-plugin-jest"],
ignorePatterns: [
".eslintrc.cjs",
"scripts",
Expand Down Expand Up @@ -43,6 +43,7 @@ module.exports = {
],
"import/no-unresolved": 0,
"import/prefer-default-export": 0,
'jest/no-focused-tests': 'error',
"keyword-spacing": "error",
"max-classes-per-file": 0,
"max-len": 0,
Expand Down
1 change: 1 addition & 0 deletions langgraph/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"eslint-config-airbnb-base": "^15.0.0",
"eslint-config-prettier": "^8.6.0",
"eslint-plugin-import": "^2.29.1",
"eslint-plugin-jest": "^28.8.0",
"eslint-plugin-no-instanceof": "^1.0.1",
"eslint-plugin-prettier": "^4.2.1",
"jest": "^29.5.0",
Expand Down
104 changes: 104 additions & 0 deletions langgraph/src/graph/annotation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import { RunnableLike } from "@langchain/core/runnables";
import { BaseChannel } from "../channels/base.js";
import { BinaryOperator, BinaryOperatorAggregate } from "../channels/binop.js";
import { LastValue } from "../channels/last_value.js";

export type SingleReducer<ValueType, UpdateType = ValueType> =
| {
reducer: BinaryOperator<ValueType, UpdateType>;
default?: () => ValueType;
}
| {
/**
* @deprecated Use `reducer` instead
*/
value: BinaryOperator<ValueType, UpdateType>;
default?: () => ValueType;
}
| null;

export interface StateDefinition {
[key: string]: BaseChannel | (() => BaseChannel);
}

type ExtractValueType<C> = C extends BaseChannel
? C["ValueType"]
: C extends () => BaseChannel
? ReturnType<C>["ValueType"]
: never;

type ExtractUpdateType<C> = C extends BaseChannel
? C["UpdateType"]
: C extends () => BaseChannel
? ReturnType<C>["UpdateType"]
: never;

export type StateType<SD extends StateDefinition> = {
[key in keyof SD]: ExtractValueType<SD[key]>;
};

export type UpdateType<SD extends StateDefinition> = {
[key in keyof SD]?: ExtractUpdateType<SD[key]>;
};

export type NodeType<SD extends StateDefinition> = RunnableLike<
StateType<SD>,
UpdateType<SD>
>;

export class AnnotationRoot<SD extends StateDefinition> {
lc_graph_name = "AnnotationRoot";

declare State: StateType<SD>;

declare Update: UpdateType<SD>;

declare Node: NodeType<SD>;

spec: SD;

constructor(s: SD) {
this.spec = s;
}
}

export function Annotation<ValueType>(): LastValue<ValueType>;

export function Annotation<ValueType, UpdateType = ValueType>(
annotation: SingleReducer<ValueType, UpdateType>
): BinaryOperatorAggregate<ValueType, UpdateType>;

export function Annotation<ValueType, UpdateType = ValueType>(
annotation?: SingleReducer<ValueType, UpdateType>
): BaseChannel<ValueType, UpdateType> {
if (annotation) {
return getChannel<ValueType, UpdateType>(annotation);
} else {
// @ts-expect-error - Annotation without reducer
return new LastValue<ValueType>();
}
}
Annotation.Root = <S extends StateDefinition>(sd: S) => new AnnotationRoot(sd);

export function getChannel<V, U = V>(
reducer: SingleReducer<V, U>
): BaseChannel<V, U> {
if (
typeof reducer === "object" &&
reducer &&
"reducer" in reducer &&
reducer.reducer
) {
return new BinaryOperatorAggregate(reducer.reducer, reducer.default);
}
if (
typeof reducer === "object" &&
reducer &&
"value" in reducer &&
reducer.value
) {
return new BinaryOperatorAggregate(reducer.value, reducer.default);
}
// @ts-expect-error - Annotation without reducer
return new LastValue<V>();
}
4 changes: 1 addition & 3 deletions langgraph/src/graph/index.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
export { Annotation, type StateType, type UpdateType } from "./annotation.js";
export { END, START, Graph } from "./graph.js";
export {
type StateGraphArgs,
StateGraph,
type CompiledStateGraph,
Annotation,
type StateType,
type UpdateType,
} from "./state.js";
export { MessageGraph, messagesStateReducer } from "./message.js";
2 changes: 1 addition & 1 deletion langgraph/src/graph/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type Messages =
| BaseMessageLike;

export function messagesStateReducer(
left: Messages,
left: BaseMessage[],
right: Messages
): BaseMessage[] {
const leftArray = Array.isArray(left) ? left : [left];
Expand Down
104 changes: 23 additions & 81 deletions langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ import {
RunnableLike,
} from "@langchain/core/runnables";
import { BaseChannel } from "../channels/base.js";
import { BinaryOperator, BinaryOperatorAggregate } from "../channels/binop.js";
import { END, CompiledGraph, Graph, START, Branch } from "./graph.js";
import { LastValue } from "../channels/last_value.js";
import {
ChannelWrite,
ChannelWriteEntry,
Expand All @@ -22,64 +20,17 @@ import { RunnableCallable } from "../utils.js";
import { All } from "../pregel/types.js";
import { _isSend, Send, TAG_HIDDEN } from "../constants.js";
import { InvalidUpdateError } from "../errors.js";
import {
AnnotationRoot,
getChannel,
SingleReducer,
StateDefinition,
StateType,
UpdateType,
} from "./annotation.js";

const ROOT = "__root__";

export function Annotation<ValueType>(): LastValue<ValueType>;

export function Annotation<ValueType, UpdateType = ValueType>(
annotation: SingleReducer<ValueType, UpdateType>
): BinaryOperatorAggregate<ValueType, UpdateType>;

export function Annotation<ValueType, UpdateType = ValueType>(
annotation?: SingleReducer<ValueType, UpdateType>
): BaseChannel<ValueType, UpdateType> {
if (annotation) {
return getChannel<ValueType, UpdateType>(annotation);
} else {
// @ts-expect-error - Annotation without reducer
return new LastValue<ValueType>();
}
}

interface StateDefinition {
[key: string]: BaseChannel | (() => BaseChannel);
}

type ExtractValueType<C> = C extends BaseChannel
? C["ValueType"]
: C extends () => BaseChannel
? ReturnType<C>["ValueType"]
: never;

type ExtractUpdateType<C> = C extends BaseChannel
? C["UpdateType"]
: C extends () => BaseChannel
? ReturnType<C>["UpdateType"]
: never;

export type StateType<S extends StateDefinition> = {
[key in keyof S]: ExtractValueType<S[key]>;
};

export type UpdateType<S extends StateDefinition> = {
[key in keyof S]?: ExtractUpdateType<S[key]>;
};

type SingleReducer<ValueType, UpdateType = ValueType> =
| {
reducer: BinaryOperator<ValueType, UpdateType>;
default?: () => ValueType;
}
| {
/**
* @deprecated Use `reducer` instead
*/
value: BinaryOperator<ValueType, UpdateType>;
default?: () => ValueType;
}
| null;

export type ChannelReducers<Channels extends object> = {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[K in keyof Channels]: SingleReducer<Channels[K], any>;
Expand All @@ -106,13 +57,14 @@ export class StateGraph<

constructor(
fields: SD extends StateDefinition
? SD | StateGraphArgs<S>
? SD | AnnotationRoot<SD> | StateGraphArgs<S>
: StateGraphArgs<S>
) {
super();
if (isStateDefinition(fields)) {
if (isStateDefinition(fields) || isAnnotationRoot(fields)) {
const spec = isAnnotationRoot(fields) ? fields.spec : fields;
this.channels = {};
for (const [key, val] of Object.entries(fields)) {
for (const [key, val] of Object.entries(spec)) {
if (typeof val === "function") {
this.channels[key] = val();
} else {
Expand Down Expand Up @@ -261,27 +213,6 @@ function _getChannels<Channels extends Record<string, unknown> | unknown>(
return channels;
}

function getChannel<V, U = V>(reducer: SingleReducer<V, U>): BaseChannel<V, U> {
if (
typeof reducer === "object" &&
reducer &&
"reducer" in reducer &&
reducer.reducer
) {
return new BinaryOperatorAggregate(reducer.reducer, reducer.default);
}
if (
typeof reducer === "object" &&
reducer &&
"value" in reducer &&
reducer.value
) {
return new BinaryOperatorAggregate(reducer.value, reducer.default);
}
// @ts-expect-error - Annotation without reducer
return new LastValue<V>();
}

export class CompiledStateGraph<
S,
U,
Expand Down Expand Up @@ -450,3 +381,14 @@ function isStateDefinition(obj: unknown): obj is StateDefinition {
Object.values(obj).every((v) => typeof v === "function" || isBaseChannel(v))
);
}

function isAnnotationRoot<SD extends StateDefinition>(
obj: unknown | AnnotationRoot<SD>
): obj is AnnotationRoot<SD> {
return (
typeof obj === "object" &&
obj !== null &&
"lc_graph_name" in obj &&
obj.lc_graph_name === "AnnotationRoot"
);
}
18 changes: 13 additions & 5 deletions langgraph/src/tests/graph.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
/* eslint-disable @typescript-eslint/no-unused-vars */
import { describe, it, expect } from "@jest/globals";
import { Annotation, StateGraph } from "../graph/state.js";
import { StateGraph } from "../graph/state.js";
import { END, START } from "../web.js";
import { Annotation } from "../graph/annotation.js";

describe("State", () => {
it("should validate a new node key correctly ", () => {
Expand All @@ -19,17 +21,23 @@ describe("State", () => {
});

it("should allow reducers with different argument types", async () => {
const State = {
const StateAnnotation = Annotation.Root({
val: Annotation<number>,
testval: Annotation<string[], string>({
reducer: (left, right) =>
right ? left.concat([right.toString()]) : left,
}),
};
const stateGraph = new StateGraph(State);
});
const stateGraph = new StateGraph(StateAnnotation);

const graph = stateGraph
.addNode("testnode", (_) => ({ testval: "hi!", val: 3 }))
.addNode("testnode", (state: typeof StateAnnotation.State) => {
// Should properly be typed as string
state.testval.concat(["stringval"]);
// @ts-expect-error Should be typed as a number
const valValue: string | undefined | null = state.val;
return { testval: "hi!", val: 3 };
})
.addEdge(START, "testnode")
.addEdge("testnode", END)
.compile();
Expand Down
10 changes: 5 additions & 5 deletions langgraph/src/tests/pregel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1901,12 +1901,12 @@ describe("StateGraph", () => {
});
});

it.only("State graph packets", async () => {
const AgentState = {
it("State graph packets", async () => {
const AgentState = Annotation.Root({
messages: Annotation({
reducer: messagesStateReducer,
}),
};
});
const searchApi = tool(
async ({ query }) => {
return `result for ${query}`;
Expand Down Expand Up @@ -1960,13 +1960,13 @@ describe("StateGraph", () => {
],
});

const agent = async (state: StateType<typeof AgentState>) => {
const agent = async (state: typeof AgentState.State) => {
return {
messages: await model.invoke(state.messages),
};
};

const shouldContinue = async (state: StateType<typeof AgentState>) => {
const shouldContinue = async (state: typeof AgentState.State) => {
// TODO: Support this?
// expect(state.something_extra).toEqual("hi there");
const toolCalls = (state.messages[state.messages.length - 1] as AIMessage)
Expand Down
Loading

0 comments on commit ffef128

Please sign in to comment.