Skip to content

Commit

Permalink
Added action-level and generate-level middleware.
Browse files Browse the repository at this point in the history
  • Loading branch information
apascal07 committed Feb 12, 2025
1 parent f7f8701 commit 2102a9c
Show file tree
Hide file tree
Showing 25 changed files with 543 additions and 93 deletions.
6 changes: 5 additions & 1 deletion go/ai/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ func (pm *programmableModel) Generate(ctx context.Context, r *registry.Registry,

func defineProgrammableModel(r *registry.Registry) *programmableModel {
pm := &programmableModel{r: r}
DefineModel(r, "default", "programmableModel", nil, func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
supports := &ModelInfoSupports{
Tools: true,
Multiturn: true,
}
DefineModel(r, "", "programmableModel", &ModelInfo{Supports: supports}, nil, func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
return pm.Generate(ctx, r, req, &ToolConfig{MaxTurns: 5}, cb)
})
return pm
Expand Down
3 changes: 1 addition & 2 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package ai

import (
Expand Down Expand Up @@ -52,7 +51,7 @@ func DefineEmbedder(
provider, name string,
embed func(context.Context, *EmbedRequest) (*EmbedResponse, error),
) Embedder {
return (*embedderActionDef)(core.DefineAction(r, provider, name, atype.Embedder, nil, embed))
return (*embedderActionDef)(core.DefineAction(r, provider, name, atype.Embedder, nil, nil, embed))
}

// IsDefinedEmbedder reports whether an embedder is defined.
Expand Down
6 changes: 3 additions & 3 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ type ModelInfoSupports struct {

// A ModelRequest is a request to generate completions from a model.
type ModelRequest struct {
Config any `json:"config,omitempty"`
Context []any `json:"context,omitempty"`
Messages []*Message `json:"messages,omitempty"`
Config any `json:"config,omitempty"`
Context []*Document `json:"context,omitempty"`
Messages []*Message `json:"messages,omitempty"`
// Output describes the desired response format.
Output *ModelRequestOutput `json:"output,omitempty"`
ToolChoice ToolChoice `json:"toolChoice,omitempty"`
Expand Down
61 changes: 46 additions & 15 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@ type Model interface {
// Name returns the registry name of the model.
Name() string
// Generate applies the [Model] to provided request, handling tool requests and handles streaming.
Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error)
Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, mw []ModelMiddleware, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error)
}

// ModelFunc is a function that generates a model response.
type ModelFunc = core.Func[*ModelRequest, *ModelResponse, *ModelResponseChunk]

// ModelMiddleware is middleware for model generate requests.
type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResponseChunk]

type modelActionDef core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk]

type modelAction = core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk]
Expand All @@ -44,7 +50,7 @@ type ToolConfig struct {

// DefineGenerateAction defines a utility generate action.
func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAction {
return (*generateAction)(core.DefineStreamingAction(r, "", "generate", atype.Util, map[string]any{},
return (*generateAction)(core.DefineStreamingAction(r, "", "generate", atype.Util, map[string]any{}, nil,
func(ctx context.Context, req *GenerateActionOptions, cb ModelStreamingCallback) (output *ModelResponse, err error) {
logger.FromContext(ctx).Debug("GenerateAction",
"input", fmt.Sprintf("%#v", req))
Expand All @@ -53,9 +59,10 @@ func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAc
"output", fmt.Sprintf("%#v", output),
"err", err)
}()

return tracing.RunInNewSpan(ctx, r.TracingState(), "generate", "util", false, req,
func(ctx context.Context, input *GenerateActionOptions) (*ModelResponse, error) {
model := LookupModel(r, "default", req.Model)
model := LookupModel(r, "", req.Model)
if model == nil {
return nil, fmt.Errorf("model %q not found", req.Model)
}
Expand Down Expand Up @@ -95,17 +102,17 @@ func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAc
ReturnToolRequests: req.ReturnToolRequests,
}

return model.Generate(ctx, r, modelReq, toolCfg, cb)
return model.Generate(ctx, r, modelReq, nil, toolCfg, cb)
})
}))
}

// DefineModel registers the given generate function as an action, and returns a
// [Model] that runs it.
// DefineModel registers the given generate function as an action, and returns a [Model] that runs it.
func DefineModel(
r *registry.Registry,
provider, name string,
metadata *ModelInfo,
mw []ModelMiddleware,
generate func(context.Context, *ModelRequest, ModelStreamingCallback) (*ModelResponse, error),
) Model {
metadataMap := map[string]any{}
Expand All @@ -129,9 +136,9 @@ func DefineModel(
metadataMap["supports"] = supports
metadataMap["versions"] = metadata.Versions

return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{
"model": metadataMap,
}, generate))
mw = append([]ModelMiddleware{ValidateSupport(name, metadata.Supports)}, mw...)

return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{"model": metadataMap}, mw, generate))
}

// IsDefinedModel reports whether a model is defined.
Expand All @@ -158,6 +165,7 @@ type generateParams struct {
SystemPrompt *Message
MaxTurns int
ReturnToolRequests bool
Middleware []ModelMiddleware
}

// GenerateOption configures params of the Generate call.
Expand Down Expand Up @@ -224,10 +232,13 @@ func WithConfig(config any) GenerateOption {
}
}

// WithContext adds provided context to ModelRequest.
func WithContext(c ...any) GenerateOption {
// WithContext adds provided documents to ModelRequest.
func WithContext(docs ...*Document) GenerateOption {
return func(req *generateParams) error {
req.Request.Context = append(req.Request.Context, c...)
if req.Request.Context != nil {
return errors.New("generate.WithContext: cannot set context more than once")
}
req.Request.Context = docs
return nil
}
}
Expand Down Expand Up @@ -320,6 +331,17 @@ func WithToolChoice(toolChoice ToolChoice) GenerateOption {
}
}

// WithMiddleware adds middleware to the generate request.
func WithMiddleware(middleware ...ModelMiddleware) GenerateOption {
return func(req *generateParams) error {
if req.Middleware != nil {
return errors.New("generate.WithMiddleware: cannot set Middleware more than once")
}
req.Middleware = middleware
return nil
}
}

// Generate run generate request for this model. Returns ModelResponse struct.
func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (*ModelResponse, error) {
req := &generateParams{
Expand Down Expand Up @@ -368,7 +390,7 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
ReturnToolRequests: req.ReturnToolRequests,
}

return req.Model.Generate(ctx, r, req.Request, toolCfg, req.Stream)
return req.Model.Generate(ctx, r, req.Request, req.Middleware, toolCfg, req.Stream)
}

// validateModelVersion checks in the registry the action of the
Expand Down Expand Up @@ -435,7 +457,7 @@ func GenerateData(ctx context.Context, r *registry.Registry, value any, opts ...
}

// Generate applies the [Action] to provided request, handling tool requests and handles streaming.
func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error) {
func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, mw []ModelMiddleware, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error) {
if m == nil {
return nil, errors.New("Generate called on a nil Model; check that all models are defined")
}
Expand Down Expand Up @@ -463,9 +485,18 @@ func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req
return nil, err
}

handler := (*modelAction)(m).Run
for i := len(mw) - 1; i >= 0; i-- {
currentHandler := handler
currentMiddleware := mw[i]
handler = func(ctx context.Context, in *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
return currentMiddleware(ctx, in, cb, currentHandler)
}
}

currentTurn := 0
for {
resp, err := (*modelAction)(m).Run(ctx, req, cb)
resp, err := handler(ctx, req, cb)
if err != nil {
return nil, err
}
Expand Down
65 changes: 58 additions & 7 deletions go/ai/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var (
Versions: []string{"echo-001", "echo-002"},
}

echoModel = DefineModel(r, "test", modelName, &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
echoModel = DefineModel(r, "test", modelName, &metadata, nil, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
if msc != nil {
msc(ctx, &ModelResponseChunk{
Content: []*Part{NewTextPart("stream!")},
Expand Down Expand Up @@ -264,7 +264,7 @@ func TestGenerate(t *testing.T) {
},
},
Config: GenerationCommonConfig{Temperature: 1},
Context: []any{[]any{string("Banana")}},
Context: []*Document{&Document{Content: []*Part{NewTextPart("Banana")}}},
Output: &ModelRequestOutput{
Format: "json",
Schema: map[string]any{
Expand Down Expand Up @@ -310,7 +310,7 @@ func TestGenerate(t *testing.T) {
Temperature: 1,
}),
WithHistory(NewUserTextMessage("banana"), NewModelTextMessage("yes, banana")),
WithContext([]any{"Banana"}),
WithContext(&Document{Content: []*Part{NewTextPart("Banana")}}),
WithOutputSchema(&GameCharacter{}),
WithTools(gablorkenTool),
WithStreaming(func(ctx context.Context, grc *ModelResponseChunk) error {
Expand Down Expand Up @@ -346,7 +346,13 @@ func TestGenerate(t *testing.T) {
},
)

interruptModel := DefineModel(r, "test", "interrupt", nil,
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
},
}
interruptModel := DefineModel(r, "test", "interrupt", info, nil,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
return &ModelResponse{
Request: gr,
Expand Down Expand Up @@ -399,7 +405,13 @@ func TestGenerate(t *testing.T) {

t.Run("handles multiple parallel tool calls", func(t *testing.T) {
roundCount := 0
parallelModel := DefineModel(r, "test", "parallel", nil,
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
},
}
parallelModel := DefineModel(r, "test", "parallel", info, nil,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
Expand Down Expand Up @@ -458,7 +470,13 @@ func TestGenerate(t *testing.T) {

t.Run("handles multiple rounds of tool calls", func(t *testing.T) {
roundCount := 0
multiRoundModel := DefineModel(r, "test", "multiround", nil,
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
},
}
multiRoundModel := DefineModel(r, "test", "multiround", info, nil,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
Expand Down Expand Up @@ -520,7 +538,13 @@ func TestGenerate(t *testing.T) {
})

t.Run("exceeds maximum turns", func(t *testing.T) {
infiniteModel := DefineModel(r, "test", "infinite", nil,
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
},
}
infiniteModel := DefineModel(r, "test", "infinite", info, nil,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
return &ModelResponse{
Request: gr,
Expand Down Expand Up @@ -550,6 +574,33 @@ func TestGenerate(t *testing.T) {
t.Errorf("unexpected error message: %v", err)
}
})

t.Run("applies middleware", func(t *testing.T) {
middlewareCalled := false
testMiddleware := func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback, next ModelFunc) (*ModelResponse, error) {
middlewareCalled = true
req.Messages = append(req.Messages, NewUserTextMessage("middleware was here"))
return next(ctx, req, cb)
}

res, err := Generate(context.Background(), r,
WithModel(echoModel),
WithTextPrompt("test middleware"),
WithMiddleware(testMiddleware),
)
if err != nil {
t.Fatal(err)
}

if !middlewareCalled {
t.Error("middleware was not called")
}

expectedText := "test middlewaremiddleware was here"
if res.Text() != expectedText {
t.Errorf("got text %q, want %q", res.Text(), expectedText)
}
})
}

func TestModelVersion(t *testing.T) {
Expand Down
37 changes: 37 additions & 0 deletions go/ai/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package ai

import (
"context"
"fmt"

"github.com/firebase/genkit/go/core"
)

// ValidateSupport creates middleware that validates whether a model supports the requested features.
func ValidateSupport(name string, supports *ModelInfoSupports) ModelMiddleware {
return func(ctx context.Context, input *ModelRequest, cb ModelStreamingCallback, next core.Func[*ModelRequest, *ModelResponse, *ModelResponseChunk]) (*ModelResponse, error) {
if supports == nil {
supports = &ModelInfoSupports{}
}

if !supports.Media {
for _, msg := range input.Messages {
for _, part := range msg.Content {
if part.IsMedia() {
return nil, fmt.Errorf("model %q does not support media, but media was provided. Request: %+v", name, input)
}
}
}
}

if !supports.Tools && len(input.Tools) > 0 {
return nil, fmt.Errorf("model %q does not support tool use, but tools were provided. Request: %+v", name, input)
}

if !supports.Multiturn && len(input.Messages) > 1 {
return nil, fmt.Errorf("model %q does not support multiple messages, but %d were provided. Request: %+v", name, len(input.Messages), input)
}

return next(ctx, input, cb)
}
}
Loading

0 comments on commit 2102a9c

Please sign in to comment.