Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

agents: internationalize and add Chinese #1000

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions agents/conversational.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,19 @@ package agents

import (
"context"
_ "embed"
"fmt"
"regexp"
"strings"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/i18n"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/prompts"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/tools"
)

const (
_conversationalFinalAnswerAction = "AI:"
)

// ConversationalAgent is a struct that represents an agent responsible for deciding
// what to do or give the final output if the task is finished given a set of inputs
// and previous steps taken.
Expand All @@ -33,7 +29,9 @@ type ConversationalAgent struct {
// Tools is a list of the tools the agent can use.
Tools []tools.Tool
// Output key is the key where the final output is placed.
OutputKey string
OutputKey string
FinalAnswer string
Language i18n.Lang
// CallbacksHandler is the handler for callbacks.
CallbacksHandler callbacks.Handler
}
Expand All @@ -45,6 +43,7 @@ func NewConversationalAgent(llm llms.Model, tools []tools.Tool, opts ...Option)
for _, opt := range opts {
opt(&options)
}
options.loadConversationalTranslatable()

return &ConversationalAgent{
Chain: chains.NewLLMChain(
Expand All @@ -54,6 +53,8 @@ func NewConversationalAgent(llm llms.Model, tools []tools.Tool, opts ...Option)
),
Tools: tools,
OutputKey: options.outputKey,
FinalAnswer: i18n.AgentsMustPhrase(options.lang, "conversational final answer"),
Language: options.lang,
CallbacksHandler: options.callbacksHandler,
}
}
Expand All @@ -69,7 +70,7 @@ func (a *ConversationalAgent) Plan(
fullInputs[key] = value
}

fullInputs["agent_scratchpad"] = constructScratchPad(intermediateSteps)
fullInputs["agent_scratchpad"] = constructScratchPad(intermediateSteps, a.Language)

var stream func(ctx context.Context, chunk []byte) error

Expand All @@ -84,7 +85,10 @@ func (a *ConversationalAgent) Plan(
ctx,
a.Chain,
fullInputs,
chains.WithStopWords([]string{"\nObservation:", "\n\tObservation:"}),
chains.WithStopWords([]string{
fmt.Sprintf("\n%s", i18n.AgentsMustPhrase(a.Language, "observation")),
fmt.Sprintf("\n\t%s", i18n.AgentsMustPhrase(a.Language, "observation")),
}),
chains.WithStreamingFunc(stream),
)
if err != nil {
Expand Down Expand Up @@ -117,22 +121,22 @@ func (a *ConversationalAgent) GetTools() []tools.Tool {
return a.Tools
}

func constructScratchPad(steps []schema.AgentStep) string {
func constructScratchPad(steps []schema.AgentStep, language i18n.Lang) string {
var scratchPad string
if len(steps) > 0 {
for _, step := range steps {
scratchPad += step.Action.Log
scratchPad += "\nObservation: " + step.Observation
scratchPad += fmt.Sprintf("\n%s %s", i18n.AgentsMustPhrase(language, "observation"), step.Observation)
}
scratchPad += "\n" + "Thought:"
scratchPad += fmt.Sprintf("\n%s", i18n.AgentsMustPhrase(language, "thought"))
}

return scratchPad
}

func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, error) {
if strings.Contains(output, _conversationalFinalAnswerAction) {
splits := strings.Split(output, _conversationalFinalAnswerAction)
if strings.Contains(output, a.FinalAnswer) {
splits := strings.Split(output, a.FinalAnswer)

finishAction := &schema.AgentFinish{
ReturnValues: map[string]any{
Expand All @@ -144,7 +148,9 @@ func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction,
return nil, finishAction, nil
}

r := regexp.MustCompile(`Action: (.*?)[\n]*Action Input: (.*)`)
action, actionInput := i18n.AgentsMustPhrase(a.Language, "action"),
i18n.AgentsMustPhrase(a.Language, "action input")
r := regexp.MustCompile(fmt.Sprintf(`%s (.*?)[\n]*%s (.*)`, action, actionInput))
matches := r.FindStringSubmatch(output)
if len(matches) == 0 {
return nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output)
Expand All @@ -155,15 +161,6 @@ func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction,
}, nil, nil
}

//go:embed prompts/conversational_prefix.txt
var _defaultConversationalPrefix string //nolint:gochecknoglobals

//go:embed prompts/conversational_format_instructions.txt
var _defaultConversationalFormatInstructions string //nolint:gochecknoglobals

//go:embed prompts/conversational_suffix.txt
var _defaultConversationalSuffix string //nolint:gochecknoglobals

func createConversationalPrompt(tools []tools.Tool, prefix, instructions, suffix string) prompts.PromptTemplate {
template := strings.Join([]string{prefix, instructions, suffix}, "\n\n")

Expand Down
1 change: 1 addition & 0 deletions agents/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func NewExecutor(agent Agent, opts ...Option) *Executor {
for _, opt := range opts {
opt(&options)
}
options.loadExecutorTranslatable()

return &Executor{
Agent: agent,
Expand Down
31 changes: 19 additions & 12 deletions agents/mrkl.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@ import (

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/i18n"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/tools"
)

const (
_finalAnswerAction = "Final Answer:"
_defaultOutputKey = "output"
)

// OneShotZeroAgent is a struct that represents an agent responsible for deciding
// what to do or give the final output if the task is finished given a set of inputs
// and previous steps taken.
Expand All @@ -31,7 +27,9 @@ type OneShotZeroAgent struct {
// Tools is a list of the tools the agent can use.
Tools []tools.Tool
// Output key is the key where the final output is placed.
OutputKey string
OutputKey string
FinalAnswer string
Lang i18n.Lang
// CallbacksHandler is the handler for callbacks.
CallbacksHandler callbacks.Handler
}
Expand All @@ -46,6 +44,7 @@ func NewOneShotAgent(llm llms.Model, tools []tools.Tool, opts ...Option) *OneSho
for _, opt := range opts {
opt(&options)
}
options.loadMrklTranslatable()

return &OneShotZeroAgent{
Chain: chains.NewLLMChain(
Expand All @@ -55,6 +54,8 @@ func NewOneShotAgent(llm llms.Model, tools []tools.Tool, opts ...Option) *OneSho
),
Tools: tools,
OutputKey: options.outputKey,
FinalAnswer: i18n.AgentsMustPhrase(options.lang, "mrkl final answer"),
Lang: options.lang,
CallbacksHandler: options.callbacksHandler,
}
}
Expand All @@ -70,8 +71,8 @@ func (a *OneShotZeroAgent) Plan(
fullInputs[key] = value
}

fullInputs["agent_scratchpad"] = constructScratchPad(intermediateSteps)
fullInputs["today"] = time.Now().Format("January 02, 2006")
fullInputs["agent_scratchpad"] = constructScratchPad(intermediateSteps, a.Lang)
fullInputs["today"] = time.Now().Format(i18n.AgentsMustPhrase(a.Lang, "today format"))

var stream func(ctx context.Context, chunk []byte) error

Expand All @@ -86,7 +87,10 @@ func (a *OneShotZeroAgent) Plan(
ctx,
a.Chain,
fullInputs,
chains.WithStopWords([]string{"\nObservation:", "\n\tObservation:"}),
chains.WithStopWords([]string{
fmt.Sprintf("\n%s", i18n.AgentsMustPhrase(a.Lang, "observation")),
fmt.Sprintf("\n\t%s", i18n.AgentsMustPhrase(a.Lang, "observation")),
}),
chains.WithStreamingFunc(stream),
)
if err != nil {
Expand Down Expand Up @@ -120,8 +124,8 @@ func (a *OneShotZeroAgent) GetTools() []tools.Tool {
}

func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, error) {
if strings.Contains(output, _finalAnswerAction) {
splits := strings.Split(output, _finalAnswerAction)
if strings.Contains(output, a.FinalAnswer) {
splits := strings.Split(output, a.FinalAnswer)

return nil, &schema.AgentFinish{
ReturnValues: map[string]any{
Expand All @@ -131,7 +135,10 @@ func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *sc
}, nil
}

r := regexp.MustCompile(`Action:\s*(.+)\s*Action Input:\s(?s)*(.+)`)
action, actionInput, observation := i18n.AgentsMustPhrase(a.Lang, "action"),
i18n.AgentsMustPhrase(a.Lang, "action input"),
i18n.AgentsMustPhrase(a.Lang, "observation")
r := regexp.MustCompile(fmt.Sprintf(`%s\s*(.+)\s*%s\s(?s)*([^%s]+)`, action, actionInput, observation))
matches := r.FindStringSubmatch(output)
if len(matches) == 0 {
return nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output)
Expand Down
23 changes: 0 additions & 23 deletions agents/mrkl_prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,6 @@ import (
"github.com/tmc/langchaingo/tools"
)

const (
_defaultMrklPrefix = `Today is {{.today}}.
Answer the following questions as best you can. You have access to the following tools:

{{.tool_descriptions}}`

_defaultMrklFormatInstructions = `Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [ {{.tool_names}} ]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question`

_defaultMrklSuffix = `Begin!

Question: {{.input}}
Thought:{{.agent_scratchpad}}`
)

func createMRKLPrompt(tools []tools.Tool, prefix, instructions, suffix string) prompts.PromptTemplate {
template := strings.Join([]string{prefix, instructions, suffix}, "\n\n")

Expand Down
5 changes: 4 additions & 1 deletion agents/markl_test.go → agents/mrkl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/i18n"
"github.com/tmc/langchaingo/schema"
)

Expand Down Expand Up @@ -38,7 +39,9 @@ func TestMRKLOutputParser(t *testing.T) {
},
}

a := OneShotZeroAgent{}
a := OneShotZeroAgent{
Lang: i18n.EN,
}
for _, tc := range testCases {
actions, finish, err := a.parseOutput(tc.input)
require.ErrorIs(t, tc.expectedErr, err)
Expand Down
1 change: 1 addition & 0 deletions agents/openai_functions_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func NewOpenAIFunctionsAgent(llm llms.Model, tools []tools.Tool, opts ...Option)
for _, opt := range opts {
opt(&options)
}
options.loadOpenAIFunctionsTranslatable()

return &OpenAIFunctionsAgent{
LLM: llm,
Expand Down
Loading
Loading