diff --git a/azure-oai-proxy b/azure-oai-proxy index c7aa6e2..83a3c83 100755 Binary files a/azure-oai-proxy and b/azure-oai-proxy differ diff --git a/pkg/google/proxy.go b/pkg/google/proxy.go index 8eb3973..589c8aa 100644 --- a/pkg/google/proxy.go +++ b/pkg/google/proxy.go @@ -213,21 +213,22 @@ func convertToOpenAIResponse(resp *genai.GenerateContentResponse) map[string]int type Model struct { ID string `json:"id"` Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Capabilities Capabilities `json:"capabilities"` - LifecycleStatus string `json:"lifecycle_status"` - Status string `json:"status"` - Deprecation Deprecation `json:"deprecation"` + CreatedAt int64 `json:"created_at,omitempty"` + Capabilities Capabilities `json:"capabilities,omitempty"` + LifecycleStatus string `json:"lifecycle_status,omitempty"` + Status string `json:"status,omitempty"` + Deprecation Deprecation `json:"deprecation,omitempty"` FineTune string `json:"fine_tune,omitempty"` - Name string `json:"name"` - Version string `json:"version"` - Description string `json:"description"` - InputTokenLimit int `json:"inputTokenLimit"` - OutputTokenLimit int `json:"outputTokenLimit"` - SupportedGenerationMethods []string `json:"supportedGenerationMethods"` + Name string `json:"name,omitempty"` + Version string `json:"version,omitempty"` + Description string `json:"description,omitempty"` + InputTokenLimit int `json:"input_token_limit,omitempty"` + OutputTokenLimit int `json:"output_token_limit,omitempty"` + SupportedGenerationMethods []string `json:"supported_generation_methods,omitempty"` Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + IsExperimental bool `json:"is_experimental,omitempty"` } // Capabilities represents the capabilities of a Google AI model. @@ -250,67 +251,83 @@ func FetchGoogleAIModels() ([]Model, error) { return nil, fmt.Errorf("Google AI Studio API key not set") } - url := fmt.Sprintf("%s/%s/models?key=%s", GoogleAIEndpoint, GoogleAIAPIVersion, GoogleAIAPIKey) - req, err := http.NewRequest("GET", url, nil) + ctx := context.Background() + client, err := genai.NewClient(ctx, option.WithAPIKey(GoogleAIAPIKey)) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create Google AI client: %v", err) } + defer client.Close() - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() + var models []Model + // Fetch regular models + regIter := client.ListModels(ctx) + for { + m, err := regIter.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, fmt.Errorf("failed to list regular Google AI models: %v", err) + } - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("failed to fetch Google AI Studio models: %s", string(body)) - } + modelID := strings.TrimPrefix(m.Name, "models/") - var googleModels struct { - Models []struct { - Name string `json:"name"` - Version string `json:"version"` - DisplayName string `json:"displayName"` - Description string `json:"description"` - InputTokenLimit int `json:"inputTokenLimit"` - OutputTokenLimit int `json:"outputTokenLimit"` - SupportedGenerationMethods []string `json:"supportedGenerationMethods"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` - } `json:"models"` + models = append(models, Model{ + ID: modelID, + Object: "model", + Name: m.DisplayName, + Version: m.Version, + Description: m.Description, + InputTokenLimit: int(m.InputTokenLimit), + OutputTokenLimit: int(m.OutputTokenLimit), + SupportedGenerationMethods: m.SupportedGenerationMethods, + Temperature: 0.0, // Default or set based on your needs + TopP: 0.0, // Default or set based on your needs + TopK: 0, // Default or set based on your needs + Capabilities: Capabilities{ + Completion: true, + ChatCompletion: true, + Embeddings: strings.Contains(modelID, "embedding"), + }, + LifecycleStatus: "active", // You may need to adjust this based on the actual model status + Status: "ready", // You may need to adjust this based on the actual model status + IsExperimental: false, + }) } - if err := json.NewDecoder(resp.Body).Decode(&googleModels); err != nil { - return nil, err - } + // Fetch experimental models + expIter := client.ListModels(ctx) + for { + m, err := expIter.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, fmt.Errorf("failed to list experimental Google AI models: %v", err) + } - var models []Model - for _, m := range googleModels.Models { - // Extract model ID from the name field (e.g., "models/gemini-pro") modelID := strings.TrimPrefix(m.Name, "models/") models = append(models, Model{ ID: modelID, Object: "model", - Name: m.Name, + Name: m.DisplayName, Version: m.Version, Description: m.Description, - InputTokenLimit: m.InputTokenLimit, - OutputTokenLimit: m.OutputTokenLimit, + InputTokenLimit: int(m.InputTokenLimit), + OutputTokenLimit: int(m.OutputTokenLimit), SupportedGenerationMethods: m.SupportedGenerationMethods, - Temperature: m.Temperature, - TopP: m.TopP, - TopK: m.TopK, + Temperature: 0.0, // Default or set based on your needs + TopP: 0.0, // Default or set based on your needs + TopK: 0, // Default or set based on your needs Capabilities: Capabilities{ Completion: true, ChatCompletion: true, Embeddings: strings.Contains(modelID, "embedding"), }, - LifecycleStatus: "active", - Status: "ready", + LifecycleStatus: "experimental", // You may need to adjust this based on the actual model status + Status: "ready", // You may need to adjust this based on the actual model status + IsExperimental: true, }) }