Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1.0.4-rc-patch-2 #20

Merged
merged 12 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
4 changes: 1 addition & 3 deletions .github/workflows/ghcr-docker-publish.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
name: Docker

on:
schedule:
- cron: '25 8 * * *'
push:
branches: [ "main" ]
branches: [ "main", dev ]
# Publish semver tags as releases.
tags: [ 'v*.*.*', '*.*.*' ]
pull_request:
Expand Down
143 changes: 89 additions & 54 deletions pkg/azure/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package azure

import (
"bytes"
"fmt"
"io/ioutil"
"encoding/json"
"io"
"log"
"net/http"
"net/http/httputil"
Expand All @@ -18,7 +18,7 @@ import (

var (
AzureOpenAIToken = ""
AzureOpenAIAPIVersion = "2024-05-01-preview"
AzureOpenAIAPIVersion = "2024-06-01"
AzureOpenAIEndpoint = ""
AzureOpenAIModelMapper = map[string]string{
"gpt-3.5-turbo": "gpt-35-turbo",
Expand All @@ -31,6 +31,7 @@ var (
"gpt-4-32k": "gpt-4-32k",
"gpt-4-32k-0613": "gpt-4-32k-0613",
"gpt-4o": "gpt-4o",
"gpt-4o-mini": "gpt-4o-mini",
"gpt-4o-2024-05-13": "gpt-4o-2024-05-13",
"gpt-4-turbo": "gpt-4-turbo",
"gpt-4-vision-preview": "gpt-4-vision-preview",
Expand Down Expand Up @@ -92,14 +93,65 @@ func NewOpenAIReverseProxy() *httputil.ReverseProxy {
}
}

func getModelFromRequest(req *http.Request) string {
if req.Body == nil {
return ""
}
body, _ := io.ReadAll(req.Body)
req.Body = io.NopCloser(bytes.NewBuffer(body))
return gjson.GetBytes(body, "model").String()
}

// sanitizeHeaders returns a copy of the headers with sensitive information redacted
func sanitizeHeaders(headers http.Header) http.Header {
sanitized := make(http.Header)
for key, values := range headers {
if key == "Authorization" || key == "api-key" {
sanitized[key] = []string{"[REDACTED]"}
} else {
sanitized[key] = values
}
}
return sanitized
}

func HandleToken(req *http.Request) {
var token string

// Check for API Key in the api-key header
if apiKey := req.Header.Get("api-key"); apiKey != "" {
token = apiKey
} else if authHeader := req.Header.Get("Authorization"); authHeader != "" {
// If not found, check for Authorization header
token = strings.TrimPrefix(authHeader, "Bearer ")
} else if AzureOpenAIToken != "" {
// If neither is present, use the AzureOpenAIToken if set
token = AzureOpenAIToken
} else if envApiKey := os.Getenv("AZURE_OPENAI_API_KEY"); envApiKey != "" {
// As a last resort, check for API key in environment variable
token = envApiKey
}

if token != "" {
// Set the api-key header with the found token
req.Header.Set("api-key", token)
// Remove the Authorization header to avoid conflicts
req.Header.Del("Authorization")
} else {
log.Println("Warning: No authentication token found")
}
}

// Update the makeDirector function to handle the new endpoint structure
func makeDirector(remote *url.URL) func(*http.Request) {
return func(req *http.Request) {

// Get model and map it to deployment
model := getModelFromRequest(req)
deployment := GetDeploymentByModel(model)

// Handle token
handleToken(req)
HandleToken(req)

// Set the Host, Scheme, Path, and RawPath of the request
originURL := req.URL.String()
Expand All @@ -110,78 +162,61 @@ func makeDirector(remote *url.URL) func(*http.Request) {
// Handle different endpoints
switch {
case strings.HasPrefix(req.URL.Path, "/v1/chat/completions"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "chat/completions")
req.URL.Path = path.Join("/openai/deployments", deployment, "chat/completions")
case strings.HasPrefix(req.URL.Path, "/v1/completions"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "completions")
req.URL.Path = path.Join("/openai/deployments", deployment, "completions")
case strings.HasPrefix(req.URL.Path, "/v1/embeddings"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "embeddings")
req.URL.Path = path.Join("/openai/deployments", deployment, "embeddings")
case strings.HasPrefix(req.URL.Path, "/v1/images/generations"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "images/generations")
req.URL.Path = path.Join("/openai/deployments", deployment, "images/generations")
case strings.HasPrefix(req.URL.Path, "/v1/fine_tunes"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "fine-tunes")
req.URL.Path = path.Join("/openai/deployments", deployment, "fine-tunes")
case strings.HasPrefix(req.URL.Path, "/v1/files"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "files")
req.URL.Path = path.Join("/openai/deployments", deployment, "files")
case strings.HasPrefix(req.URL.Path, "/v1/audio/speech"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "audio/speech")
req.URL.Path = path.Join("/openai/deployments", deployment, "audio/speech")
case strings.HasPrefix(req.URL.Path, "/v1/audio/transcriptions"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "transcriptions")
req.URL.Path = path.Join("/openai/deployments", deployment, "transcriptions")
case strings.HasPrefix(req.URL.Path, "/v1/audio/translations"):
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "translations")
req.URL.Path = path.Join("/openai/deployments", deployment, "translations")
default:
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.TrimPrefix(req.URL.Path, "/v1/"))
req.URL.Path = path.Join("/openai/deployments", deployment, strings.TrimPrefix(req.URL.Path, "/v1/"))
}

req.URL.RawPath = req.URL.EscapedPath()

// Add logging for new parameters
if req.Body != nil {
var requestBody map[string]interface{}
bodyBytes, _ := io.ReadAll(req.Body)
json.Unmarshal(bodyBytes, &requestBody)

newParams := []string{"completion_config", "presence_penalty", "frequency_penalty", "best_of"}
for _, param := range newParams {
if val, ok := requestBody[param]; ok {
log.Printf("Request includes %s parameter: %v", param, val)
}
}

// Restore the body to the request
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}

// Add the api-version query parameter
query := req.URL.Query()
query.Add("api-version", AzureOpenAIAPIVersion)
req.URL.RawQuery = query.Encode()

log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String())
}
}

func getModelFromRequest(req *http.Request) string {
if req.Body == nil {
return ""
}
body, _ := ioutil.ReadAll(req.Body)
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
return gjson.GetBytes(body, "model").String()
}

func handleToken(req *http.Request) {
token := ""
if AzureOpenAIToken != "" {
token = AzureOpenAIToken
} else {
token = strings.ReplaceAll(req.Header.Get("Authorization"), "Bearer ", "")
}
req.Header.Set("api-key", token)
req.Header.Del("Authorization")
}

func HandleToken(req *http.Request) {
token := ""
if AzureOpenAIToken != "" {
token = AzureOpenAIToken
} else if authHeader := req.Header.Get("Authorization"); authHeader != "" {
token = strings.TrimPrefix(authHeader, "Bearer ")
} else if apiKey := os.Getenv("AZURE_OPENAI_API_KEY"); apiKey != "" {
token = apiKey
}

if token != "" {
req.Header.Set("api-key", token)
req.Header.Del("Authorization")
log.Printf("Proxying request [%s] %s -> %s", model, originURL, req.URL.String())
// log.Printf("Sanitized Request Headers: %v", sanitizeHeaders(req.Header))
}
}

func modifyResponse(res *http.Response) error {
// Handle rate limiting headers
if res.StatusCode == http.StatusTooManyRequests {
log.Printf("Rate limit exceeded: %s", res.Header.Get("Retry-After"))
if res.StatusCode >= 400 {
body, _ := io.ReadAll(res.Body)
log.Printf("Azure API Error Response: Status: %d, Body: %s", res.StatusCode, string(body))
res.Body = io.NopCloser(bytes.NewBuffer(body))
}

// Handle streaming responses
Expand Down