From 9b646eb2bc47203a36592067113338036b9b72b1 Mon Sep 17 00:00:00 2001 From: cte Date: Sat, 1 Mar 2025 21:51:17 -0800 Subject: [PATCH] DRY up getModel --- src/api/__tests__/index.test.ts | 257 ++++++++++++++++++++++++++++++++ src/api/index.ts | 43 +++++- src/api/providers/anthropic.ts | 39 ++--- src/api/providers/constants.ts | 3 + src/api/providers/ollama.ts | 10 +- src/api/providers/openai.ts | 8 +- src/api/providers/openrouter.ts | 46 ++---- src/api/providers/vertex.ts | 56 ++----- 8 files changed, 341 insertions(+), 121 deletions(-) create mode 100644 src/api/__tests__/index.test.ts create mode 100644 src/api/providers/constants.ts diff --git a/src/api/__tests__/index.test.ts b/src/api/__tests__/index.test.ts new file mode 100644 index 0000000000..4408ca0ffc --- /dev/null +++ b/src/api/__tests__/index.test.ts @@ -0,0 +1,257 @@ +// npx jest src/api/__tests__/index.test.ts + +import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta/messages/index.mjs" + +import { getModelParams } from "../index" +import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "../providers/constants" + +describe("getModelParams", () => { + it("should return default values when no custom values are provided", () => { + const options = {} + const model = { + id: "test-model", + contextWindow: 16000, + supportsPromptCache: true, + } + + const result = getModelParams({ + options, + model, + defaultMaxTokens: 1000, + defaultTemperature: 0.5, + }) + + expect(result).toEqual({ + maxTokens: 1000, + thinking: undefined, + temperature: 0.5, + }) + }) + + it("should use custom temperature from options when provided", () => { + const options = { modelTemperature: 0.7 } + const model = { + id: "test-model", + contextWindow: 16000, + supportsPromptCache: true, + } + + const result = getModelParams({ + options, + model, + defaultMaxTokens: 1000, + defaultTemperature: 0.5, + }) + + expect(result).toEqual({ + maxTokens: 1000, + thinking: undefined, + temperature: 0.7, + }) + }) + + it("should use model maxTokens when available", () => { + const options = {} + const model = { + id: "test-model", + maxTokens: 2000, + contextWindow: 16000, + supportsPromptCache: true, + } + + const result = getModelParams({ + options, + model, + defaultMaxTokens: 1000, + }) + + expect(result).toEqual({ + maxTokens: 2000, + thinking: undefined, + temperature: 0, + }) + }) + + it("should handle thinking models correctly", () => { + const options = {} + const model = { + id: "test-model", + thinking: true, + maxTokens: 2000, + contextWindow: 16000, + supportsPromptCache: true, + } + + const result = getModelParams({ + options, + model, + }) + + const expectedThinking: BetaThinkingConfigParam = { + type: "enabled", + budget_tokens: 1600, // 80% of 2000 + } + + expect(result).toEqual({ + maxTokens: 2000, + thinking: expectedThinking, + temperature: 1.0, // Thinking models require temperature 1.0. + }) + }) + + it("should honor customMaxTokens for thinking models", () => { + const options = { modelMaxTokens: 3000 } + const model = { + id: "test-model", + thinking: true, + contextWindow: 16000, + supportsPromptCache: true, + } + + const result = getModelParams({ + options, + model, + defaultMaxTokens: 2000, + }) + + const expectedThinking: BetaThinkingConfigParam = { + type: "enabled", + budget_tokens: 2400, // 80% of 3000 + } + + expect(result).toEqual({ + maxTokens: 3000, + thinking: expectedThinking, + temperature: 1.0, + }) + }) + + it("should honor customMaxThinkingTokens for thinking models", () => { + const options = { modelMaxThinkingTokens: 1500 } + const model = { + id: "test-model", + thinking: true, + maxTokens: 4000, + contextWindow: 16000, + supportsPromptCache: true, + } + + const result = getModelParams({ + options, + model, + }) + + const expectedThinking: BetaThinkingConfigParam = { + type: "enabled", + budget_tokens: 1500, // Using the custom value + } + + expect(result).toEqual({ + maxTokens: 4000, + thinking: expectedThinking, + temperature: 1.0, + }) + }) + + it("should not honor customMaxThinkingTokens for non-thinking models", () => { + const options = { modelMaxThinkingTokens: 1500 } + const model = { + id: "test-model", + maxTokens: 4000, + contextWindow: 16000, + supportsPromptCache: true, + // Note: model.thinking is not set (so it's falsey). + } + + const result = getModelParams({ + options, + model, + }) + + expect(result).toEqual({ + maxTokens: 4000, + thinking: undefined, // Should remain undefined despite customMaxThinkingTokens being set. + temperature: 0, // Using default temperature. + }) + }) + + it("should clamp thinking budget to at least 1024 tokens", () => { + const options = { modelMaxThinkingTokens: 500 } + const model = { + id: "test-model", + thinking: true, + maxTokens: 2000, + contextWindow: 16000, + supportsPromptCache: true, + } + + const result = getModelParams({ + options, + model, + }) + + const expectedThinking: BetaThinkingConfigParam = { + type: "enabled", + budget_tokens: 1024, // Minimum is 1024 + } + + expect(result).toEqual({ + maxTokens: 2000, + thinking: expectedThinking, + temperature: 1.0, + }) + }) + + it("should clamp thinking budget to at most 80% of max tokens", () => { + const options = { modelMaxThinkingTokens: 5000 } + const model = { + id: "test-model", + thinking: true, + maxTokens: 4000, + contextWindow: 16000, + supportsPromptCache: true, + } + + const result = getModelParams({ + options, + model, + }) + + const expectedThinking: BetaThinkingConfigParam = { + type: "enabled", + budget_tokens: 3200, // 80% of 4000 + } + + expect(result).toEqual({ + maxTokens: 4000, + thinking: expectedThinking, + temperature: 1.0, + }) + }) + + it("should use ANTHROPIC_DEFAULT_MAX_TOKENS when no maxTokens is provided for thinking models", () => { + const options = {} + const model = { + id: "test-model", + thinking: true, + contextWindow: 16000, + supportsPromptCache: true, + } + + const result = getModelParams({ + options, + model, + }) + + const expectedThinking: BetaThinkingConfigParam = { + type: "enabled", + budget_tokens: Math.floor(ANTHROPIC_DEFAULT_MAX_TOKENS * 0.8), + } + + expect(result).toEqual({ + maxTokens: undefined, + thinking: expectedThinking, + temperature: 1.0, + }) + }) +}) diff --git a/src/api/index.ts b/src/api/index.ts index f68c9acd1f..6cf9317e40 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -1,6 +1,9 @@ import { Anthropic } from "@anthropic-ai/sdk" +import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta/messages/index.mjs" + +import { ApiConfiguration, ModelInfo, ApiHandlerOptions } from "../shared/api" +import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "./providers/constants" import { GlamaHandler } from "./providers/glama" -import { ApiConfiguration, ModelInfo } from "../shared/api" import { AnthropicHandler } from "./providers/anthropic" import { AwsBedrockHandler } from "./providers/bedrock" import { OpenRouterHandler } from "./providers/openrouter" @@ -63,3 +66,41 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler { return new AnthropicHandler(options) } } + +export function getModelParams({ + options, + model, + defaultMaxTokens, + defaultTemperature = 0, +}: { + options: ApiHandlerOptions + model: ModelInfo + defaultMaxTokens?: number + defaultTemperature?: number +}) { + const { + modelMaxTokens: customMaxTokens, + modelMaxThinkingTokens: customMaxThinkingTokens, + modelTemperature: customTemperature, + } = options + + let maxTokens = model.maxTokens ?? defaultMaxTokens + let thinking: BetaThinkingConfigParam | undefined = undefined + let temperature = customTemperature ?? defaultTemperature + + if (model.thinking) { + // Only honor `customMaxTokens` for thinking models. + maxTokens = customMaxTokens ?? maxTokens + + // Clamp the thinking budget to be at most 80% of max tokens and at + // least 1024 tokens. + const maxBudgetTokens = Math.floor((maxTokens || ANTHROPIC_DEFAULT_MAX_TOKENS) * 0.8) + const budgetTokens = Math.max(Math.min(customMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens), 1024) + thinking = { type: "enabled", budget_tokens: budgetTokens } + + // Anthropic "Thinking" models require a temperature of 1.0. + temperature = 1.0 + } + + return { maxTokens, thinking, temperature } +} diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index dc34b2eacd..acf5dcf9ec 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -1,7 +1,6 @@ import { Anthropic } from "@anthropic-ai/sdk" import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming" import { CacheControlEphemeral } from "@anthropic-ai/sdk/resources" -import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta" import { anthropicDefaultModelId, AnthropicModelId, @@ -9,8 +8,9 @@ import { ApiHandlerOptions, ModelInfo, } from "../../shared/api" -import { ApiHandler, SingleCompletionHandler } from "../index" import { ApiStream } from "../transform/stream" +import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "./constants" +import { ApiHandler, SingleCompletionHandler, getModelParams } from "../index" export class AnthropicHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions @@ -51,7 +51,7 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler { stream = await this.client.messages.create( { model: modelId, - max_tokens: maxTokens, + max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS, temperature, thinking, // Setting cache breakpoint for system prompt so new tasks can reuse it. @@ -99,7 +99,7 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler { default: { stream = (await this.client.messages.create({ model: modelId, - max_tokens: maxTokens, + max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS, temperature, system: [{ text: systemPrompt, type: "text" }], messages, @@ -180,13 +180,6 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler { getModel() { const modelId = this.options.apiModelId - - const { - modelMaxTokens: customMaxTokens, - modelMaxThinkingTokens: customMaxThinkingTokens, - modelTemperature: customTemperature, - } = this.options - let id = modelId && modelId in anthropicModels ? (modelId as AnthropicModelId) : anthropicDefaultModelId const info: ModelInfo = anthropicModels[id] @@ -197,25 +190,11 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler { id = "claude-3-7-sonnet-20250219" } - let maxTokens = info.maxTokens ?? 8192 - let thinking: BetaThinkingConfigParam | undefined = undefined - let temperature = customTemperature ?? 0 - - if (info.thinking) { - // Only honor `customMaxTokens` for thinking models. - maxTokens = customMaxTokens ?? maxTokens - - // Clamp the thinking budget to be at most 80% of max tokens and at - // least 1024 tokens. - const maxBudgetTokens = Math.floor(maxTokens * 0.8) - const budgetTokens = Math.max(Math.min(customMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens), 1024) - thinking = { type: "enabled", budget_tokens: budgetTokens } - - // Anthropic "Thinking" models require a temperature of 1.0. - temperature = 1.0 + return { + id, + info, + ...getModelParams({ options: this.options, model: info, defaultMaxTokens: ANTHROPIC_DEFAULT_MAX_TOKENS }), } - - return { id, info, maxTokens, thinking, temperature } } async completePrompt(prompt: string) { @@ -223,7 +202,7 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler { const message = await this.client.messages.create({ model: modelId, - max_tokens: maxTokens, + max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS, thinking, temperature, messages: [{ role: "user", content: prompt }], diff --git a/src/api/providers/constants.ts b/src/api/providers/constants.ts new file mode 100644 index 0000000000..86ca71746e --- /dev/null +++ b/src/api/providers/constants.ts @@ -0,0 +1,3 @@ +export const ANTHROPIC_DEFAULT_MAX_TOKENS = 8192 + +export const DEEP_SEEK_DEFAULT_TEMPERATURE = 0.6 diff --git a/src/api/providers/ollama.ts b/src/api/providers/ollama.ts index de7df5d261..94f5394eb6 100644 --- a/src/api/providers/ollama.ts +++ b/src/api/providers/ollama.ts @@ -7,11 +7,9 @@ import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../.. import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToR1Format } from "../transform/r1-format" import { ApiStream } from "../transform/stream" -import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./openai" +import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants" import { XmlMatcher } from "../../utils/xml-matcher" -const OLLAMA_DEFAULT_TEMPERATURE = 0 - export class OllamaHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions private client: OpenAI @@ -35,7 +33,7 @@ export class OllamaHandler implements ApiHandler, SingleCompletionHandler { const stream = await this.client.chat.completions.create({ model: this.getModel().id, messages: openAiMessages, - temperature: this.options.modelTemperature ?? OLLAMA_DEFAULT_TEMPERATURE, + temperature: this.options.modelTemperature ?? 0, stream: true, }) const matcher = new XmlMatcher( @@ -76,9 +74,7 @@ export class OllamaHandler implements ApiHandler, SingleCompletionHandler { messages: useR1Format ? convertToR1Format([{ role: "user", content: prompt }]) : [{ role: "user", content: prompt }], - temperature: - this.options.modelTemperature ?? - (useR1Format ? DEEP_SEEK_DEFAULT_TEMPERATURE : OLLAMA_DEFAULT_TEMPERATURE), + temperature: this.options.modelTemperature ?? (useR1Format ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0), stream: false, }) return response.choices[0]?.message.content || "" diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index f1c404d50a..208184f60a 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -13,14 +13,12 @@ import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToR1Format } from "../transform/r1-format" import { convertToSimpleMessages } from "../transform/simple-format" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants" export interface OpenAiHandlerOptions extends ApiHandlerOptions { defaultHeaders?: Record } -export const DEEP_SEEK_DEFAULT_TEMPERATURE = 0.6 -const OPENAI_DEFAULT_TEMPERATURE = 0 - export class OpenAiHandler implements ApiHandler, SingleCompletionHandler { protected options: OpenAiHandlerOptions private client: OpenAI @@ -78,9 +76,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler { const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { model: modelId, - temperature: - this.options.modelTemperature ?? - (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : OPENAI_DEFAULT_TEMPERATURE), + temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0), messages: convertedMessages, stream: true as const, stream_options: { include_usage: true }, diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index 4ada184cd9..66aaf52d81 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -9,10 +9,8 @@ import { parseApiPrice } from "../../utils/cost" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream" import { convertToR1Format } from "../transform/r1-format" -import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./openai" -import { ApiHandler, SingleCompletionHandler } from ".." - -const OPENROUTER_DEFAULT_TEMPERATURE = 0 +import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants" +import { ApiHandler, getModelParams, SingleCompletionHandler } from ".." // Add custom interface for OpenRouter params. type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & { @@ -200,40 +198,16 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler { let id = modelId ?? openRouterDefaultModelId const info = modelInfo ?? openRouterDefaultModelInfo - const { - modelMaxTokens: customMaxTokens, - modelMaxThinkingTokens: customMaxThinkingTokens, - modelTemperature: customTemperature, - } = this.options - - let maxTokens = info.maxTokens - let thinking: BetaThinkingConfigParam | undefined = undefined - let temperature = customTemperature ?? OPENROUTER_DEFAULT_TEMPERATURE - let topP: number | undefined = undefined - - // Handle models based on deepseek-r1 - if (id.startsWith("deepseek/deepseek-r1") || modelId === "perplexity/sonar-reasoning") { - // Recommended temperature for DeepSeek reasoning models. - temperature = customTemperature ?? DEEP_SEEK_DEFAULT_TEMPERATURE - // Some provider support topP and 0.95 is value that Deepseek used in their benchmarks. - topP = 0.95 - } - - if (info.thinking) { - // Only honor `customMaxTokens` for thinking models. - maxTokens = customMaxTokens ?? maxTokens + const isDeepSeekR1 = id.startsWith("deepseek/deepseek-r1") || modelId === "perplexity/sonar-reasoning" + const defaultTemperature = isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0 + const topP = isDeepSeekR1 ? 0.95 : undefined - // Clamp the thinking budget to be at most 80% of max tokens and at - // least 1024 tokens. - const maxBudgetTokens = Math.floor((maxTokens || 8192) * 0.8) - const budgetTokens = Math.max(Math.min(customMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens), 1024) - thinking = { type: "enabled", budget_tokens: budgetTokens } - - // Anthropic "Thinking" models require a temperature of 1.0. - temperature = 1.0 + return { + id, + info, + ...getModelParams({ options: this.options, model: info, defaultTemperature }), + topP, } - - return { id, info, maxTokens, thinking, temperature, topP } } async completePrompt(prompt: string) { diff --git a/src/api/providers/vertex.ts b/src/api/providers/vertex.ts index d60c27b4dd..02a8659886 100644 --- a/src/api/providers/vertex.ts +++ b/src/api/providers/vertex.ts @@ -1,12 +1,13 @@ import { Anthropic } from "@anthropic-ai/sdk" import { AnthropicVertex } from "@anthropic-ai/vertex-sdk" import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming" -import { ApiHandler, SingleCompletionHandler } from "../" -import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta" +import { VertexAI } from "@google-cloud/vertexai" + import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api" import { ApiStream } from "../transform/stream" -import { VertexAI } from "@google-cloud/vertexai" import { convertAnthropicMessageToVertexGemini } from "../transform/vertex-gemini-format" +import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "./constants" +import { ApiHandler, getModelParams, SingleCompletionHandler } from "../" // Types for Vertex SDK @@ -344,21 +345,8 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler { } } - getModel(): { - id: VertexModelId - info: ModelInfo - temperature: number - maxTokens: number - thinking?: BetaThinkingConfigParam - } { + getModel() { const modelId = this.options.apiModelId - - const { - modelMaxTokens: customMaxTokens, - modelMaxThinkingTokens: customMaxThinkingTokens, - modelTemperature: customTemperature, - } = this.options - let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId const info: ModelInfo = vertexModels[id] @@ -368,25 +356,11 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler { id = id.replace(":thinking", "") as VertexModelId } - let maxTokens = info.maxTokens || 8192 - let thinking: BetaThinkingConfigParam | undefined = undefined - let temperature = customTemperature ?? 0 - - if (info.thinking) { - // Only honor `customMaxTokens` for thinking models. - maxTokens = customMaxTokens ?? maxTokens - - // Clamp the thinking budget to be at most 80% of max tokens and at - // least 1024 tokens. - const maxBudgetTokens = Math.floor(maxTokens * 0.8) - const budgetTokens = Math.max(Math.min(customMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens), 1024) - thinking = { type: "enabled", budget_tokens: budgetTokens } - - // Anthropic "Thinking" models require a temperature of 1.0. - temperature = 1.0 + return { + id, + info, + ...getModelParams({ options: this.options, model: info, defaultMaxTokens: ANTHROPIC_DEFAULT_MAX_TOKENS }), } - - return { id, info, maxTokens, thinking, temperature } } private async completePromptGemini(prompt: string) { @@ -423,9 +397,9 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler { let { id, info, temperature, maxTokens, thinking } = this.getModel() const useCache = info.supportsPromptCache - const params = { + const params: Anthropic.Messages.MessageCreateParamsNonStreaming = { model: id, - max_tokens: maxTokens, + max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS, temperature, thinking, system: "", // No system prompt needed for single completions @@ -446,19 +420,19 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler { stream: false, } - const response = (await this.anthropicClient.messages.create( - params as Anthropic.Messages.MessageCreateParamsNonStreaming, - )) as unknown as VertexMessageResponse - + const response = (await this.anthropicClient.messages.create(params)) as unknown as VertexMessageResponse const content = response.content[0] + if (content.type === "text") { return content.text } + return "" } catch (error) { if (error instanceof Error) { throw new Error(`Vertex completion error: ${error.message}`) } + throw error } }