diff --git a/main.go b/main.go index 007b909..151a6c9 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,9 @@ package main import ( - "github.com/gyarbij/azure-oai-proxy/pkg/azure" - "github.com/gyarbij/azure-oai-proxy/pkg/openai" + "encoding/json" "github.com/gin-gonic/gin" + "github.com/gyarbij/azure-oai-proxy/pkg/azure" "log" "net/http" "os" @@ -31,17 +31,16 @@ func main() { if ProxyMode == "azure" { router.GET("/v1/models", handleGetModels) router.OPTIONS("/v1/*path", handleOptions) - // Existing routes router.POST("/v1/chat/completions", handleAzureProxy) router.POST("/v1/completions", handleAzureProxy) router.POST("/v1/embeddings", handleAzureProxy) // DALL-E routes router.POST("/v1/images/generations", handleAzureProxy) - // speech- routes - router.POST("/v1/audio/speech", handleAzureProxy) - router.GET("/v1/audio/voices", handleAzureProxy) - router.POST("/v1/audio/transcriptions", handleAzureProxy) - router.POST("/v1/audio/translations", handleAzureProxy) + // speech- routes + router.POST("/v1/audio/speech", handleAzureProxy) + router.GET("/v1/audio/voices", handleAzureProxy) + router.POST("/v1/audio/transcriptions", handleAzureProxy) + router.POST("/v1/audio/translations", handleAzureProxy) // Fine-tuning routes router.POST("/v1/fine_tunes", handleAzureProxy) router.GET("/v1/fine_tunes", handleAzureProxy) @@ -57,7 +56,7 @@ func main() { // Deployments management routes router.GET("/deployments", handleAzureProxy) router.GET("/deployments/:deployment_id", handleAzureProxy) - router.GET("/v1/models/:model_id/capabilities", handleAzureProxy) + router.GET("/v1/models/:model_id/capabilities", handleAzureProxy) } else { router.Any("*path", handleOpenAIProxy) } @@ -66,51 +65,83 @@ func main() { } func handleGetModels(c *gin.Context) { - models := []string{ - "gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-4o-2024-05-13", "gpt-4-turbo-2024-04-09", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", "gpt-4-vision-preview", "gpt-4-32k-0613", - "gpt-35-turbo-0301", "gpt-35-turbo-0613", "gpt-35-turbo-1106", "gpt-35-turbo-0125", "gpt-35-turbo-16k", - "text-embedding-3-large", "text-embedding-3-small", "text-embedding-ada-002", - "dall-e-2", "dall-e-3", - "babbage-002", "davinci-002", "whisper-001", - } - result := azure.ListModelResponse{ - Object: "list", - } - for _, model := range models { - result.Data = append(result.Data, azure.Model{ - ID: model, - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: []azure.ModelPermission{ - { - ID: "", - Object: "model", - Created: 1679602087, - AllowCreateEngine: true, - AllowSampling: true, - AllowLogprobs: true, - AllowSearchIndices: true, - AllowView: true, - AllowFineTuning: true, - Organization: "*", - Group: nil, - IsBlocking: false, - }, - }, - Root: model, - Parent: nil, - }) - } - c.JSON(200, result) + models, err := fetchDeployedModels() + if err != nil { + log.Printf("error fetching deployed models: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to fetch deployed models"}) + return + } + result := azure.ListModelResponse{ + Object: "list", + Data: models, + } + c.JSON(http.StatusOK, result) +} + +func fetchDeployedModels() ([]azure.Model, error) { + endpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") + token := os.Getenv("AZURE_OPENAI_TOKEN") + + req, err := http.NewRequest("GET", endpoint+"/openai/deployments?api-version=2024-05-01-preview", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := ioutil.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to fetch deployed models: %s", string(body)) + } + + var deployedModelsResponse azure.ListDeployedModelsResponse + if err := json.NewDecoder(resp.Body).Decode(&deployedModelsResponse); err != nil { + return nil, err + } + + models := []azure.Model{} + for _, deployedModel := range deployedModelsResponse.Data { + models = append(models, azure.Model{ + ID: deployedModel.ModelID, + Object: "model", + Created: int(deployedModel.CreatedAt), + OwnedBy: "openai", + Permission: []azure.ModelPermission{ + { + ID: "", + Object: "model", + Created: int(deployedModel.CreatedAt), + AllowCreateEngine: true, + AllowSampling: true, + AllowLogprobs: true, + AllowSearchIndices: true, + AllowView: true, + AllowFineTuning: true, + Organization: "*", + Group: nil, + IsBlocking: false, + }, + }, + Root: deployedModel.ModelID, + Parent: nil, + }) + } + + return models, nil } func handleOptions(c *gin.Context) { - c.Header("Access-Control-Allow-Origin", "*") - c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") - c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization") - c.Status(200) - return + c.Header("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") + c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization") + c.Status(200) + return } func handleAzureProxy(c *gin.Context) { @@ -128,13 +159,12 @@ func handleAzureProxy(c *gin.Context) { } } - // Enhanced error logging if c.Writer.Status() >= 400 { log.Printf("Azure API request failed: %s %s, Status: %d", c.Request.Method, c.Request.URL.Path, c.Writer.Status()) } } func handleOpenAIProxy(c *gin.Context) { - server := openai.NewOpenAIReverseProxy() - server.ServeHTTP(c.Writer, c.Request) + server := openai.NewOpenAIReverseProxy() + server.ServeHTTP(c.Writer, c.Request) }