diff --git a/llms/anthropic/anthropicllm.go b/llms/anthropic/anthropicllm.go index cddb7ce05..5ce6c2105 100644 --- a/llms/anthropic/anthropicllm.go +++ b/llms/anthropic/anthropicllm.go @@ -14,7 +14,6 @@ import ( ) var ( - ErrEmptyResponse = errors.New("no response") ErrMissingToken = errors.New("missing the Anthropic API key, set it in the ANTHROPIC_API_KEY environment variable") ErrUnexpectedResponseLength = errors.New("unexpected length of response") ErrInvalidContentType = errors.New("invalid content type") @@ -92,7 +91,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten func generateCompletionsContent(ctx context.Context, o *LLM, messages []llms.MessageContent, opts *llms.CallOptions) (*llms.ContentResponse, error) { if len(messages) == 0 || len(messages[0].Parts) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } msg0 := messages[0] @@ -153,7 +152,7 @@ func generateMessagesContent(ctx context.Context, o *LLM, messages []llms.Messag return nil, fmt.Errorf("anthropic: failed to create message: %w", err) } if result == nil { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } choices := make([]*llms.ContentChoice, len(result.Content)) diff --git a/llms/anthropic/internal/anthropicclient/anthropicclient.go b/llms/anthropic/internal/anthropicclient/anthropicclient.go index cbe9fd9bf..bdcc275e7 100644 --- a/llms/anthropic/internal/anthropicclient/anthropicclient.go +++ b/llms/anthropic/internal/anthropicclient/anthropicclient.go @@ -16,9 +16,6 @@ const ( defaultModel = "claude-3-5-sonnet-20240620" ) -// ErrEmptyResponse is returned when the Anthropic API returns an empty response. -var ErrEmptyResponse = errors.New("empty response") - // Client is a client for the Anthropic API. type Client struct { token string diff --git a/llms/bedrock/internal/bedrockclient/defaults.go b/llms/bedrock/internal/bedrockclient/defaults.go new file mode 100644 index 000000000..0dd83d0fc --- /dev/null +++ b/llms/bedrock/internal/bedrockclient/defaults.go @@ -0,0 +1,7 @@ +package bedrockclient + +const ( + DefaultMaxTokenLength2048 = 2048 + DefaultMaxTokenLength512 = 512 + DefaultMaxTokenLength20 = 20 +) diff --git a/llms/bedrock/internal/bedrockclient/provider_ai21.go b/llms/bedrock/internal/bedrockclient/provider_ai21.go index 5ff482b15..94ebaf3ab 100644 --- a/llms/bedrock/internal/bedrockclient/provider_ai21.go +++ b/llms/bedrock/internal/bedrockclient/provider_ai21.go @@ -81,7 +81,7 @@ func createAi21Completion(ctx context.Context, client *bedrockruntime.Client, mo Prompt: txt, Temperature: options.Temperature, TopP: options.TopP, - MaxTokens: getMaxTokens(options.MaxTokens, 2048), + MaxTokens: getMaxTokens(options.MaxTokens, DefaultMaxTokenLength2048), StopSequences: options.StopWords, CountPenalty: struct { Scale float64 `json:"scale"` diff --git a/llms/bedrock/internal/bedrockclient/provider_amazon.go b/llms/bedrock/internal/bedrockclient/provider_amazon.go index bae6eb825..4a05deac6 100644 --- a/llms/bedrock/internal/bedrockclient/provider_amazon.go +++ b/llms/bedrock/internal/bedrockclient/provider_amazon.go @@ -67,7 +67,7 @@ func createAmazonCompletion(ctx context.Context, inputContent := amazonTextGenerationInput{ InputText: txt, TextGenerationConfig: amazonTextGenerationConfigInput{ - MaxTokens: getMaxTokens(options.MaxTokens, 512), + MaxTokens: getMaxTokens(options.MaxTokens, DefaultMaxTokenLength512), TopP: options.TopP, Temperature: options.Temperature, StopSequences: options.StopWords, diff --git a/llms/bedrock/internal/bedrockclient/provider_anthropic.go b/llms/bedrock/internal/bedrockclient/provider_anthropic.go index 5258da760..31255e48a 100644 --- a/llms/bedrock/internal/bedrockclient/provider_anthropic.go +++ b/llms/bedrock/internal/bedrockclient/provider_anthropic.go @@ -134,7 +134,7 @@ func createAnthropicCompletion(ctx context.Context, input := anthropicTextGenerationInput{ AnthropicVersion: AnthropicLatestVersion, - MaxTokens: getMaxTokens(options.MaxTokens, 2048), + MaxTokens: getMaxTokens(options.MaxTokens, DefaultMaxTokenLength2048), System: systemPrompt, Messages: inputContents, Temperature: options.Temperature, diff --git a/llms/bedrock/internal/bedrockclient/provider_cohere.go b/llms/bedrock/internal/bedrockclient/provider_cohere.go index 1ededb2f1..1d5ca9c9b 100644 --- a/llms/bedrock/internal/bedrockclient/provider_cohere.go +++ b/llms/bedrock/internal/bedrockclient/provider_cohere.go @@ -74,7 +74,7 @@ func createCohereCompletion(ctx context.Context, Temperature: options.Temperature, P: options.TopP, K: options.TopK, - MaxTokens: getMaxTokens(options.MaxTokens, 20), + MaxTokens: getMaxTokens(options.MaxTokens, DefaultMaxTokenLength20), StopSequences: options.StopWords, NumGenerations: options.CandidateCount, } diff --git a/llms/bedrock/internal/bedrockclient/provider_meta.go b/llms/bedrock/internal/bedrockclient/provider_meta.go index 737918712..1cf667f5c 100644 --- a/llms/bedrock/internal/bedrockclient/provider_meta.go +++ b/llms/bedrock/internal/bedrockclient/provider_meta.go @@ -56,7 +56,7 @@ func createMetaCompletion(ctx context.Context, Prompt: txt, Temperature: options.Temperature, TopP: options.TopP, - MaxGenLen: getMaxTokens(options.MaxTokens, 512), + MaxGenLen: getMaxTokens(options.MaxTokens, DefaultMaxTokenLength512), } body, err := json.Marshal(input) diff --git a/llms/cloudflare/cloudflarellm.go b/llms/cloudflare/cloudflarellm.go index e6cb66b1b..f1d33da50 100644 --- a/llms/cloudflare/cloudflarellm.go +++ b/llms/cloudflare/cloudflarellm.go @@ -10,11 +10,6 @@ import ( "github.com/tmc/langchaingo/llms/cloudflare/internal/cloudflareclient" ) -var ( - ErrEmptyResponse = errors.New("no response") - ErrIncompleteEmbedding = errors.New("not all input got embedded") -) - // LLM is a cloudflare LLM implementation. type LLM struct { CallbacksHandler callbacks.Handler @@ -147,11 +142,11 @@ func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]flo } if len(res.Result.Data) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } if len(inputTexts) != len(res.Result.Data) { - return res.Result.Data, ErrIncompleteEmbedding + return res.Result.Data, llms.ErrIncompleteEmbedding } return res.Result.Data, nil diff --git a/llms/cloudflare/internal/cloudflareclient/api.go b/llms/cloudflare/internal/cloudflareclient/api.go index 4eb54aee6..af5a1f479 100644 --- a/llms/cloudflare/internal/cloudflareclient/api.go +++ b/llms/cloudflare/internal/cloudflareclient/api.go @@ -37,7 +37,7 @@ func (c *Client) CreateEmbedding(ctx context.Context, texts *CreateEmbeddingRequ return nil, err } - if resp.StatusCode > 299 { + if resp.StatusCode >= http.StatusMultipleChoices { return nil, fmt.Errorf("error: %s", body) } @@ -81,7 +81,7 @@ func (c *Client) GenerateContent(ctx context.Context, request *GenerateContentRe return nil, err } - if response.StatusCode > 299 { + if response.StatusCode >= http.StatusMultipleChoices { return nil, fmt.Errorf("error: %s", body) } @@ -165,7 +165,7 @@ func (c *Client) Summarize(ctx context.Context, inputText string, maxLength int) return nil, err } - if resp.StatusCode > 299 { + if resp.StatusCode >= http.StatusMultipleChoices { return nil, fmt.Errorf("error: %s", body) } diff --git a/llms/cohere/coherellm.go b/llms/cohere/coherellm.go index ba2391b10..b02c4fbe4 100644 --- a/llms/cohere/coherellm.go +++ b/llms/cohere/coherellm.go @@ -11,8 +11,7 @@ import ( ) var ( - ErrEmptyResponse = errors.New("no response") - ErrMissingToken = errors.New("missing the COHERE_API_KEY key, set it in the COHERE_API_KEY environment variable") + ErrMissingToken = errors.New("missing the COHERE_API_KEY key, set it in the COHERE_API_KEY environment variable") ErrUnexpectedResponseLength = errors.New("unexpected length of response") ) diff --git a/llms/cohere/internal/cohereclient/cohereclient.go b/llms/cohere/internal/cohereclient/cohereclient.go index 0f45590e6..54f05e587 100644 --- a/llms/cohere/internal/cohereclient/cohereclient.go +++ b/llms/cohere/internal/cohereclient/cohereclient.go @@ -10,12 +10,10 @@ import ( "strings" "github.com/cohere-ai/tokenizer" + "github.com/tmc/langchaingo/llms" ) -var ( - ErrEmptyResponse = errors.New("empty response") - ErrModelNotFound = errors.New("model not found") -) +var ErrModelNotFound = errors.New("model not found") type Client struct { token string @@ -111,8 +109,8 @@ func (c *Client) CreateGeneration(ctx context.Context, r *GenerationRequest) (*G return nil, fmt.Errorf("create request: %w", err) } - req.Header.Set("content-type", "application/json") - req.Header.Set("authorization", "bearer "+c.token) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "bearer "+c.token) res, err := c.httpClient.Do(req) if err != nil { @@ -129,7 +127,7 @@ func (c *Client) CreateGeneration(ctx context.Context, r *GenerationRequest) (*G if strings.HasPrefix(response.Message, "model not found") { return nil, ErrModelNotFound } - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } var generation Generation diff --git a/llms/ernie/erniellm.go b/llms/ernie/erniellm.go index ff47e3faf..568a70cd8 100644 --- a/llms/ernie/erniellm.go +++ b/llms/ernie/erniellm.go @@ -11,10 +11,7 @@ import ( "github.com/tmc/langchaingo/llms/ernie/internal/ernieclient" ) -var ( - ErrEmptyResponse = errors.New("no response") - ErrCodeResponse = errors.New("has error code") -) +var ErrCodeResponse = errors.New("has error code") type LLM struct { client *ernieclient.Client diff --git a/llms/ernie/internal/ernieclient/ernieclient.go b/llms/ernie/internal/ernieclient/ernieclient.go index 5a563460b..42e36f3b8 100644 --- a/llms/ernie/internal/ernieclient/ernieclient.go +++ b/llms/ernie/internal/ernieclient/ernieclient.go @@ -11,6 +11,8 @@ import ( "net/http" "strings" "time" + + "github.com/tmc/langchaingo/llms" ) var ( @@ -18,7 +20,6 @@ var ( ErrCompletionCode = errors.New("completion API returned unexpected status code") ErrAccessTokenCode = errors.New("get access_token API returned unexpected status code") ErrEmbeddingCode = errors.New("embedding API returned unexpected status code") - ErrEmptyResponse = errors.New("empty response") ) // Client is a client for the ERNIE API. @@ -285,7 +286,7 @@ func (c *Client) CreateChat(ctx context.Context, r *ChatRequest) (*ChatResponse, } if resp.Result == "" && resp.FunctionCall == nil { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } return resp, nil diff --git a/llms/errors.go b/llms/errors.go new file mode 100644 index 000000000..7b1ac608b --- /dev/null +++ b/llms/errors.go @@ -0,0 +1,11 @@ +package llms + +import "errors" + +var ( + // ErrEmptyResponse is returned when an LLM returns an empty response. + ErrEmptyResponse = errors.New("no response") + // ErrIncompleteEmbedding is returned when the length of an embedding + // request does not match the length of the returned embeddings array. + ErrIncompleteEmbedding = errors.New("not all input got embedded") +) diff --git a/llms/googleai/internal/palmclient/palmclient.go b/llms/googleai/internal/palmclient/palmclient.go index aaefe28d1..7eb531959 100644 --- a/llms/googleai/internal/palmclient/palmclient.go +++ b/llms/googleai/internal/palmclient/palmclient.go @@ -21,10 +21,10 @@ var ( ) var defaultParameters = map[string]interface{}{ //nolint:gochecknoglobals - "temperature": 0.2, //nolint:gomnd - "maxOutputTokens": 256, //nolint:gomnd - "topP": 0.8, //nolint:gomnd - "topK": 40, //nolint:gomnd + "temperature": 0.2, //nolint:all + "maxOutputTokens": 256, //nolint:all + "topP": 0.8, //nolint:all + "topK": 40, //nolint:all } const ( @@ -65,9 +65,6 @@ func New(ctx context.Context, projectID, location string, opts ...option.ClientO }, nil } -// ErrEmptyResponse is returned when the OpenAI API returns an empty response. -var ErrEmptyResponse = errors.New("empty response") - // CompletionRequest is a request to create a completion. type CompletionRequest struct { Prompts []string `json:"prompts"` @@ -290,7 +287,7 @@ func (c *PaLMClient) batchPredict(ctx context.Context, model string, prompts []s return nil, err } if len(resp.GetPredictions()) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } return resp.GetPredictions(), nil } @@ -329,7 +326,7 @@ func (c *PaLMClient) chat(ctx context.Context, r *ChatRequest) ([]*structpb.Valu return nil, err } if len(resp.GetPredictions()) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } return resp.GetPredictions(), nil } diff --git a/llms/googleai/option.go b/llms/googleai/option.go index 6ba1b43b1..0f9d55078 100644 --- a/llms/googleai/option.go +++ b/llms/googleai/option.go @@ -25,18 +25,31 @@ type Options struct { ClientOptions []option.ClientOption } +const ( + CloudProject = "" + CloudLocation = "" + DefaultModel = "gemini-pro" + DefaultEmbeddingModel = "embedding-001" + DefaultCandidateCount = 1 + DefaultMaxTokens = 2048 + DefaultTemperature = 0.5 + DefaultTopK = 3 + DefaultTopP = 0.95 + DefaultHarmThreshold = HarmBlockOnlyHigh +) + func DefaultOptions() Options { return Options{ - CloudProject: "", - CloudLocation: "", - DefaultModel: "gemini-pro", - DefaultEmbeddingModel: "embedding-001", - DefaultCandidateCount: 1, - DefaultMaxTokens: 2048, - DefaultTemperature: 0.5, - DefaultTopK: 3, - DefaultTopP: 0.95, - HarmThreshold: HarmBlockOnlyHigh, + CloudProject: CloudProject, + CloudLocation: CloudLocation, + DefaultModel: DefaultModel, + DefaultEmbeddingModel: DefaultEmbeddingModel, + DefaultCandidateCount: DefaultCandidateCount, + DefaultMaxTokens: DefaultMaxTokens, + DefaultTemperature: DefaultTemperature, + DefaultTopK: DefaultTopK, + DefaultTopP: DefaultTopP, + HarmThreshold: DefaultHarmThreshold, } } diff --git a/llms/googleai/palm/palm_llm.go b/llms/googleai/palm/palm_llm.go index 54d142632..f104142ce 100644 --- a/llms/googleai/palm/palm_llm.go +++ b/llms/googleai/palm/palm_llm.go @@ -13,7 +13,6 @@ import ( ) var ( - ErrEmptyResponse = errors.New("no response") ErrMissingProjectID = errors.New("missing the GCP Project ID, set it in the GOOGLE_CLOUD_PROJECT environment variable") //nolint:lll ErrMissingLocation = errors.New("missing the GCP Location, set it in the GOOGLE_CLOUD_LOCATION environment variable") //nolint:lll ErrUnexpectedResponseLength = errors.New("unexpected length of response") @@ -85,7 +84,7 @@ func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]flo } if len(embeddings) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } if len(inputTexts) != len(embeddings) { return embeddings, ErrUnexpectedResponseLength diff --git a/llms/huggingface/huggingfacellm.go b/llms/huggingface/huggingfacellm.go index 63ef03b3a..d9625f045 100644 --- a/llms/huggingface/huggingfacellm.go +++ b/llms/huggingface/huggingfacellm.go @@ -11,7 +11,6 @@ import ( ) var ( - ErrEmptyResponse = errors.New("empty response") ErrMissingToken = errors.New("missing the Hugging Face API token. Set it in the HUGGINGFACEHUB_API_TOKEN environment variable") //nolint:lll ErrUnexpectedResponseLength = errors.New("unexpected length of response") ) @@ -115,7 +114,7 @@ func (o *LLM) CreateEmbedding( return nil, err } if len(embeddings) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } if len(inputTexts) != len(embeddings) { return embeddings, ErrUnexpectedResponseLength diff --git a/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go b/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go index d74c36b3b..ab17f1644 100644 --- a/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go +++ b/llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go @@ -4,13 +4,12 @@ import ( "context" "errors" "fmt" -) -var ( - ErrInvalidToken = errors.New("invalid token") - ErrEmptyResponse = errors.New("empty response") + "github.com/tmc/langchaingo/llms" ) +var ErrInvalidToken = errors.New("invalid token") + type Client struct { Token string Model string @@ -64,7 +63,7 @@ func (c *Client) RunInference(ctx context.Context, request *InferenceRequest) (* return nil, fmt.Errorf("failed to run inference: %w", err) } if len(resp) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } text := resp[0].Text // TODO: Add response cleaning based on Model. @@ -96,7 +95,7 @@ func (c *Client) CreateEmbedding( } if len(resp) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } return resp, nil diff --git a/llms/huggingface/internal/huggingfaceclient/huggingfaceclient_test.go b/llms/huggingface/internal/huggingfaceclient/huggingfaceclient_test.go index d36e66208..2d2188702 100644 --- a/llms/huggingface/internal/huggingfaceclient/huggingfaceclient_test.go +++ b/llms/huggingface/internal/huggingfaceclient/huggingfaceclient_test.go @@ -54,11 +54,15 @@ func mockServer(t *testing.T) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { b, err := io.ReadAll(r.Body) - require.NoError(t, err) + if err != nil { + t.Error(err) + } var infReq inferencePayload err = json.Unmarshal(b, &infReq) - require.NoError(t, err) + if err != nil { + t.Error(err) + } if infReq.Parameters.TopK == -1 { w.WriteHeader(http.StatusBadRequest) diff --git a/llms/llamafile/llamafilellm.go b/llms/llamafile/llamafilellm.go index 8e9cc6088..4fef5f312 100644 --- a/llms/llamafile/llamafilellm.go +++ b/llms/llamafile/llamafilellm.go @@ -9,11 +9,6 @@ import ( "github.com/tmc/langchaingo/llms/llamafile/internal/llamafileclient" ) -var ( - ErrEmptyResponse = errors.New("no response") - ErrIncompleteEmbedding = errors.New("not all input got embedded") -) - // LLM is a llamafile LLM implementation. type LLM struct { CallbacksHandler callbacks.Handler diff --git a/llms/local/internal/localclient/localclient.go b/llms/local/internal/localclient/localclient.go index d11902ca3..b8c2f7589 100644 --- a/llms/local/internal/localclient/localclient.go +++ b/llms/local/internal/localclient/localclient.go @@ -2,12 +2,8 @@ package localclient import ( "context" - "errors" ) -// ErrEmptyResponse is returned when the OpenAI API returns an empty response. -var ErrEmptyResponse = errors.New("empty response") - // Client is a client for a local LLM. type Client struct { BinPath string diff --git a/llms/local/localllm.go b/llms/local/localllm.go index bbbbe9d40..3cb6dfb4e 100644 --- a/llms/local/localllm.go +++ b/llms/local/localllm.go @@ -13,12 +13,8 @@ import ( "github.com/tmc/langchaingo/llms/local/internal/localclient" ) -var ( - // ErrEmptyResponse is returned when the local LLM binary returns an empty response. - ErrEmptyResponse = errors.New("no response") - // ErrMissingBin is returned when the LOCAL_LLM_BIN environment variable is not set. - ErrMissingBin = errors.New("missing the local LLM binary path, set the LOCAL_LLM_BIN environment variable") -) +// ErrMissingBin is returned when the LOCAL_LLM_BIN environment variable is not set. +var ErrMissingBin = errors.New("missing the local LLM binary path, set the LOCAL_LLM_BIN environment variable") // LLM is a local LLM implementation. type LLM struct { diff --git a/llms/maritaca/internal/maritacaclient/maritacaclient.go b/llms/maritaca/internal/maritacaclient/maritacaclient.go index 1da938710..21b3bdce6 100644 --- a/llms/maritaca/internal/maritacaclient/maritacaclient.go +++ b/llms/maritaca/internal/maritacaclient/maritacaclient.go @@ -131,7 +131,8 @@ func parseData(input string) (string, error) { } parts := strings.SplitAfter(input, "data:") - if len(parts) < 2 { + const expectedPartsLength = 2 + if len(parts) < expectedPartsLength { return "", nil } diff --git a/llms/maritaca/maritacallm.go b/llms/maritaca/maritacallm.go index 890d7325b..458d6b189 100644 --- a/llms/maritaca/maritacallm.go +++ b/llms/maritaca/maritacallm.go @@ -10,11 +10,6 @@ import ( "github.com/tmc/langchaingo/llms/maritaca/internal/maritacaclient" ) -var ( - ErrEmptyResponse = errors.New("no response") - ErrIncompleteEmbedding = errors.New("not all input got embedded") -) - // LLM is a maritaca LLM implementation. type LLM struct { CallbacksHandler callbacks.Handler diff --git a/llms/ollama/ollamallm.go b/llms/ollama/ollamallm.go index 0de34a599..5767ddc87 100644 --- a/llms/ollama/ollamallm.go +++ b/llms/ollama/ollamallm.go @@ -9,11 +9,6 @@ import ( "github.com/tmc/langchaingo/llms/ollama/internal/ollamaclient" ) -var ( - ErrEmptyResponse = errors.New("no response") - ErrIncompleteEmbedding = errors.New("not all input got embedded") -) - // LLM is a ollama LLM implementation. type LLM struct { CallbacksHandler callbacks.Handler @@ -185,14 +180,14 @@ func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]flo } if len(embedding.Embedding) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } embeddings = append(embeddings, embedding.Embedding) } if len(inputTexts) != len(embeddings) { - return embeddings, ErrIncompleteEmbedding + return embeddings, llms.ErrIncompleteEmbedding } return embeddings, nil diff --git a/llms/openai/internal/openaiclient/openaiclient.go b/llms/openai/internal/openaiclient/openaiclient.go index 05c114e84..50cb4b11c 100644 --- a/llms/openai/internal/openaiclient/openaiclient.go +++ b/llms/openai/internal/openaiclient/openaiclient.go @@ -2,10 +2,11 @@ package openaiclient import ( "context" - "errors" "fmt" "net/http" "strings" + + "github.com/tmc/langchaingo/llms" ) const ( @@ -13,9 +14,6 @@ const ( defaultFunctionCallBehavior = "auto" ) -// ErrEmptyResponse is returned when the OpenAI API returns an empty response. -var ErrEmptyResponse = errors.New("empty response") - type APIType string const ( @@ -87,7 +85,7 @@ func (c *Client) CreateCompletion(ctx context.Context, r *CompletionRequest) (*C return nil, err } if len(resp.Choices) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } return &Completion{ Text: resp.Choices[0].Message.Content, @@ -115,7 +113,7 @@ func (c *Client) CreateEmbedding(ctx context.Context, r *EmbeddingRequest) ([][] } if len(resp.Data) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } embeddings := make([][]float32, 0) @@ -140,7 +138,7 @@ func (c *Client) CreateChat(ctx context.Context, r *ChatRequest) (*ChatCompletio return nil, err } if len(resp.Choices) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } return resp, nil } @@ -154,10 +152,10 @@ func (c *Client) setHeaders(req *http.Request) { if c.apiType == APITypeOpenAI || c.apiType == APITypeAzureAD { req.Header.Set("Authorization", "Bearer "+c.token) } else { - req.Header.Set("api-key", c.token) + req.Header.Set("Api-Key", c.token) } if c.organization != "" { - req.Header.Set("OpenAI-Organization", c.organization) + req.Header.Set("Openai-Organization", c.organization) } } diff --git a/llms/openai/llm.go b/llms/openai/llm.go index 21a074774..9a09020c3 100644 --- a/llms/openai/llm.go +++ b/llms/openai/llm.go @@ -9,7 +9,6 @@ import ( ) var ( - ErrEmptyResponse = errors.New("no response") ErrMissingToken = errors.New("missing the OpenAI API key, set it in the OPENAI_API_KEY environment variable") //nolint:lll ErrMissingAzureModel = errors.New("model needs to be provided when using Azure API") ErrMissingAzureEmbeddingModel = errors.New("embeddings model needs to be provided when using Azure API") diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 78f8334d2..2c97bc35d 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -147,7 +147,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten return nil, err } if len(result.Choices) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } choices := make([]*llms.ContentChoice, len(result.Choices)) @@ -202,7 +202,7 @@ func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]flo return nil, fmt.Errorf("failed to create openai embeddings: %w", err) } if len(embeddings) == 0 { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } if len(inputTexts) != len(embeddings) { return embeddings, ErrUnexpectedResponseLength diff --git a/llms/watsonx/watsonxllm.go b/llms/watsonx/watsonxllm.go index 42f00c272..f9f39f18e 100644 --- a/llms/watsonx/watsonxllm.go +++ b/llms/watsonx/watsonxllm.go @@ -9,10 +9,7 @@ import ( "github.com/tmc/langchaingo/llms" ) -var ( - ErrInvalidPrompt = errors.New("invalid prompt") - ErrEmptyResponse = errors.New("no response") -) +var ErrInvalidPrompt = errors.New("invalid prompt") type LLM struct { CallbacksHandler callbacks.Handler @@ -53,7 +50,7 @@ func (wx *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConte } if result.Text == "" { - return nil, ErrEmptyResponse + return nil, llms.ErrEmptyResponse } resp := &llms.ContentResponse{