diff --git a/api/api/ai.go b/api/api/ai.go index 3165949..1ba8fab 100644 --- a/api/api/ai.go +++ b/api/api/ai.go @@ -56,7 +56,14 @@ func chatHandler(w http.ResponseWriter, r *http.Request) { } defer conn.Close() - client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) + apiKey := os.Getenv("API_KEY") + baseUrl := os.Getenv("BASE_URL") + model := os.Getenv("MODEL") + config := openai.DefaultConfig(apiKey) + if baseUrl != "" { + config.BaseURL = baseUrl + } + client := openai.NewClientWithConfig(config) for { messageType, p, err := conn.ReadMessage() if err != nil { @@ -77,7 +84,7 @@ func chatHandler(w http.ResponseWriter, r *http.Request) { stream, err := client.CreateChatCompletionStream( context.Background(), openai.ChatCompletionRequest{ - Model: openai.GPT3Dot5Turbo, + Model: model, Messages: messages, Stream: true, }, diff --git a/api/api/ai_test.go b/api/api/ai_test.go index 6089591..869a32e 100644 --- a/api/api/ai_test.go +++ b/api/api/ai_test.go @@ -11,7 +11,7 @@ import ( ) func TestAiFastaParse(t *testing.T) { - apiKey := os.Getenv("OPENAI_API_KEY") + apiKey := os.Getenv("API_KEY") if apiKey == "" { return }