Skip to content

Commit

Permalink
Merge pull request #1257 from RooVetGit/cte/unify-thinking-budget-set…
Browse files Browse the repository at this point in the history
…ting
  • Loading branch information
cte authored Feb 28, 2025
2 parents 54c6874 + 8cbce2d commit 8fae1ca
Show file tree
Hide file tree
Showing 11 changed files with 28 additions and 58 deletions.
18 changes: 3 additions & 15 deletions src/api/providers/__tests__/vertex.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ describe("VertexHandler", () => {
vertexProjectId: "test-project",
vertexRegion: "us-central1",
modelMaxTokens: 16384,
vertexThinking: 4096,
modelMaxThinkingTokens: 4096,
})

const modelInfo = thinkingHandler.getModel()
Expand All @@ -662,7 +662,7 @@ describe("VertexHandler", () => {
vertexProjectId: "test-project",
vertexRegion: "us-central1",
modelMaxTokens: 16384,
vertexThinking: 5000,
modelMaxThinkingTokens: 5000,
})

expect((handlerWithBudget.getModel().thinking as any).budget_tokens).toBe(5000)
Expand All @@ -688,25 +688,13 @@ describe("VertexHandler", () => {
expect((handlerWithSmallMaxTokens.getModel().thinking as any).budget_tokens).toBe(1024)
})

it("should use anthropicThinking value if vertexThinking is not provided", () => {
const handler = new VertexHandler({
apiModelId: "claude-3-7-sonnet@20250219:thinking",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
modelMaxTokens: 16384,
anthropicThinking: 6000, // Should be used as fallback
})

expect((handler.getModel().thinking as any).budget_tokens).toBe(6000)
})

it("should pass thinking configuration to API", async () => {
const thinkingHandler = new VertexHandler({
apiModelId: "claude-3-7-sonnet@20250219:thinking",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
modelMaxTokens: 16384,
vertexThinking: 4096,
modelMaxThinkingTokens: 4096,
})

const mockCreate = jest.fn().mockImplementation(async (options) => {
Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
// least 1024 tokens.
const maxBudgetTokens = Math.floor(maxTokens * 0.8)
const budgetTokens = Math.max(
Math.min(this.options.anthropicThinking ?? maxBudgetTokens, maxBudgetTokens),
Math.min(this.options.modelMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens),
1024,
)

Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/openrouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
// least 1024 tokens.
const maxBudgetTokens = Math.floor((maxTokens || 8192) * 0.8)
const budgetTokens = Math.max(
Math.min(this.options.anthropicThinking ?? maxBudgetTokens, maxBudgetTokens),
Math.min(this.options.modelMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens),
1024,
)

Expand Down
5 changes: 1 addition & 4 deletions src/api/providers/vertex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,7 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
temperature = 1.0 // Thinking requires temperature 1.0
const maxBudgetTokens = Math.floor(maxTokens * 0.8)
const budgetTokens = Math.max(
Math.min(
this.options.vertexThinking ?? this.options.anthropicThinking ?? maxBudgetTokens,
maxBudgetTokens,
),
Math.min(this.options.modelMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens),
1024,
)
thinking = { type: "enabled", budget_tokens: budgetTokens }
Expand Down
15 changes: 5 additions & 10 deletions src/core/webview/ClineProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1651,8 +1651,6 @@ export class ClineProvider implements vscode.WebviewViewProvider {
lmStudioModelId,
lmStudioBaseUrl,
anthropicBaseUrl,
anthropicThinking,
vertexThinking,
geminiApiKey,
openAiNativeApiKey,
deepSeekApiKey,
Expand All @@ -1673,6 +1671,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
requestyModelInfo,
modelTemperature,
modelMaxTokens,
modelMaxThinkingTokens,
} = apiConfiguration
await Promise.all([
this.updateGlobalState("apiProvider", apiProvider),
Expand Down Expand Up @@ -1701,8 +1700,6 @@ export class ClineProvider implements vscode.WebviewViewProvider {
this.updateGlobalState("lmStudioModelId", lmStudioModelId),
this.updateGlobalState("lmStudioBaseUrl", lmStudioBaseUrl),
this.updateGlobalState("anthropicBaseUrl", anthropicBaseUrl),
this.updateGlobalState("anthropicThinking", anthropicThinking),
this.updateGlobalState("vertexThinking", vertexThinking),
this.storeSecret("geminiApiKey", geminiApiKey),
this.storeSecret("openAiNativeApiKey", openAiNativeApiKey),
this.storeSecret("deepSeekApiKey", deepSeekApiKey),
Expand All @@ -1723,6 +1720,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
this.updateGlobalState("requestyModelInfo", requestyModelInfo),
this.updateGlobalState("modelTemperature", modelTemperature),
this.updateGlobalState("modelMaxTokens", modelMaxTokens),
this.updateGlobalState("anthropicThinking", modelMaxThinkingTokens),
])
if (this.cline) {
this.cline.api = buildApiHandler(apiConfiguration)
Expand Down Expand Up @@ -2159,8 +2157,6 @@ export class ClineProvider implements vscode.WebviewViewProvider {
lmStudioModelId,
lmStudioBaseUrl,
anthropicBaseUrl,
anthropicThinking,
vertexThinking,
geminiApiKey,
openAiNativeApiKey,
deepSeekApiKey,
Expand Down Expand Up @@ -2216,6 +2212,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
requestyModelInfo,
modelTemperature,
modelMaxTokens,
modelMaxThinkingTokens,
maxOpenTabsContext,
] = await Promise.all([
this.getGlobalState("apiProvider") as Promise<ApiProvider | undefined>,
Expand Down Expand Up @@ -2244,8 +2241,6 @@ export class ClineProvider implements vscode.WebviewViewProvider {
this.getGlobalState("lmStudioModelId") as Promise<string | undefined>,
this.getGlobalState("lmStudioBaseUrl") as Promise<string | undefined>,
this.getGlobalState("anthropicBaseUrl") as Promise<string | undefined>,
this.getGlobalState("anthropicThinking") as Promise<number | undefined>,
this.getGlobalState("vertexThinking") as Promise<number | undefined>,
this.getSecret("geminiApiKey") as Promise<string | undefined>,
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
Expand Down Expand Up @@ -2301,6 +2296,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
this.getGlobalState("requestyModelInfo") as Promise<ModelInfo | undefined>,
this.getGlobalState("modelTemperature") as Promise<number | undefined>,
this.getGlobalState("modelMaxTokens") as Promise<number | undefined>,
this.getGlobalState("anthropicThinking") as Promise<number | undefined>,
this.getGlobalState("maxOpenTabsContext") as Promise<number | undefined>,
])

Expand Down Expand Up @@ -2346,8 +2342,6 @@ export class ClineProvider implements vscode.WebviewViewProvider {
lmStudioModelId,
lmStudioBaseUrl,
anthropicBaseUrl,
anthropicThinking,
vertexThinking,
geminiApiKey,
openAiNativeApiKey,
deepSeekApiKey,
Expand All @@ -2368,6 +2362,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
requestyModelInfo,
modelTemperature,
modelMaxTokens,
modelMaxThinkingTokens,
},
lastShownAnnouncementId,
customInstructions,
Expand Down
2 changes: 1 addition & 1 deletion src/shared/__tests__/checkExistApiConfig.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ describe("checkExistKey", () => {
apiKey: "test-key",
apiProvider: undefined,
anthropicBaseUrl: undefined,
anthropicThinking: undefined,
modelMaxThinkingTokens: undefined,
}
expect(checkExistKey(config)).toBe(true)
})
Expand Down
3 changes: 1 addition & 2 deletions src/shared/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ export interface ApiHandlerOptions {
apiModelId?: string
apiKey?: string // anthropic
anthropicBaseUrl?: string
anthropicThinking?: number
vsCodeLmModelSelector?: vscode.LanguageModelChatSelector
glamaModelId?: string
glamaModelInfo?: ModelInfo
Expand All @@ -41,7 +40,6 @@ export interface ApiHandlerOptions {
awsUseProfile?: boolean
vertexProjectId?: string
vertexRegion?: string
vertexThinking?: number
openAiBaseUrl?: string
openAiApiKey?: string
openAiModelId?: string
Expand Down Expand Up @@ -70,6 +68,7 @@ export interface ApiHandlerOptions {
requestyModelInfo?: ModelInfo
modelTemperature?: number
modelMaxTokens?: number
modelMaxThinkingTokens?: number
}

export type ApiConfiguration = ApiHandlerOptions & {
Expand Down
3 changes: 1 addition & 2 deletions src/shared/globalState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ export type GlobalStateKey =
| "awsUseProfile"
| "vertexProjectId"
| "vertexRegion"
| "vertexThinking"
| "lastShownAnnouncementId"
| "customInstructions"
| "alwaysAllowReadOnly"
Expand All @@ -43,7 +42,6 @@ export type GlobalStateKey =
| "lmStudioModelId"
| "lmStudioBaseUrl"
| "anthropicBaseUrl"
| "anthropicThinking"
| "azureApiVersion"
| "openAiStreamingEnabled"
| "openRouterModelId"
Expand Down Expand Up @@ -83,5 +81,6 @@ export type GlobalStateKey =
| "unboundModelInfo"
| "modelTemperature"
| "modelMaxTokens"
| "anthropicThinking" // TODO: Rename to `modelMaxThinkingTokens`.
| "mistralCodestralUrl"
| "maxOpenTabsContext"
13 changes: 5 additions & 8 deletions webview-ui/src/components/settings/ThinkingBudget.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,24 @@ export const ThinkingBudget = ({
modelInfo,
provider,
}: ThinkingBudgetProps) => {
const isVertexProvider = provider === "vertex"
const budgetField = isVertexProvider ? "vertexThinking" : "anthropicThinking"

const tokens = apiConfiguration?.modelMaxTokens || modelInfo?.maxTokens || 64_000
const tokensMin = 8192
const tokensMax = modelInfo?.maxTokens || 64_000

// Get the appropriate thinking tokens based on provider
const thinkingTokens = useMemo(() => {
const value = isVertexProvider ? apiConfiguration?.vertexThinking : apiConfiguration?.anthropicThinking
const value = apiConfiguration?.modelMaxThinkingTokens
return value || Math.min(Math.floor(0.8 * tokens), 8192)
}, [apiConfiguration, isVertexProvider, tokens])
}, [apiConfiguration, tokens])

const thinkingTokensMin = 1024
const thinkingTokensMax = Math.floor(0.8 * tokens)

useEffect(() => {
if (thinkingTokens > thinkingTokensMax) {
setApiConfigurationField(budgetField, thinkingTokensMax)
setApiConfigurationField("modelMaxThinkingTokens", thinkingTokensMax)
}
}, [thinkingTokens, thinkingTokensMax, setApiConfigurationField, budgetField])
}, [thinkingTokens, thinkingTokensMax, setApiConfigurationField])

if (!modelInfo?.thinking) {
return null
Expand Down Expand Up @@ -66,7 +63,7 @@ export const ThinkingBudget = ({
max={thinkingTokensMax}
step={1024}
value={[thinkingTokens]}
onValueChange={([value]) => setApiConfigurationField(budgetField, value)}
onValueChange={([value]) => setApiConfigurationField("modelMaxThinkingTokens", value)}
/>
<div className="w-12 text-sm text-center">{thinkingTokens}</div>
</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,7 @@ jest.mock("../ThinkingBudget", () => ({
ThinkingBudget: ({ apiConfiguration, setApiConfigurationField, modelInfo, provider }: any) =>
modelInfo?.thinking ? (
<div data-testid="thinking-budget" data-provider={provider}>
<input
data-testid="thinking-tokens"
value={
provider === "vertex" ? apiConfiguration?.vertexThinking : apiConfiguration?.anthropicThinking
}
/>
<input data-testid="thinking-tokens" value={apiConfiguration?.modelMaxThinkingTokens} />
</div>
) : null,
}))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ describe("ThinkingBudget", () => {
expect(screen.getAllByTestId("slider")).toHaveLength(2)
})

it("should use anthropicThinking field for Anthropic provider", () => {
it("should use modelMaxThinkingTokens field for Anthropic provider", () => {
const setApiConfigurationField = jest.fn()

render(
<ThinkingBudget
{...defaultProps}
apiConfiguration={{ anthropicThinking: 4096 }}
apiConfiguration={{ modelMaxThinkingTokens: 4096 }}
setApiConfigurationField={setApiConfigurationField}
provider="anthropic"
/>,
Expand All @@ -75,16 +75,16 @@ describe("ThinkingBudget", () => {
const sliders = screen.getAllByTestId("slider")
fireEvent.change(sliders[1], { target: { value: "5000" } })

expect(setApiConfigurationField).toHaveBeenCalledWith("anthropicThinking", 5000)
expect(setApiConfigurationField).toHaveBeenCalledWith("modelMaxThinkingTokens", 5000)
})

it("should use vertexThinking field for Vertex provider", () => {
it("should use modelMaxThinkingTokens field for Vertex provider", () => {
const setApiConfigurationField = jest.fn()

render(
<ThinkingBudget
{...defaultProps}
apiConfiguration={{ vertexThinking: 4096 }}
apiConfiguration={{ modelMaxThinkingTokens: 4096 }}
setApiConfigurationField={setApiConfigurationField}
provider="vertex"
/>,
Expand All @@ -93,7 +93,7 @@ describe("ThinkingBudget", () => {
const sliders = screen.getAllByTestId("slider")
fireEvent.change(sliders[1], { target: { value: "5000" } })

expect(setApiConfigurationField).toHaveBeenCalledWith("vertexThinking", 5000)
expect(setApiConfigurationField).toHaveBeenCalledWith("modelMaxThinkingTokens", 5000)
})

it("should cap thinking tokens at 80% of max tokens", () => {
Expand All @@ -102,13 +102,13 @@ describe("ThinkingBudget", () => {
render(
<ThinkingBudget
{...defaultProps}
apiConfiguration={{ modelMaxTokens: 10000, anthropicThinking: 9000 }}
apiConfiguration={{ modelMaxTokens: 10000, modelMaxThinkingTokens: 9000 }}
setApiConfigurationField={setApiConfigurationField}
/>,
)

// Effect should trigger and cap the value
expect(setApiConfigurationField).toHaveBeenCalledWith("anthropicThinking", 8000) // 80% of 10000
expect(setApiConfigurationField).toHaveBeenCalledWith("modelMaxThinkingTokens", 8000) // 80% of 10000
})

it("should use default thinking tokens if not provided", () => {
Expand Down

0 comments on commit 8fae1ca

Please sign in to comment.