Skip to content

Commit

Permalink
Merge pull request release-1.0.4 from Gyarbij/dev
Browse files Browse the repository at this point in the history
1.0.4
  • Loading branch information
Gyarbij authored Jul 20, 2024
2 parents dc26227 + 4178857 commit 82c87cb
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 57 deletions.
Binary file added .DS_Store
Binary file not shown.
4 changes: 1 addition & 3 deletions .github/workflows/ghcr-docker-publish.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
name: Docker

on:
schedule:
- cron: '25 8 * * *'
push:
branches: [ "main" ]
branches: [ "main", dev ]
# Publish semver tags as releases.
tags: [ 'v*.*.*', '*.*.*' ]
pull_request:
Expand Down
143 changes: 89 additions & 54 deletions pkg/azure/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package azure

import (
"bytes"
"fmt"
"io/ioutil"
"encoding/json"
"io"
"log"
"net/http"
"net/http/httputil"
Expand All @@ -18,7 +18,7 @@ import (

var (
AzureOpenAIToken = ""
AzureOpenAIAPIVersion = "2024-05-01-preview"
AzureOpenAIAPIVersion = "2024-06-01"
AzureOpenAIEndpoint = ""
AzureOpenAIModelMapper = map[string]string{
"gpt-3.5-turbo": "gpt-35-turbo",
Expand All @@ -31,6 +31,7 @@ var (
"gpt-4-32k": "gpt-4-32k",
"gpt-4-32k-0613": "gpt-4-32k-0613",
"gpt-4o": "gpt-4o",
"gpt-4o-mini": "gpt-4o-mini",
"gpt-4o-2024-05-13": "gpt-4o-2024-05-13",
"gpt-4-turbo": "gpt-4-turbo",
"gpt-4-vision-preview": "gpt-4-vision-preview",
Expand Down Expand Up @@ -92,14 +93,65 @@ func NewOpenAIReverseProxy() *httputil.ReverseProxy {
}
}

func getModelFromRequest(req *http.Request) string {
if req.Body == nil {
return ""
}
body, _ := io.ReadAll(req.Body)
req.Body = io.NopCloser(bytes.NewBuffer(body))
return gjson.GetBytes(body, "model").String()
}

// sanitizeHeaders returns a copy of the headers with sensitive information redacted
func sanitizeHeaders(headers http.Header) http.Header {
sanitized := make(http.Header)
for key, values := range headers {
if key == "Authorization" || key == "api-key" {
sanitized[key] = []string{"[REDACTED]"}
} else {
sanitized[key] = values
}
}
return sanitized
}

func HandleToken(req *http.Request) {
var token string

// Check for API Key in the api-key header
if apiKey := req.Header.Get("api-key"); apiKey != "" {
token = apiKey
} else if authHeader := req.Header.Get("Authorization"); authHeader != "" {
// If not found, check for Authorization header
token = strings.TrimPrefix(authHeader, "Bearer ")
} else if AzureOpenAIToken != "" {
// If neither is present, use the AzureOpenAIToken if set
token = AzureOpenAIToken
} else if envApiKey := os.Getenv("AZURE_OPENAI_API_KEY"); envApiKey != "" {
// As a last resort, check for API key in environment variable
token = envApiKey
}

if token != "" {
// Set the api-key header with the found token
req.Header.Set("api-key", token)
// Remove the Authorization header to avoid conflicts
req.Header.Del("Authorization")
} else {
log.Println("Warning: No authentication token found")
}
}

// Update the makeDirector function to handle the new endpoint structure
func makeDirector(remote *url.URL) func(*http.Request) {
return func(req *http.Request) {

// Get model and map it to deployment
model := getModelFromRequest(req)
deployment := GetDeploymentByModel(model)

// Handle token
handleToken(req)
HandleToken(req)

// Set the Host, Scheme, Path, and RawPath of the request
originURL := req.URL.String()
Expand All @@ -110,78 +162,61 @@ func makeDirector(remote *url.URL) func(*http.Request) {
// Handle different endpoints
switch {
case strings.HasPrefix(req.URL.Path, "/v1/chat/completions"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "chat/completions")
req.URL.Path = path.Join("/openai/deployments", deployment, "chat/completions")
case strings.HasPrefix(req.URL.Path, "/v1/completions"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "completions")
req.URL.Path = path.Join("/openai/deployments", deployment, "completions")
case strings.HasPrefix(req.URL.Path, "/v1/embeddings"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "embeddings")
req.URL.Path = path.Join("/openai/deployments", deployment, "embeddings")
case strings.HasPrefix(req.URL.Path, "/v1/images/generations"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "images/generations")
req.URL.Path = path.Join("/openai/deployments", deployment, "images/generations")
case strings.HasPrefix(req.URL.Path, "/v1/fine_tunes"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "fine-tunes")
req.URL.Path = path.Join("/openai/deployments", deployment, "fine-tunes")
case strings.HasPrefix(req.URL.Path, "/v1/files"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "files")
req.URL.Path = path.Join("/openai/deployments", deployment, "files")
case strings.HasPrefix(req.URL.Path, "/v1/audio/speech"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "audio/speech")
req.URL.Path = path.Join("/openai/deployments", deployment, "audio/speech")
case strings.HasPrefix(req.URL.Path, "/v1/audio/transcriptions"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "transcriptions")
req.URL.Path = path.Join("/openai/deployments", deployment, "transcriptions")
case strings.HasPrefix(req.URL.Path, "/v1/audio/translations"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "translations")
req.URL.Path = path.Join("/openai/deployments", deployment, "translations")
default:
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.TrimPrefix(req.URL.Path, "/v1/"))
req.URL.Path = path.Join("/openai/deployments", deployment, strings.TrimPrefix(req.URL.Path, "/v1/"))
}

req.URL.RawPath = req.URL.EscapedPath()

// Add logging for new parameters
if req.Body != nil {
var requestBody map[string]interface{}
bodyBytes, _ := io.ReadAll(req.Body)
json.Unmarshal(bodyBytes, &requestBody)

newParams := []string{"completion_config", "presence_penalty", "frequency_penalty", "best_of"}
for _, param := range newParams {
if val, ok := requestBody[param]; ok {
log.Printf("Request includes %s parameter: %v", param, val)
}
}

// Restore the body to the request
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}

// Add the api-version query parameter
query := req.URL.Query()
query.Add("api-version", AzureOpenAIAPIVersion)
req.URL.RawQuery = query.Encode()

log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String())
}
}

func getModelFromRequest(req *http.Request) string {
if req.Body == nil {
return ""
}
body, _ := ioutil.ReadAll(req.Body)
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
return gjson.GetBytes(body, "model").String()
}

func handleToken(req *http.Request) {
token := ""
if AzureOpenAIToken != "" {
token = AzureOpenAIToken
} else {
token = strings.ReplaceAll(req.Header.Get("Authorization"), "Bearer ", "")
}
req.Header.Set("api-key", token)
req.Header.Del("Authorization")
}

func HandleToken(req *http.Request) {
token := ""
if AzureOpenAIToken != "" {
token = AzureOpenAIToken
} else if authHeader := req.Header.Get("Authorization"); authHeader != "" {
token = strings.TrimPrefix(authHeader, "Bearer ")
} else if apiKey := os.Getenv("AZURE_OPENAI_API_KEY"); apiKey != "" {
token = apiKey
}

if token != "" {
req.Header.Set("api-key", token)
req.Header.Del("Authorization")
log.Printf("Proxying request [%s] %s -> %s", model, originURL, req.URL.String())
// log.Printf("Sanitized Request Headers: %v", sanitizeHeaders(req.Header))
}
}

func modifyResponse(res *http.Response) error {
// Handle rate limiting headers
if res.StatusCode == http.StatusTooManyRequests {
log.Printf("Rate limit exceeded: %s", res.Header.Get("Retry-After"))
if res.StatusCode >= 400 {
body, _ := io.ReadAll(res.Body)
log.Printf("Azure API Error Response: Status: %d, Body: %s", res.StatusCode, string(body))
res.Body = io.NopCloser(bytes.NewBuffer(body))
}

// Handle streaming responses
Expand Down

0 comments on commit 82c87cb

Please sign in to comment.