Skip to content

Commit

Permalink
feat: Add support for fetching deployed models in Azure
Browse files Browse the repository at this point in the history
This commit adds the functionality to fetch deployed models in Azure. It includes the implementation of the `fetchDeployedModels` function, which makes a GET request to the Azure OpenAI API to retrieve the list of deployed models. The fetched models are then returned as a JSON response.

The `handleGetModels` function has been updated to use the `fetchDeployedModels` function and return the fetched models as a response.

This feature allows users to retrieve information about the deployed models in Azure, enabling a more consistent experience to the openai api in open web-ui.
  • Loading branch information
Gyarbij committed Jun 22, 2024
1 parent f818cda commit ba1ccb0
Showing 1 changed file with 84 additions and 54 deletions.
138 changes: 84 additions & 54 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package main

import (
"github.com/gyarbij/azure-oai-proxy/pkg/azure"
"github.com/gyarbij/azure-oai-proxy/pkg/openai"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/gyarbij/azure-oai-proxy/pkg/azure"
"log"
"net/http"
"os"
Expand Down Expand Up @@ -31,17 +31,16 @@ func main() {
if ProxyMode == "azure" {
router.GET("/v1/models", handleGetModels)
router.OPTIONS("/v1/*path", handleOptions)
// Existing routes
router.POST("/v1/chat/completions", handleAzureProxy)
router.POST("/v1/completions", handleAzureProxy)
router.POST("/v1/embeddings", handleAzureProxy)
// DALL-E routes
router.POST("/v1/images/generations", handleAzureProxy)
// speech- routes
router.POST("/v1/audio/speech", handleAzureProxy)
router.GET("/v1/audio/voices", handleAzureProxy)
router.POST("/v1/audio/transcriptions", handleAzureProxy)
router.POST("/v1/audio/translations", handleAzureProxy)
// speech- routes
router.POST("/v1/audio/speech", handleAzureProxy)
router.GET("/v1/audio/voices", handleAzureProxy)
router.POST("/v1/audio/transcriptions", handleAzureProxy)
router.POST("/v1/audio/translations", handleAzureProxy)
// Fine-tuning routes
router.POST("/v1/fine_tunes", handleAzureProxy)
router.GET("/v1/fine_tunes", handleAzureProxy)
Expand All @@ -57,7 +56,7 @@ func main() {
// Deployments management routes
router.GET("/deployments", handleAzureProxy)
router.GET("/deployments/:deployment_id", handleAzureProxy)
router.GET("/v1/models/:model_id/capabilities", handleAzureProxy)
router.GET("/v1/models/:model_id/capabilities", handleAzureProxy)
} else {
router.Any("*path", handleOpenAIProxy)
}
Expand All @@ -66,51 +65,83 @@ func main() {
}

func handleGetModels(c *gin.Context) {
models := []string{
"gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-4o-2024-05-13", "gpt-4-turbo-2024-04-09", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", "gpt-4-vision-preview", "gpt-4-32k-0613",
"gpt-35-turbo-0301", "gpt-35-turbo-0613", "gpt-35-turbo-1106", "gpt-35-turbo-0125", "gpt-35-turbo-16k",
"text-embedding-3-large", "text-embedding-3-small", "text-embedding-ada-002",
"dall-e-2", "dall-e-3",
"babbage-002", "davinci-002", "whisper-001",
}
result := azure.ListModelResponse{
Object: "list",
}
for _, model := range models {
result.Data = append(result.Data, azure.Model{
ID: model,
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: []azure.ModelPermission{
{
ID: "",
Object: "model",
Created: 1679602087,
AllowCreateEngine: true,
AllowSampling: true,
AllowLogprobs: true,
AllowSearchIndices: true,
AllowView: true,
AllowFineTuning: true,
Organization: "*",
Group: nil,
IsBlocking: false,
},
},
Root: model,
Parent: nil,
})
}
c.JSON(200, result)
models, err := fetchDeployedModels()
if err != nil {
log.Printf("error fetching deployed models: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to fetch deployed models"})
return
}
result := azure.ListModelResponse{
Object: "list",
Data: models,
}
c.JSON(http.StatusOK, result)
}

func fetchDeployedModels() ([]azure.Model, error) {
endpoint := os.Getenv("AZURE_OPENAI_ENDPOINT")
token := os.Getenv("AZURE_OPENAI_TOKEN")

req, err := http.NewRequest("GET", endpoint+"/openai/deployments?api-version=2024-05-01-preview", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token)

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := ioutil.ReadAll(resp.Body)
return nil, fmt.Errorf("failed to fetch deployed models: %s", string(body))
}

var deployedModelsResponse azure.ListDeployedModelsResponse
if err := json.NewDecoder(resp.Body).Decode(&deployedModelsResponse); err != nil {
return nil, err
}

models := []azure.Model{}
for _, deployedModel := range deployedModelsResponse.Data {
models = append(models, azure.Model{
ID: deployedModel.ModelID,
Object: "model",
Created: int(deployedModel.CreatedAt),
OwnedBy: "openai",
Permission: []azure.ModelPermission{
{
ID: "",
Object: "model",
Created: int(deployedModel.CreatedAt),
AllowCreateEngine: true,
AllowSampling: true,
AllowLogprobs: true,
AllowSearchIndices: true,
AllowView: true,
AllowFineTuning: true,
Organization: "*",
Group: nil,
IsBlocking: false,
},
},
Root: deployedModel.ModelID,
Parent: nil,
})
}

return models, nil
}

func handleOptions(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.Status(200)
return
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.Status(200)
return
}

func handleAzureProxy(c *gin.Context) {
Expand All @@ -128,13 +159,12 @@ func handleAzureProxy(c *gin.Context) {
}
}

// Enhanced error logging
if c.Writer.Status() >= 400 {
log.Printf("Azure API request failed: %s %s, Status: %d", c.Request.Method, c.Request.URL.Path, c.Writer.Status())
}
}

func handleOpenAIProxy(c *gin.Context) {
server := openai.NewOpenAIReverseProxy()
server.ServeHTTP(c.Writer, c.Request)
server := openai.NewOpenAIReverseProxy()
server.ServeHTTP(c.Writer, c.Request)
}

0 comments on commit ba1ccb0

Please sign in to comment.