Skip to content

Commit

Permalink
config
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj committed Jun 18, 2024
1 parent d492b91 commit b6c53db
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type Config struct {
APIKey string
// Generative models to provide.
// If empty, a complete list will be obtained from the service.
Models []string
Models map[string]*ai.ModelCapabilities
// Embedding models to provide.
// If empty, a complete list will be obtained from the service.
Embedders []string
Expand All @@ -80,7 +80,10 @@ func Init(ctx context.Context, cfg Config) (err error) {
return err
}

needModels := len(cfg.Models) == 0
needModels := cfg.Models == nil || len(cfg.Models) == 0
if needModels {
cfg.Models = map[string]*ai.ModelCapabilities{}
}
needEmbedders := len(cfg.Embedders) == 0
if needModels || needEmbedders {
iter := client.ListModels(ctx)
Expand All @@ -95,30 +98,35 @@ func Init(ctx context.Context, cfg Config) (err error) {
// Model names are of the form "models/name".
name := path.Base(mi.Name)
if needModels && slices.Contains(mi.SupportedGenerationMethods, "generateContent") {
cfg.Models = append(cfg.Models, name)
cfg.Models[name] = nil
}
if needEmbedders && slices.Contains(mi.SupportedGenerationMethods, "embedContent") {
cfg.Embedders = append(cfg.Embedders, name)
}
}
}
for _, name := range cfg.Models {
defineModel(name, client)
for name, c := range cfg.Models {
defineModel(name, client, c)
}
for _, name := range cfg.Embedders {
defineEmbedder(name, client)
}
return nil
}

func defineModel(name string, client *genai.Client) {
capabilities, ok := knownModelsCapabilities[name]
if !ok {
capabilities = defaultModelCapabilities
func defineModel(name string, client *genai.Client, capabilities *ai.ModelCapabilities) {
c := defaultModelCapabilities
if capabilities == nil {
foundCapability, ok := knownModelsCapabilities[name]
if ok {
c = foundCapability
}
} else {
c = *capabilities
}
meta := &ai.ModelMetadata{
Label: "Google AI - " + name,
Supports: capabilities,
Supports: c,
}
g := generator{model: name, client: client}
ai.DefineModel(provider, name, meta, g.generate)
Expand Down

0 comments on commit b6c53db

Please sign in to comment.