Skip to content

Commit

Permalink
feat: chat model adapt new tool choice definition
Browse files Browse the repository at this point in the history
  • Loading branch information
N3kox committed Jan 23, 2025
1 parent 815f7d8 commit ee0ff42
Show file tree
Hide file tree
Showing 15 changed files with 137 additions and 68 deletions.
34 changes: 22 additions & 12 deletions components/model/ark/chatmodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ func (cm *ChatModel) genRequest(in []*schema.Message, opts ...fmodel.Option) (re
Model: &cm.config.Model,
TopP: cm.config.TopP,
Stop: cm.config.Stop,
Tools: nil,
}, opts...)

req = &model.ChatCompletionRequest{
Expand Down Expand Up @@ -341,19 +342,28 @@ func (cm *ChatModel) genRequest(in []*schema.Message, opts ...fmodel.Option) (re
})
}

req.Tools = make([]*model.Tool, 0, len(cm.tools))

for _, tool := range cm.tools {
arkTool := &model.Tool{
Type: model.ToolTypeFunction,
Function: &model.FunctionDefinition{
Name: tool.Function.Name,
Description: tool.Function.Description,
Parameters: tool.Function.Parameters,
},
tools := cm.tools
if options.Tools != nil {
if tools, err = toTools(options.Tools); err != nil {
return nil, err
}
}

if tools != nil {
req.Tools = make([]*model.Tool, 0, len(cm.tools))

req.Tools = append(req.Tools, arkTool)
for _, tool := range cm.tools {
arkTool := &model.Tool{
Type: model.ToolTypeFunction,
Function: &model.FunctionDefinition{
Name: tool.Function.Name,
Description: tool.Function.Description,
Parameters: tool.Function.Parameters,
},
}

req.Tools = append(req.Tools, arkTool)
}
}

return req, nil
Expand Down Expand Up @@ -564,7 +574,7 @@ func toTools(tls []*schema.ToolInfo) ([]tool, error) {
for i := range tls {
ti := tls[i]
if ti == nil {
return nil, fmt.Errorf("tool info cannot be nil in BindTools")
return nil, fmt.Errorf("tool info cannot be nil")
}

paramsJSONSchema, err := ti.ParamsOneOf.ToOpenAPIV3()
Expand Down
2 changes: 1 addition & 1 deletion components/model/ark/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.18

require (
github.com/bytedance/mockey v1.2.10
github.com/cloudwego/eino v0.3.4
github.com/cloudwego/eino v0.3.8-0.20250121122240-6d988e43f988
github.com/getkin/kin-openapi v0.118.0
github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.9.0
Expand Down
2 changes: 2 additions & 0 deletions components/model/ark/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/eino v0.3.4 h1:trWw8lKU1t1b7PMKSW1GXEJ4H2rLiGWFyVoMJJ3pRDg=
github.com/cloudwego/eino v0.3.4/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo=
github.com/cloudwego/eino v0.3.8-0.20250121122240-6d988e43f988 h1:DytkzO6vEF+0w0MwAmopW8SHzQir1B8iFmrdpfWqPyg=
github.com/cloudwego/eino v0.3.8-0.20250121122240-6d988e43f988/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
34 changes: 29 additions & 5 deletions components/model/claude/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,21 +220,35 @@ func (c *claude) Stream(ctx context.Context, input []*schema.Message, opts ...mo
}

func (c *claude) BindTools(tools []*schema.ToolInfo) error {
result, err := toAnthropicToolParam(tools)
if err != nil {
return err
}

c.tools = result
c.origTools = tools
return nil
}

func toAnthropicToolParam(tools []*schema.ToolInfo) ([]anthropic.ToolParam, error) {
if len(tools) == 0 {
return nil, nil
}

result := make([]anthropic.ToolParam, 0, len(tools))
for _, tool := range tools {
s, err := tool.ToOpenAPIV3()
if err != nil {
return fmt.Errorf("convert to openapi v3 schema fail: %w", err)
return nil, fmt.Errorf("convert to openapi v3 schema fail: %w", err)
}
result = append(result, anthropic.ToolParam{
Name: anthropic.F(tool.Name),
Description: anthropic.F(tool.Desc),
InputSchema: anthropic.F[any](s),
})
}
c.tools = result
c.origTools = tools
return nil

return result, nil
}

func (c *claude) genMessageNewParams(input []*schema.Message, opts ...model.Option) (anthropic.MessageNewParams, error) {
Expand All @@ -248,6 +262,7 @@ func (c *claude) genMessageNewParams(input []*schema.Message, opts ...model.Opti
MaxTokens: &c.maxTokens,
TopP: c.topP,
Stop: c.stopSequences,
Tools: nil,
}, opts...)
claudeOptions := model.GetImplSpecificOptions(&options{TopK: c.topK}, opts...)

Expand All @@ -270,7 +285,16 @@ func (c *claude) genMessageNewParams(input []*schema.Message, opts ...model.Opti
if claudeOptions.TopK != nil {
param.TopK = anthropic.F(int64(*claudeOptions.TopK))
}
if len(c.tools) > 0 {

tools := c.tools
if commonOptions.Tools != nil {
var err error
if tools, err = toAnthropicToolParam(commonOptions.Tools); err != nil {
return anthropic.MessageNewParams{}, err
}
}

if len(tools) > 0 {
param.Tools = anthropic.F(c.tools)
}

Expand Down
2 changes: 1 addition & 1 deletion components/model/claude/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ toolchain go1.22.2
require (
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.8
github.com/bytedance/mockey v1.2.13
github.com/cloudwego/eino v0.3.2
github.com/cloudwego/eino v0.3.8-0.20250121122240-6d988e43f988
github.com/getkin/kin-openapi v0.118.0
github.com/stretchr/testify v1.9.0
)
Expand Down
2 changes: 2 additions & 0 deletions components/model/claude/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/eino v0.3.2 h1:GaMqt3zJAee8ybN4qsATNgSIDAbNruzKCMeMKBH4F1E=
github.com/cloudwego/eino v0.3.2/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo=
github.com/cloudwego/eino v0.3.8-0.20250121122240-6d988e43f988 h1:DytkzO6vEF+0w0MwAmopW8SHzQir1B8iFmrdpfWqPyg=
github.com/cloudwego/eino v0.3.8-0.20250121122240-6d988e43f988/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
27 changes: 18 additions & 9 deletions components/model/ollama/chatmodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,22 @@ func (cm *ChatModel) IsCallbacksEnabled() bool {
}

func (cm *ChatModel) genRequest(ctx context.Context, stream bool, in []*schema.Message, opts ...model.Option) (req *api.ChatRequest, modelConfig *model.Config, err error) {
commonOptions := model.GetCommonOptions(&model.Options{}, opts...)
specificOptions := model.GetImplSpecificOptions(&options{}, opts...)
var (
o = &options{}
mo = &model.Options{
Model: &cm.config.Model,
Tools: cm.tools,
}
)
if cm.config.Options != nil {
mo.Temperature = &cm.config.Options.Temperature
mo.TopP = &cm.config.Options.TopP
mo.Stop = cm.config.Options.Stop
o.Seed = &cm.config.Options.Seed
}

commonOptions := model.GetCommonOptions(mo, opts...)
specificOptions := model.GetImplSpecificOptions(o, opts...)

ollamaOptions := &api.Options{}
conf := cm.config.Options
Expand All @@ -230,11 +244,6 @@ func (cm *ChatModel) genRequest(ctx context.Context, stream bool, in []*schema.M
ollamaOptions.Seed = *specificOptions.Seed
}

modelName := cm.config.Model
if commonOptions.Model != nil {
modelName = *commonOptions.Model
}

reqOptions := make(map[string]any, 5)
optBytes, err := json.Marshal(ollamaOptions)
if err != nil {
Expand All @@ -250,13 +259,13 @@ func (cm *ChatModel) genRequest(ctx context.Context, stream bool, in []*schema.M
return nil, nil, fmt.Errorf("error convert messages: %w", err)
}

tools, err := toOllamaTools(cm.tools)
tools, err := toOllamaTools(mo.Tools)
if err != nil {
return nil, nil, fmt.Errorf("error convert tools: %w", err)
}

req = &api.ChatRequest{
Model: modelName,
Model: *commonOptions.Model,
Messages: msgs,
Stream: ptrOf(stream),
Format: cm.config.Format,
Expand Down
2 changes: 1 addition & 1 deletion components/model/ollama/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.22.0

require (
github.com/bytedance/mockey v1.2.13
github.com/cloudwego/eino v0.3.4
github.com/cloudwego/eino v0.3.8-0.20250121122240-6d988e43f988
github.com/ollama/ollama v0.3.0
github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.9.0
Expand Down
4 changes: 2 additions & 2 deletions components/model/ollama/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4
github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/eino v0.3.4 h1:trWw8lKU1t1b7PMKSW1GXEJ4H2rLiGWFyVoMJJ3pRDg=
github.com/cloudwego/eino v0.3.4/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo=
github.com/cloudwego/eino v0.3.8-0.20250121122240-6d988e43f988 h1:DytkzO6vEF+0w0MwAmopW8SHzQir1B8iFmrdpfWqPyg=
github.com/cloudwego/eino v0.3.8-0.20250121122240-6d988e43f988/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
10 changes: 9 additions & 1 deletion components/model/qianfan/chatmodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ func (c *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts ..
Model: &c.config.Model,
TopP: c.config.TopP,
Stop: c.config.Stop,
Tools: nil,
ToolChoice: c.toolChoice,
}, opts...)

Expand All @@ -216,6 +217,13 @@ func (c *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts ..
Config: cfg,
})

tools := c.tools
if options.Tools != nil {
if tools, err = toQianfanTools(options.Tools); err != nil {
return nil, err
}
}

req := &qianfan.ChatCompletionV2Request{
BaseRequestBody: qianfan.BaseRequestBody{},
Model: *options.Model,
Expand All @@ -230,7 +238,7 @@ func (c *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts ..
User: dereferenceOrZero(c.config.User),
FrequencyPenalty: dereferenceOrZero(c.config.FrequencyPenalty),
PresencePenalty: dereferenceOrZero(c.config.PresencePenalty),
Tools: c.tools,
Tools: tools,
ParallelToolCalls: dereferenceOrZero(c.config.ParallelToolCalls),
ResponseFormat: c.config.ResponseFormat,
}
Expand Down
2 changes: 1 addition & 1 deletion components/model/qianfan/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.18
require (
github.com/baidubce/bce-qianfan-sdk/go/qianfan v0.0.14
github.com/bytedance/mockey v1.2.13
github.com/cloudwego/eino v0.3.8-0.20250117083911-81a6676a6157
github.com/cloudwego/eino v0.3.8-0.20250121122240-6d988e43f988
github.com/smartystreets/goconvey v1.8.1
)

Expand Down
2 changes: 2 additions & 0 deletions components/model/qianfan/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ github.com/cloudwego/eino v0.3.4 h1:trWw8lKU1t1b7PMKSW1GXEJ4H2rLiGWFyVoMJJ3pRDg=
github.com/cloudwego/eino v0.3.4/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo=
github.com/cloudwego/eino v0.3.8-0.20250117083911-81a6676a6157 h1:PO1XjvUgTFOlMCp5o3EBpE3P70f7d1393ysnj+C6JOU=
github.com/cloudwego/eino v0.3.8-0.20250117083911-81a6676a6157/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo=
github.com/cloudwego/eino v0.3.8-0.20250121122240-6d988e43f988 h1:DytkzO6vEF+0w0MwAmopW8SHzQir1B8iFmrdpfWqPyg=
github.com/cloudwego/eino v0.3.8-0.20250121122240-6d988e43f988/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
78 changes: 44 additions & 34 deletions libs/acl/openai/chat_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,18 @@ func (cm *Client) genRequest(in []*schema.Message, options *model.Options) (*ope
User: dereferenceOrZero(cm.config.User),
}

if len(cm.tools) > 0 {
tools := cm.tools
if options.Tools != nil {
var err error
if tools, err = toTools(options.Tools); err != nil {
return nil, err
}
}

if len(tools) > 0 {
req.Tools = make([]openai.Tool, len(cm.tools))
for i := range cm.tools {
t := cm.tools[i]
for i := range tools {
t := tools[i]

req.Tools[i] = openai.Tool{
Type: openai.ToolTypeFunction,
Expand All @@ -315,40 +323,40 @@ func (cm *Client) genRequest(in []*schema.Message, options *model.Options) (*ope
},
}
}
}

if options.ToolChoice != nil {
/*
tool_choice is string or object
Controls which (if any) tool is called by the model.
"none" means the model will not call any tool and instead generates a message.
"auto" means the model can pick between generating a message or calling one or more tools.
"required" means the model must call one or more tools.
Specifying a particular tool via {"type": "function", "function": {"name": "my_function"}} forces the model to call that tool.
"none" is the default when no tools are present.
"auto" is the default if tools are present.
*/

switch *options.ToolChoice {
case schema.ToolChoiceForbidden:
req.ToolChoice = toolChoiceNone
case schema.ToolChoiceAllowed:
req.ToolChoice = toolChoiceAuto
case schema.ToolChoiceForced:
if len(req.Tools) > 1 {
req.ToolChoice = toolChoiceRequired
} else {
req.ToolChoice = openai.ToolChoice{
Type: req.Tools[0].Type,
Function: openai.ToolFunction{
Name: req.Tools[0].Function.Name,
},
}
if options.ToolChoice != nil {
/*
tool_choice is string or object
Controls which (if any) tool is called by the model.
"none" means the model will not call any tool and instead generates a message.
"auto" means the model can pick between generating a message or calling one or more tools.
"required" means the model must call one or more tools.
Specifying a particular tool via {"type": "function", "function": {"name": "my_function"}} forces the model to call that tool.
"none" is the default when no tools are present.
"auto" is the default if tools are present.
*/

switch *options.ToolChoice {
case schema.ToolChoiceForbidden:
req.ToolChoice = toolChoiceNone
case schema.ToolChoiceAllowed:
req.ToolChoice = toolChoiceAuto
case schema.ToolChoiceForced:
if len(req.Tools) > 1 {
req.ToolChoice = toolChoiceRequired
} else {
req.ToolChoice = openai.ToolChoice{
Type: req.Tools[0].Type,
Function: openai.ToolFunction{
Name: req.Tools[0].Function.Name,
},
}
default:
return nil, fmt.Errorf("tool choice=%s not support", *options.ToolChoice)
}
default:
return nil, fmt.Errorf("tool choice=%s not support", *options.ToolChoice)
}
}

Expand Down Expand Up @@ -404,6 +412,7 @@ func (cm *Client) Generate(ctx context.Context, in []*schema.Message, opts ...mo
Model: &cm.config.Model,
TopP: cm.config.TopP,
Stop: cm.config.Stop,
Tools: nil,
ToolChoice: cm.toolChoice,
}, opts...)

Expand Down Expand Up @@ -490,6 +499,7 @@ func (cm *Client) Stream(ctx context.Context, in []*schema.Message,
Model: &cm.config.Model,
TopP: cm.config.TopP,
Stop: cm.config.Stop,
Tools: nil,
ToolChoice: cm.toolChoice,
}, opts...)

Expand Down
2 changes: 1 addition & 1 deletion libs/acl/openai/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.18

require (
github.com/bytedance/mockey v1.2.13
github.com/cloudwego/eino v0.3.8-0.20250117083911-81a6676a6157
github.com/cloudwego/eino v0.3.8-0.20250121122240-6d988e43f988
github.com/getkin/kin-openapi v0.118.0
github.com/sashabaranov/go-openai v1.32.5
github.com/stretchr/testify v1.9.0
Expand Down
Loading

0 comments on commit ee0ff42

Please sign in to comment.