Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

agents: add CallOptions to Plan() (issue #353) #857

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ service-account.json

embeddings/cybertron/models/*
examples/cybertron-embedding-example/models/*
!/.idea/
3 changes: 2 additions & 1 deletion agents/agents.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package agents
import (
"context"

"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/tools"
)
Expand All @@ -11,7 +12,7 @@ import (
type Agent interface {
// Plan Given an input and previous steps decide what to do next. Returns
// either actions or a finish.
Plan(ctx context.Context, intermediateSteps []schema.AgentStep, inputs map[string]string) ([]schema.AgentAction, *schema.AgentFinish, error) //nolint:lll
Plan(ctx context.Context, intermediateSteps []schema.AgentStep, inputs map[string]string, options ...chains.ChainCallOption) ([]schema.AgentAction, *schema.AgentFinish, error) //nolint:lll
GetInputKeys() []string
GetOutputKeys() []string
GetTools() []tools.Tool
Expand Down
7 changes: 5 additions & 2 deletions agents/conversational.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func (a *ConversationalAgent) Plan(
ctx context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
options ...chains.ChainCallOption,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
fullInputs := make(map[string]any, len(inputs))
for key, value := range inputs {
Expand All @@ -80,12 +81,14 @@ func (a *ConversationalAgent) Plan(
}
}

options = append(options, chains.WithStopWords([]string{"\nObservation:", "\n\tObservation:"}))
options = append(options, chains.WithStreamingFunc(stream))

output, err := chains.Predict(
ctx,
a.Chain,
fullInputs,
chains.WithStopWords([]string{"\nObservation:", "\n\tObservation:"}),
chains.WithStreamingFunc(stream),
options...,
)
if err != nil {
return nil, nil, err
Expand Down
7 changes: 2 additions & 5 deletions agents/conversational_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,10 @@ func TestConversationalWithMemory(t *testing.T) {
llm, err := openai.New(openai.WithModel("gpt-4"))
require.NoError(t, err)

executor, err := Initialize(
llm,
[]tools.Tool{tools.Calculator{}},
ConversationalReactDescription,
executor := NewExecutor(
NewConversationalAgent(llm, []tools.Tool{tools.Calculator{}}),
WithMemory(memory.NewConversationBuffer()),
)
require.NoError(t, err)

_, err = chains.Run(context.Background(), executor, "Hi! my name is Bob and the year I was born is 1987")
require.NoError(t, err)
Expand Down
7 changes: 4 additions & 3 deletions agents/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func NewExecutor(agent Agent, opts ...Option) *Executor {
}
}

func (e *Executor) Call(ctx context.Context, inputValues map[string]any, _ ...chains.ChainCallOption) (map[string]any, error) { //nolint:lll
func (e *Executor) Call(ctx context.Context, inputValues map[string]any, options ...chains.ChainCallOption) (map[string]any, error) { //nolint:lll
inputs, err := inputsToString(inputValues)
if err != nil {
return nil, err
Expand All @@ -57,7 +57,7 @@ func (e *Executor) Call(ctx context.Context, inputValues map[string]any, _ ...ch
steps := make([]schema.AgentStep, 0)
for i := 0; i < e.MaxIterations; i++ {
var finish map[string]any
steps, finish, err = e.doIteration(ctx, steps, nameToTool, inputs)
steps, finish, err = e.doIteration(ctx, steps, nameToTool, inputs, options...)
if finish != nil || err != nil {
return finish, err
}
Expand All @@ -79,8 +79,9 @@ func (e *Executor) doIteration( // nolint
steps []schema.AgentStep,
nameToTool map[string]tools.Tool,
inputs map[string]string,
options ...chains.ChainCallOption,
) ([]schema.AgentStep, map[string]any, error) {
actions, finish, err := e.Agent.Plan(ctx, steps, inputs)
actions, finish, err := e.Agent.Plan(ctx, steps, inputs, options...)
if errors.Is(err, ErrUnableToParseOutput) && e.ErrorHandler != nil {
formattedObservation := err.Error()
if e.ErrorHandler.Formatter != nil {
Expand Down
11 changes: 4 additions & 7 deletions agents/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func (a *testAgent) Plan(
_ context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
_ ...chains.ChainCallOption,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
a.recordedIntermediateSteps = intermediateSteps
a.recordedInputs = inputs
Expand Down Expand Up @@ -90,15 +91,11 @@ func TestExecutorWithMRKLAgent(t *testing.T) {
require.NoError(t, err)

calculator := tools.Calculator{}

a, err := agents.Initialize(
llm,
[]tools.Tool{searchTool, calculator},
agents.ZeroShotReactDescription,
executor := agents.NewExecutor(
agents.NewOneShotAgent(llm, []tools.Tool{searchTool, calculator}),
)
require.NoError(t, err)

result, err := chains.Run(context.Background(), a, "If a person lived three times as long as Jacklyn Zeman, how long would they live") //nolint:lll
result, err := chains.Run(context.Background(), executor, "If a person lived three times as long as Jacklyn Zeman, how long would they live") //nolint:lll
require.NoError(t, err)

require.True(t, strings.Contains(result, "210"), "correct answer 210 not in response")
Expand Down
7 changes: 5 additions & 2 deletions agents/mrkl.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func (a *OneShotZeroAgent) Plan(
ctx context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
options ...chains.ChainCallOption,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
fullInputs := make(map[string]any, len(inputs))
for key, value := range inputs {
Expand All @@ -82,12 +83,14 @@ func (a *OneShotZeroAgent) Plan(
}
}

options = append(options, chains.WithStopWords([]string{"\nObservation:", "\n\tObservation:"}))
options = append(options, chains.WithStreamingFunc(stream))

output, err := chains.Predict(
ctx,
a.Chain,
fullInputs,
chains.WithStopWords([]string{"\nObservation:", "\n\tObservation:"}),
chains.WithStreamingFunc(stream),
options...,
)
if err != nil {
return nil, nil, err
Expand Down
3 changes: 3 additions & 0 deletions agents/openai_functions_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/prompts"
"github.com/tmc/langchaingo/schema"
Expand Down Expand Up @@ -67,10 +68,12 @@ func (o *OpenAIFunctionsAgent) functions() []llms.FunctionDefinition {
}

// Plan decides what action to take or returns the final result of the input.
// TODO(fix): add {options ...chains.ChainCallOption} to llm request.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not implement this here?

func (o *OpenAIFunctionsAgent) Plan(
ctx context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
_ ...chains.ChainCallOption,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
fullInputs := make(map[string]any, len(inputs))
for key, value := range inputs {
Expand Down
Loading