Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] support tool requests when invoking a generator #181

Merged
merged 2 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions go/ai/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"errors"
"fmt"
"slices"
"strings"

"github.com/firebase/genkit/go/genkit"
Expand Down Expand Up @@ -68,6 +69,26 @@ func RegisterGenerator(provider, name string, metadata *GeneratorMetadata, gener
}, generator.Generate))
}

// Generate applies a [Generator] to some input, handling tool requests.
func Generate(ctx context.Context, generator Generator, input *GenerateRequest, cb genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) {
for {
resp, err := generator.Generate(ctx, input, cb)
if err != nil {
return nil, err
}

newReq, err := handleToolRequest(ctx, input, resp)
if err != nil {
return nil, err
}
if newReq == nil {
return resp, nil
}

input = newReq
}
}

// generatorActionType is the instantiated genkit.Action type registered
// by RegisterGenerator.
type generatorActionType = genkit.Action[*GenerateRequest, *GenerateResponse, *Candidate]
Expand All @@ -91,9 +112,65 @@ type generatorAction struct {
action *generatorActionType
}

// Generate implements Generator.
// Generate implements Generator. This is like the [Generate] function,
// but invokes the [genkit.Action] rather than invoking the Generator
// directly.
func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) {
return ga.action.Run(ctx, input, cb)
for {
resp, err := ga.action.Run(ctx, input, cb)
if err != nil {
return nil, err
}

newReq, err := handleToolRequest(ctx, input, resp)
if err != nil {
return nil, err
}
if newReq == nil {
return resp, nil
}

input = newReq
}
}

// handleToolRequest checks if a tool was requested by a generator.
// If a tool was requested, this runs the tool and returns an
// updated GenerateRequest. If no tool was requested this returns nil.
func handleToolRequest(ctx context.Context, req *GenerateRequest, resp *GenerateResponse) (*GenerateRequest, error) {
if len(resp.Candidates) == 0 {
return nil, nil
}
msg := resp.Candidates[0].Message
if msg == nil || len(msg.Content) == 0 {
return nil, nil
}
part := msg.Content[0]
if !part.IsToolRequest() {
return nil, nil
}

toolReq := part.ToolRequest()
output, err := RunTool(ctx, toolReq.Name, toolReq.Input)
if err != nil {
return nil, err
}

toolResp := &Message{
Content: []*Part{
NewToolResponsePart(&ToolResponse{
Name: toolReq.Name,
Output: output,
}),
},
Role: RoleTool,
}

// Copy the GenerateRequest rather than modifying it.
rreq := *req
rreq.Messages = append(slices.Clip(rreq.Messages), msg, toolResp)

return &rreq, nil
}

// Text returns the contents of the first candidate in a
Expand Down
65 changes: 65 additions & 0 deletions go/ai/tools.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ai

import (
"context"
"fmt"
"maps"

"github.com/firebase/genkit/go/genkit"
)

// A Tool is an implementation of a single tool.
// The ToolDefinition has JSON schemas that describe the types.
// TODO: This should be generic over the function input and output types,
// and something in the general code should handle the JSON conversion.
type Tool struct {
Definition *ToolDefinition
Fn func(context.Context, map[string]any) (map[string]any, error)
}

// RegisterTool registers a tool function.
func RegisterTool(name string, definition *ToolDefinition, metadata map[string]any, fn func(ctx context.Context, input map[string]any) (map[string]any, error)) {
if len(metadata) > 0 {
metadata = maps.Clone(metadata)
}
if metadata == nil {
metadata = make(map[string]any)
}
metadata["type"] = "tool"

// TODO: There is no provider for a tool.
genkit.RegisterAction(genkit.ActionTypeTool, "tool",
genkit.NewAction(definition.Name, metadata, fn))
}

// toolActionType is the instantiated genkit.Action type registered
// by RegisterTool.
type toolActionType = genkit.Action[map[string]any, map[string]any, struct{}]

// RunTool looks up a tool registered by [RegisterTool],
// runs it with the given input, and returns the result.
func RunTool(ctx context.Context, name string, input map[string]any) (map[string]any, error) {
action := genkit.LookupAction(genkit.ActionTypeTool, "tool", name)
if action == nil {
return nil, fmt.Errorf("no tool named %q", name)
}
toolInst, ok := action.(*toolActionType)
if !ok {
return nil, fmt.Errorf("RunTool: tool action %q has type %T, want %T", name, action, &toolActionType{})
}
return toolInst.Run(ctx, input, nil)
}
2 changes: 1 addition & 1 deletion go/genkit/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (p *Prompt) Execute(ctx context.Context, input *ActionInput) (*ai.GenerateR
}
}

resp, err := generator.Generate(ctx, genReq, nil)
resp, err := ai.Generate(ctx, generator, genReq, nil)
if err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions go/genkit/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ const (
ActionTypeFlow ActionType = "flow"
ActionTypeModel ActionType = "model"
ActionTypePrompt ActionType = "prompt"
ActionTypeTool ActionType = "tool"
)

// RegisterAction records the action in the global registry.
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/googleai/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ package googleai
import (
"context"

"github.com/firebase/genkit/go/ai"
"github.com/google/generative-ai-go/genai"
"github.com/google/genkit/go/ai"
)

type embedder struct {
Expand Down
85 changes: 42 additions & 43 deletions go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ package googleai_test

import (
"context"
"errors"
"flag"
"fmt"
"math"
"strings"
"testing"
Expand Down Expand Up @@ -128,6 +130,44 @@ func TestGeneratorStreaming(t *testing.T) {
}
}

var toolDef = &ai.ToolDefinition{
Name: "exponentiation",
InputSchema: map[string]any{"base": "float64", "exponent": "int"},
OutputSchema: map[string]any{"output": "float64"},
}

func init() {
ai.RegisterTool("exponentiation",
toolDef, nil,
func(ctx context.Context, input map[string]any) (map[string]any, error) {
baseAny, ok := input["base"]
if !ok {
return nil, errors.New("exponentiation tool: missing base")
}
base, ok := baseAny.(float64)
if !ok {
return nil, fmt.Errorf("exponentiation tool: base is %T, want %T", baseAny, float64(0))
}

expAny, ok := input["exponent"]
if !ok {
return nil, errors.New("exponentiation tool: missing exponent")
}
exp, ok := expAny.(float64)
if !ok {
expInt, ok := expAny.(int)
if !ok {
return nil, fmt.Errorf("exponentiation tool: exponent is %T, want %T or %T", expAny, float64(0), int(0))
}
exp = float64(expInt)
}

r := map[string]any{"output": math.Pow(base, exp)}
return r, nil
},
)
}

func TestGeneratorTool(t *testing.T) {
if *apiKey == "" {
t.Skipf("no -key provided")
Expand All @@ -145,55 +185,14 @@ func TestGeneratorTool(t *testing.T) {
Role: ai.RoleUser,
},
},
Tools: []*ai.ToolDefinition{
&ai.ToolDefinition{
Name: "exponentiation",
InputSchema: map[string]any{"base": "float64", "exponent": "int"},
OutputSchema: map[string]any{"output": "float64"},
},
},
}

resp, err := g.Generate(ctx, req, nil)
if err != nil {
t.Fatal(err)
}
p := resp.Candidates[0].Message.Content[0]
if !p.IsToolRequest() {
t.Fatalf("tool not requested")
}
toolReq := p.ToolRequest()
if toolReq.Name != "exponentiation" {
t.Errorf("tool name is %q, want \"exponentiation\"", toolReq.Name)
Tools: []*ai.ToolDefinition{toolDef},
}
if toolReq.Input["base"] != 3.5 {
t.Errorf("base is %f, want 3.5", toolReq.Input["base"])
}
if toolReq.Input["exponent"] != 2 && toolReq.Input["exponent"] != 2.0 {
// Note: 2.0 is wrong given the schema, but Gemini returns a float anyway.
t.Errorf("exponent is %f, want 2", toolReq.Input["exponent"])
}

// Update our conversation with the tool request the model made and our tool response.
// (Our "tool" is just math.Pow.)
req.Messages = append(req.Messages,
resp.Candidates[0].Message,
&ai.Message{
Content: []*ai.Part{ai.NewToolResponsePart(&ai.ToolResponse{
Name: "exponentiation",
Output: map[string]any{"output": math.Pow(3.5, 2)},
})},
Role: ai.RoleTool,
},
)

// Issue our request again.
resp, err = g.Generate(ctx, req, nil)
resp, err := ai.Generate(ctx, g, req, nil)
if err != nil {
t.Fatal(err)
}

// Check final response.
out := resp.Candidates[0].Message.Content[0].Text()
if !strings.Contains(out, "12.25") {
t.Errorf("got %s, expecting it to contain \"12.25\"", out)
Expand Down
Loading
Loading