Skip to content

Commit

Permalink
feat: update Google AI model struct to include optional fields and ex…
Browse files Browse the repository at this point in the history
…perimental flag
  • Loading branch information
Gyarbij committed Dec 20, 2024
1 parent e385611 commit 6456738
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 53 deletions.
Binary file modified azure-oai-proxy
Binary file not shown.
123 changes: 70 additions & 53 deletions pkg/google/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,21 +213,22 @@ func convertToOpenAIResponse(resp *genai.GenerateContentResponse) map[string]int
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
Capabilities Capabilities `json:"capabilities"`
LifecycleStatus string `json:"lifecycle_status"`
Status string `json:"status"`
Deprecation Deprecation `json:"deprecation"`
CreatedAt int64 `json:"created_at,omitempty"`
Capabilities Capabilities `json:"capabilities,omitempty"`
LifecycleStatus string `json:"lifecycle_status,omitempty"`
Status string `json:"status,omitempty"`
Deprecation Deprecation `json:"deprecation,omitempty"`
FineTune string `json:"fine_tune,omitempty"`
Name string `json:"name"`
Version string `json:"version"`
Description string `json:"description"`
InputTokenLimit int `json:"inputTokenLimit"`
OutputTokenLimit int `json:"outputTokenLimit"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
Name string `json:"name,omitempty"`
Version string `json:"version,omitempty"`
Description string `json:"description,omitempty"`
InputTokenLimit int `json:"input_token_limit,omitempty"`
OutputTokenLimit int `json:"output_token_limit,omitempty"`
SupportedGenerationMethods []string `json:"supported_generation_methods,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
IsExperimental bool `json:"is_experimental,omitempty"`
}

// Capabilities represents the capabilities of a Google AI model.
Expand All @@ -250,67 +251,83 @@ func FetchGoogleAIModels() ([]Model, error) {
return nil, fmt.Errorf("Google AI Studio API key not set")
}

url := fmt.Sprintf("%s/%s/models?key=%s", GoogleAIEndpoint, GoogleAIAPIVersion, GoogleAIAPIKey)
req, err := http.NewRequest("GET", url, nil)
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(GoogleAIAPIKey))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create Google AI client: %v", err)
}
defer client.Close()

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var models []Model
// Fetch regular models
regIter := client.ListModels(ctx)
for {
m, err := regIter.Next()
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("failed to list regular Google AI models: %v", err)
}

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("failed to fetch Google AI Studio models: %s", string(body))
}
modelID := strings.TrimPrefix(m.Name, "models/")

var googleModels struct {
Models []struct {
Name string `json:"name"`
Version string `json:"version"`
DisplayName string `json:"displayName"`
Description string `json:"description"`
InputTokenLimit int `json:"inputTokenLimit"`
OutputTokenLimit int `json:"outputTokenLimit"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
} `json:"models"`
models = append(models, Model{
ID: modelID,
Object: "model",
Name: m.DisplayName,
Version: m.Version,
Description: m.Description,
InputTokenLimit: int(m.InputTokenLimit),
OutputTokenLimit: int(m.OutputTokenLimit),
SupportedGenerationMethods: m.SupportedGenerationMethods,
Temperature: 0.0, // Default or set based on your needs
TopP: 0.0, // Default or set based on your needs
TopK: 0, // Default or set based on your needs
Capabilities: Capabilities{
Completion: true,
ChatCompletion: true,
Embeddings: strings.Contains(modelID, "embedding"),
},
LifecycleStatus: "active", // You may need to adjust this based on the actual model status
Status: "ready", // You may need to adjust this based on the actual model status
IsExperimental: false,
})
}

if err := json.NewDecoder(resp.Body).Decode(&googleModels); err != nil {
return nil, err
}
// Fetch experimental models
expIter := client.ListModels(ctx)
for {
m, err := expIter.Next()
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("failed to list experimental Google AI models: %v", err)
}

var models []Model
for _, m := range googleModels.Models {
// Extract model ID from the name field (e.g., "models/gemini-pro")
modelID := strings.TrimPrefix(m.Name, "models/")

models = append(models, Model{
ID: modelID,
Object: "model",
Name: m.Name,
Name: m.DisplayName,
Version: m.Version,
Description: m.Description,
InputTokenLimit: m.InputTokenLimit,
OutputTokenLimit: m.OutputTokenLimit,
InputTokenLimit: int(m.InputTokenLimit),
OutputTokenLimit: int(m.OutputTokenLimit),
SupportedGenerationMethods: m.SupportedGenerationMethods,
Temperature: m.Temperature,
TopP: m.TopP,
TopK: m.TopK,
Temperature: 0.0, // Default or set based on your needs
TopP: 0.0, // Default or set based on your needs
TopK: 0, // Default or set based on your needs
Capabilities: Capabilities{
Completion: true,
ChatCompletion: true,
Embeddings: strings.Contains(modelID, "embedding"),
},
LifecycleStatus: "active",
Status: "ready",
LifecycleStatus: "experimental", // You may need to adjust this based on the actual model status
Status: "ready", // You may need to adjust this based on the actual model status
IsExperimental: true,
})
}

Expand Down

0 comments on commit 6456738

Please sign in to comment.