Skip to content

Commit

Permalink
graph: Add additional test cases to get coverage to 100%
Browse files Browse the repository at this point in the history
  • Loading branch information
tmc committed Mar 20, 2024
1 parent 65d2bc8 commit 6a2382e
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 4 deletions.
10 changes: 6 additions & 4 deletions graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,19 @@ func (g *MessageGraph) Compile() (*Runnable, error) {
}, nil
}

// Invoke executes the compiled message graph with the given input messages.
// It returns the resulting messages and an error if any occurs during the execution.
// Invoke executes the compiled message graph with the given input messages.
// It returns the resulting messages and an error if any occurs during the execution.
func (r *Runnable) Invoke(ctx context.Context, messages []llms.MessageContent) ([]llms.MessageContent, error) {
state := messages
currentNode := r.graph.entryPoint

for {
if currentNode == END {
break
}

node, ok := r.graph.nodes[currentNode]
if !ok {
return nil, fmt.Errorf("%w: %s", ErrNodeNotFound, currentNode)
Expand All @@ -117,10 +123,6 @@ func (r *Runnable) Invoke(ctx context.Context, messages []llms.MessageContent) (
return nil, fmt.Errorf("error in node %s: %w", currentNode, err)
}

if currentNode == END {
break
}

foundNext := false
for _, edge := range r.graph.edges {
if edge.From == currentNode {
Expand Down
122 changes: 122 additions & 0 deletions graph/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package graph_test

import (
"context"
"errors"
"fmt"
"testing"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/openai"
Expand Down Expand Up @@ -54,3 +56,123 @@ func ExampleMessageGraph() {
// Output:
// [{human [{What is 1 + 1?}]} {ai [{1 + 1 equals 2.}]}]
}

func TestMessageGraph(t *testing.T) {
testCases := []struct {
name string
buildGraph func() *graph.MessageGraph
inputMessages []llms.MessageContent
expectedOutput []llms.MessageContent
expectedError error
}{
{
name: "Simple graph",
buildGraph: func() *graph.MessageGraph {
g := graph.NewMessageGraph()
g.AddNode("node1", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Node 1")), nil
})
g.AddNode("node2", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Node 2")), nil
})
g.AddEdge("node1", "node2")
g.AddEdge("node2", graph.END)
g.SetEntryPoint("node1")
return g
},
inputMessages: []llms.MessageContent{llms.TextParts(schema.ChatMessageTypeHuman, "Input")},
expectedOutput: []llms.MessageContent{
llms.TextParts(schema.ChatMessageTypeHuman, "Input"),
llms.TextParts(schema.ChatMessageTypeAI, "Node 1"),
llms.TextParts(schema.ChatMessageTypeAI, "Node 2"),
},
expectedError: nil,
},
{
name: "Entry point not set",
buildGraph: func() *graph.MessageGraph {
g := graph.NewMessageGraph()
g.AddNode("node1", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return state, nil
})
return g
},
expectedError: graph.ErrEntryPointNotSet,
},
{
name: "Node not found",
buildGraph: func() *graph.MessageGraph {
g := graph.NewMessageGraph()
g.AddNode("node1", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return state, nil
})
g.AddEdge("node1", "node2")
g.SetEntryPoint("node1")
return g
},
expectedError: fmt.Errorf("%w: node2", graph.ErrNodeNotFound),
},
{
name: "No outgoing edge",
buildGraph: func() *graph.MessageGraph {
g := graph.NewMessageGraph()
g.AddNode("node1", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return state, nil
})
g.SetEntryPoint("node1")
return g
},
expectedError: fmt.Errorf("%w: node1", graph.ErrNoOutgoingEdge),
},
{
name: "Error in node function",
buildGraph: func() *graph.MessageGraph {
g := graph.NewMessageGraph()
g.AddNode("node1", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return nil, errors.New("node error")
})
g.AddEdge("node1", graph.END)
g.SetEntryPoint("node1")
return g
},
expectedError: fmt.Errorf("error in node node1: node error"),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
g := tc.buildGraph()
runnable, err := g.Compile()
if err != nil {
if tc.expectedError == nil || !errors.Is(err, tc.expectedError) {
t.Fatalf("unexpected compile error: %v", err)
}
return
}

output, err := runnable.Invoke(context.Background(), tc.inputMessages)
if err != nil {
if tc.expectedError == nil || fmt.Sprint(err) != fmt.Sprint(tc.expectedError) {
t.Fatalf("unexpected invoke error: '%v', expected '%v'", err, tc.expectedError)
}
return
}

if tc.expectedError != nil {
t.Fatalf("expected error %v, but got nil", tc.expectedError)
}

if len(output) != len(tc.expectedOutput) {
t.Fatalf("expected output length %d, but got %d", len(tc.expectedOutput), len(output))
}

for i, msg := range output {
got := fmt.Sprint(msg)
expected := fmt.Sprint(tc.expectedOutput[i])
if got != expected {
t.Errorf("expected output[%d] content %q, but got %q", i, expected, got)
}
}
})
}
}

0 comments on commit 6a2382e

Please sign in to comment.