Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test that streaming and final answer contain the same text. #365

Closed
wants to merge 10 commits into from
Closed
4 changes: 3 additions & 1 deletion docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ are officially supported:
[2]: plugins/vertex-ai.md
[3]: plugins/ollama.md

See the docs for each plugin for setup and usage information.
See the docs for each plugin for setup and usage information. There's also
a wide variety of community supported models available you can discover by
[searching for packages starting with `genkitx-` on npmjs.org](https://www.npmjs.com/search?q=genkitx).

## How to generate content

Expand Down
53 changes: 26 additions & 27 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,39 @@ package ai

import (
"context"
"maps"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
"github.com/invopop/jsonschema"
)

// PromptRequest is a request to execute a prompt template and
// pass the result to a [ModelAction].
type PromptRequest struct {
// Input fields for the prompt. If not nil this should be a struct
// or pointer to a struct that matches the prompt's input schema.
Variables any `json:"variables,omitempty"`
// Number of candidates to return; if 0, will be taken
// from the prompt config; if still 0, will use 1.
Candidates int `json:"candidates,omitempty"`
// Model configuration. If nil will be taken from the prompt config.
Config *GenerationCommonConfig `json:"config,omitempty"`
// Context to pass to model, if any.
Context []any `json:"context,omitempty"`
// The model to use. This overrides any model specified by the prompt.
Model string `json:"model,omitempty"`
// A PromptAction is used to render a prompt template,
// producing a [GenerateRequest] that may be passed to a [ModelAction].
type PromptAction = core.Action[any, *GenerateRequest, struct{}]

// DefinePrompt takes a function that renders a prompt template
// into a [GenerateRequest] that may be passed to a [ModelAction].
// The prompt expects some input described by inputSchema.
// DefinePrompt registers the function as an action,
// and returns a [PromptAction] that runs it.
func DefinePrompt(provider, name string, metadata map[string]any, render func(context.Context, any) (*GenerateRequest, error), inputSchema *jsonschema.Schema) *PromptAction {
mm := maps.Clone(metadata)
if mm == nil {
mm = make(map[string]any)
}
mm["type"] = "prompt"
mm["prompt"] = true // required by genkit ui
return core.DefineActionWithInputSchema(provider, name, atype.Prompt, mm, render, inputSchema)
}

// Prompt is the interface used to execute a prompt template and
// pass the result to a [ModelAction].
type Prompt interface {
Generate(context.Context, *PromptRequest, func(context.Context, *GenerateResponseChunk) error) (*GenerateResponse, error)
// LookupPrompt looks up a [PromptAction] registered by [DefinePrompt].
// It returns nil if the prompt was not defined.
func LookupPrompt(provider, name string) *PromptAction {
return core.LookupActionFor[any, *GenerateRequest, struct{}](atype.Prompt, provider, name)
}

// RegisterPrompt registers a prompt in the global registry.
func RegisterPrompt(provider, name string, prompt Prompt) {
metadata := map[string]any{
"type": "prompt",
"prompt": prompt,
}
core.RegisterAction(provider,
core.NewStreamingAction(name, atype.Prompt, metadata, prompt.Generate))
// Render renders a [PromptAction] with some input data.
func Render(ctx context.Context, p *PromptAction, input any) (*GenerateRequest, error) {
return p.Run(ctx, input, nil)
}
77 changes: 53 additions & 24 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ type streamingCallback[Stream any] func(context.Context, Stream) error
//
// Each time an Action is run, it results in a new trace span.
type Action[In, Out, Stream any] struct {
name string
aname string
atype atype.ActionType
fn Func[In, Out, Stream]
tstate *tracing.State
inputSchema *jsonschema.Schema
outputSchema *jsonschema.Schema
// optional
Description string
Metadata map[string]any
description string
metadata map[string]any
}

// See js/core/src/action.ts
Expand All @@ -73,7 +73,7 @@ func DefineAction[In, Out any](provider, name string, atype atype.ActionType, me
}

func defineAction[In, Out any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
a := NewAction(name, atype, metadata, fn)
a := newAction(name, atype, metadata, fn)
r.registerAction(provider, a)
return a
}
Expand All @@ -83,7 +83,7 @@ func DefineStreamingAction[In, Out, Stream any](provider, name string, atype aty
}

func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
a := NewStreamingAction(name, atype, metadata, fn)
a := newStreamingAction(name, atype, metadata, fn)
r.registerAction(provider, a)
return a
}
Expand All @@ -92,32 +92,61 @@ func DefineCustomAction[In, Out, Stream any](provider, name string, metadata map
return DefineStreamingAction(provider, name, atype.Custom, metadata, fn)
}

// NewAction creates a new Action with the given name and non-streaming function.
func NewAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return NewStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb NoStream) (Out, error) {
// DefineActionWithInputSchema creates a new Action and registers it.
// This differs from DefineAction in that the input schema is
// defined dynamically; the static input type is "any".
// This is used for prompts.
func DefineActionWithInputSchema[Out any](provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, any) (Out, error), inputSchema *jsonschema.Schema) *Action[any, Out, struct{}] {
return defineActionWithInputSchema(globalRegistry, provider, name, atype, metadata, fn, inputSchema)
}

func defineActionWithInputSchema[Out any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, any) (Out, error), inputSchema *jsonschema.Schema) *Action[any, Out, struct{}] {
a := newActionWithInputSchema(name, atype, metadata, fn, inputSchema)
r.registerAction(provider, a)
return a
}

// newAction creates a new Action with the given name and non-streaming function.
func newAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return newStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb NoStream) (Out, error) {
return fn(ctx, in)
})
}

// NewStreamingAction creates a new Action with the given name and streaming function.
func NewStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
// newStreamingAction creates a new Action with the given name and streaming function.
func newStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
var i In
var o Out
return &Action[In, Out, Stream]{
name: name,
aname: name,
atype: atype,
fn: func(ctx context.Context, input In, sc func(context.Context, Stream) error) (Out, error) {
tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype))
return fn(ctx, input, sc)
},
inputSchema: inferJSONSchema(i),
outputSchema: inferJSONSchema(o),
Metadata: metadata,
metadata: metadata,
}
}

func newActionWithInputSchema[Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, any) (Out, error), inputSchema *jsonschema.Schema) *Action[any, Out, struct{}] {
var o Out
return &Action[any, Out, struct{}]{
aname: name,
atype: atype,
fn: func(ctx context.Context, input any, sc func(context.Context, struct{}) error) (Out, error) {
tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype))
return fn(ctx, input)
},
inputSchema: inputSchema,
outputSchema: inferJSONSchema(o),
metadata: metadata,
}
}

// Name returns the Action's name.
func (a *Action[In, Out, Stream]) Name() string { return a.name }
// name returns the Action's name.
func (a *Action[In, Out, Stream]) name() string { return a.aname }

func (a *Action[In, Out, Stream]) actionType() atype.ActionType { return a.atype }

Expand All @@ -140,35 +169,35 @@ func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(con
// This action has probably not been registered.
tstate = globalRegistry.tstate
}
return tracing.RunInNewSpan(ctx, tstate, a.name, "action", false, input,
return tracing.RunInNewSpan(ctx, tstate, a.aname, "action", false, input,
func(ctx context.Context, input In) (Out, error) {
start := time.Now()
var err error
if err = ValidateValue(input, a.inputSchema); err != nil {
if err = validateValue(input, a.inputSchema); err != nil {
err = fmt.Errorf("invalid input: %w", err)
}
var output Out
if err == nil {
output, err = a.fn(ctx, input, cb)
if err == nil {
if err = ValidateValue(output, a.outputSchema); err != nil {
if err = validateValue(output, a.outputSchema); err != nil {
err = fmt.Errorf("invalid output: %w", err)
}
}
}
latency := time.Since(start)
if err != nil {
writeActionFailure(ctx, a.name, latency, err)
writeActionFailure(ctx, a.aname, latency, err)
return internal.Zero[Out](), err
}
writeActionSuccess(ctx, a.name, latency)
writeActionSuccess(ctx, a.aname, latency)
return output, nil
})
}

func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) {
// Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process.
if err := ValidateJSON(input, a.inputSchema); err != nil {
if err := validateJSON(input, a.inputSchema); err != nil {
return nil, err
}
var in In
Expand Down Expand Up @@ -198,7 +227,7 @@ func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMes

// action is the type that all Action[I, O, S] have in common.
type action interface {
Name() string
name() string
actionType() atype.ActionType

// runJSON uses encoding/json to unmarshal the input,
Expand Down Expand Up @@ -234,9 +263,9 @@ func (d1 actionDesc) equal(d2 actionDesc) bool {

func (a *Action[I, O, S]) desc() actionDesc {
ad := actionDesc{
Name: a.name,
Description: a.Description,
Metadata: a.Metadata,
Name: a.aname,
Description: a.description,
Metadata: a.metadata,
InputSchema: a.inputSchema,
OutputSchema: a.outputSchema,
}
Expand Down
10 changes: 5 additions & 5 deletions go/core/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func inc(_ context.Context, x int) (int, error) {
}

func TestActionRun(t *testing.T) {
a := NewAction("inc", atype.Custom, nil, inc)
a := newAction("inc", atype.Custom, nil, inc)
got, err := a.Run(context.Background(), 3, nil)
if err != nil {
t.Fatal(err)
Expand All @@ -39,7 +39,7 @@ func TestActionRun(t *testing.T) {
}

func TestActionRunJSON(t *testing.T) {
a := NewAction("inc", atype.Custom, nil, inc)
a := newAction("inc", atype.Custom, nil, inc)
input := []byte("3")
want := []byte("4")
got, err := a.runJSON(context.Background(), input, nil)
Expand All @@ -53,7 +53,7 @@ func TestActionRunJSON(t *testing.T) {

func TestNewAction(t *testing.T) {
// Verify that struct{} can occur in the function signature.
_ = NewAction("f", atype.Custom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil })
_ = newAction("f", atype.Custom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil })
}

// count streams the numbers from 0 to n-1, then returns n.
Expand All @@ -70,7 +70,7 @@ func count(ctx context.Context, n int, cb func(context.Context, int) error) (int

func TestActionStreaming(t *testing.T) {
ctx := context.Background()
a := NewStreamingAction("count", atype.Custom, nil, count)
a := newStreamingAction("count", atype.Custom, nil, count)
const n = 3

// Non-streaming.
Expand Down Expand Up @@ -103,7 +103,7 @@ func TestActionStreaming(t *testing.T) {
func TestActionTracing(t *testing.T) {
ctx := context.Background()
const actionName = "TestTracing-inc"
a := NewAction(actionName, atype.Custom, nil, inc)
a := newAction(actionName, atype.Custom, nil, inc)
if _, err := a.Run(context.Background(), 3, nil); err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion go/core/conformance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (c *command) run(ctx context.Context, input string) (string, error) {
case c.Append != nil:
return input + *c.Append, nil
case c.Run != nil:
return Run(ctx, c.Run.Name, func() (string, error) {
return InternalRun(ctx, c.Run.Name, func() (string, error) {
return c.Run.Command.run(ctx, input)
})
default:
Expand Down
4 changes: 2 additions & 2 deletions go/core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
//go:generate go run ../internal/cmd/jsonschemagen -outdir .. -config schemas.config ../../genkit-tools/genkit-schema.json core

// Package core implements Genkit actions, flows and other essential machinery.
// This package is primarily intended for genkit internals and for plugins.
// Applications using genkit should use the genkit package.
// This package is primarily intended for Genkit internals and for plugins.
// Genkit applications should use the genkit package.
package core
29 changes: 13 additions & 16 deletions go/core/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ type Flow[In, Out, Stream any] struct {
// TODO(jba): middleware
}

// DefineFlow creates a Flow that runs fn, and registers it as an action.
func DefineFlow[In, Out, Stream any](name string, fn Func[In, Out, Stream]) *Flow[In, Out, Stream] {
// InternalDefineFlow is for use by genkit.DefineFlow exclusively.
// It is not subject to any backwards compatibility guarantees.
func InternalDefineFlow[In, Out, Stream any](name string, fn Func[In, Out, Stream]) *Flow[In, Out, Stream] {
return defineFlow(globalRegistry, name, fn)
}

Expand Down Expand Up @@ -259,7 +260,7 @@ func (f *Flow[In, Out, Stream]) action() *Action[*flowInstruction[In], *flowStat
tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true")
return f.runInstruction(ctx, inst, streamingCallback[Stream](cb))
}
return NewStreamingAction(f.name, atype.Flow, metadata, cback)
return newStreamingAction(f.name, atype.Flow, metadata, cback)
}

// runInstruction performs one of several actions on a flow, as determined by msg.
Expand Down Expand Up @@ -297,7 +298,7 @@ func (f *Flow[In, Out, Stream]) Name() string { return f.name }

func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) {
// Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process.
if err := ValidateJSON(input, f.inputSchema); err != nil {
if err := validateJSON(input, f.inputSchema); err != nil {
return nil, &httpError{http.StatusBadRequest, err}
}
var in In
Expand Down Expand Up @@ -380,14 +381,14 @@ func (f *Flow[In, Out, Stream]) execute(ctx context.Context, state *flowState[In
// TODO(jba): If input is missing, get it from state.input and overwrite metadata.input.
start := time.Now()
var err error
if err = ValidateValue(input, f.inputSchema); err != nil {
if err = validateValue(input, f.inputSchema); err != nil {
err = fmt.Errorf("invalid input: %w", err)
}
var output Out
if err == nil {
output, err = f.fn(ctx, input, cb)
if err == nil {
if err = ValidateValue(output, f.outputSchema); err != nil {
if err = validateValue(output, f.outputSchema); err != nil {
err = fmt.Errorf("invalid output: %w", err)
}
}
Expand Down Expand Up @@ -488,13 +489,9 @@ func (fc *flowContext[I, O]) uniqueStepName(name string) string {

var flowContextKey = internal.NewContextKey[flowContexter]()

// Run runs the function f in the context of the current flow.
// It returns an error if no flow is active.
//
// Each call to Run results in a new step in the flow.
// A step has its own span in the trace, and its result is cached so that if the flow
// is restarted, f will not be called a second time.
func Run[Out any](ctx context.Context, name string, f func() (Out, error)) (Out, error) {
// InternalRun is for use by genkit.Run exclusively.
// It is not subject to any backwards compatibility guarantees.
func InternalRun[Out any](ctx context.Context, name string, f func() (Out, error)) (Out, error) {
// from js/flow/src/steps.ts
fc := flowContextKey.FromContext(ctx)
if fc == nil {
Expand Down Expand Up @@ -541,9 +538,9 @@ func Run[Out any](ctx context.Context, name string, f func() (Out, error)) (Out,
})
}

// RunFlow runs flow in the context of another flow. The flow must run to completion when started
// (that is, it must not have interrupts).
func RunFlow[In, Out, Stream any](ctx context.Context, flow *Flow[In, Out, Stream], input In) (Out, error) {
// InternalRunFlow is for use by genkit.RunFlow exclusively.
// It is not subject to any backwards compatibility guarantees.
func InternalRunFlow[In, Out, Stream any](ctx context.Context, flow *Flow[In, Out, Stream], input In) (Out, error) {
state, err := flow.start(ctx, input, nil)
if err != nil {
return internal.Zero[Out](), err
Expand Down
Loading
Loading