Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalize to type/provider/name pattern for action names. #324

Merged
merged 3 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ type EmbedRequest struct {

// DefineEmbedder registers the given embed function as an action, and returns an
// [Embedder] whose Embed method runs it.
func DefineEmbedder(name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) Embedder {
return embedder{core.DefineAction(name, core.ActionTypeEmbedder, nil, embed)}
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) Embedder {
return embedder{core.DefineAction(provider, name, core.ActionTypeEmbedder, nil, embed)}
}

type embedder struct {
Expand Down
2 changes: 1 addition & 1 deletion go/ai/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func DefineGenerator(provider, name string, metadata *GeneratorMetadata, generat
}
metadataMap["supports"] = supports
}
a := core.DefineStreamingAction(provider+"/"+name, core.ActionTypeModel, map[string]any{
a := core.DefineStreamingAction(provider, name, core.ActionTypeModel, map[string]any{
"model": metadataMap,
}, generate)
return generator{a}
Expand Down
4 changes: 2 additions & 2 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ func DefineDocumentStore(
index func(context.Context, *IndexerRequest) error,
retrieve func(context.Context, *RetrieverRequest) (*RetrieverResponse, error),
) DocumentStore {
ia := core.DefineAction(name, core.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
ia := core.DefineAction("indexer", name, core.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
return struct{}{}, index(ctx, req)
})
ra := core.DefineAction(name, core.ActionTypeRetriever, nil, retrieve)
ra := core.DefineAction("retreiver", name, core.ActionTypeRetriever, nil, retrieve)
return &docStore{ia, ra}
}

Expand Down
6 changes: 3 additions & 3 deletions go/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type Tool struct {
}

// RegisterTool registers a tool function.
func RegisterTool(name string, definition *ToolDefinition, metadata map[string]any, fn func(ctx context.Context, input map[string]any) (map[string]any, error)) {
func RegisterTool(definition *ToolDefinition, metadata map[string]any, fn func(ctx context.Context, input map[string]any) (map[string]any, error)) {
if len(metadata) > 0 {
metadata = maps.Clone(metadata)
}
Expand All @@ -41,7 +41,7 @@ func RegisterTool(name string, definition *ToolDefinition, metadata map[string]a
}
metadata["type"] = "tool"

core.DefineAction(definition.Name, core.ActionTypeTool, metadata, fn)
core.DefineAction("local", definition.Name, core.ActionTypeTool, metadata, fn)
}

// toolActionType is the instantiated core.Action type registered
Expand All @@ -51,7 +51,7 @@ type toolActionType = core.Action[map[string]any, map[string]any, struct{}]
// RunTool looks up a tool registered by [RegisterTool],
// runs it with the given input, and returns the result.
func RunTool(ctx context.Context, name string, input map[string]any) (map[string]any, error) {
action := core.LookupAction(core.ActionTypeTool, "tool", name)
action := core.LookupAction(core.ActionTypeTool, "local", name)
if action == nil {
return nil, fmt.Errorf("no tool named %q", name)
}
Expand Down
16 changes: 8 additions & 8 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,23 @@ type Action[In, Out, Stream any] struct {
// See js/core/src/action.ts

// DefineAction creates a new Action and registers it.
func DefineAction[In, Out any](name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return defineAction(globalRegistry, name, atype, metadata, fn)
func DefineAction[In, Out any](provider, name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return defineAction(globalRegistry, provider, name, atype, metadata, fn)
}

func defineAction[In, Out any](r *registry, name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
func defineAction[In, Out any](r *registry, provider, name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
a := NewAction(name, atype, metadata, fn)
r.registerAction(name, a)
r.registerAction(provider, a)
return a
}

func DefineStreamingAction[In, Out, Stream any](name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
return defineStreamingAction(globalRegistry, name, atype, metadata, fn)
func DefineStreamingAction[In, Out, Stream any](provider, name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
return defineStreamingAction(globalRegistry, provider, name, atype, metadata, fn)
}

func defineStreamingAction[In, Out, Stream any](r *registry, name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
a := NewStreamingAction(name, atype, metadata, fn)
r.registerAction(name, a)
r.registerAction(provider, a)
return a
}

Expand Down
2 changes: 1 addition & 1 deletion go/plugins/googleai/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func NewEmbedder(ctx context.Context, model, apiKey string) (ai.Embedder, error)
if err != nil {
return nil, err
}
e := ai.DefineEmbedder("google-genai", func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) {
e := ai.DefineEmbedder("google-genai", model, func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) {
em := client.EmbeddingModel(model)
parts := convertParts(input.Document.Content)
res, err := em.EmbedContent(ctx, parts...)
Expand Down
3 changes: 1 addition & 2 deletions go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ var toolDef = &ai.ToolDefinition{
}

func init() {
ai.RegisterTool("exponentiation",
toolDef, nil,
ai.RegisterTool(toolDef, nil,
func(ctx context.Context, input map[string]any) (map[string]any, error) {
baseAny, ok := input["base"]
if !ok {
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/localvec/localvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func New(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOp
if err != nil {
return nil, err
}
return ai.DefineDocumentStore("devLocalVectorStore/"+name, r.Index, r.Retrieve), nil
return ai.DefineDocumentStore("devLocalVectorStore-"+name, r.Index, r.Retrieve), nil
}

// docStore implements the [ai.DocumentStore] interface
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/vertexai/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func NewEmbedder(ctx context.Context, model, projectID, location string) (ai.Emb

reqEndpoint := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", projectID, location, model)

e := ai.DefineEmbedder("google-vertexai", func(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) {
e := ai.DefineEmbedder("google-vertexai", model, func(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) {
preq, err := newPredictRequest(reqEndpoint, req)
if err != nil {
return nil, err
Expand Down
3 changes: 1 addition & 2 deletions go/plugins/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ var toolDef = &ai.ToolDefinition{
}

func init() {
ai.RegisterTool("exponentiation",
toolDef, nil,
ai.RegisterTool(toolDef, nil,
func(ctx context.Context, input map[string]any) (map[string]any, error) {
baseAny, ok := input["base"]
if !ok {
Expand Down
2 changes: 1 addition & 1 deletion go/samples/menu/s02.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func menu(ctx context.Context, input map[string]any) (map[string]any, error) {
}

func setup02(ctx context.Context, g ai.Generator) error {
ai.RegisterTool("menu", menuToolDef, nil, menu)
ai.RegisterTool(menuToolDef, nil, menu)

dataMenuPrompt, err := dotprompt.Define("s02_dataMenu",
`You are acting as a helpful AI assistant named Walt that can answer
Expand Down
Loading