Skip to content

Commit

Permalink
Merge pull request #1254 from lupuletic/added-new-thinking-support
Browse files Browse the repository at this point in the history
Added support for Claude Sonnet 3.7 thinking via Vertex AI
  • Loading branch information
cte authored Feb 27, 2025
2 parents 277175f + dd4fb6b commit 54c6874
Show file tree
Hide file tree
Showing 11 changed files with 593 additions and 34 deletions.
10 changes: 5 additions & 5 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@
"dependencies": {
"@anthropic-ai/bedrock-sdk": "^0.10.2",
"@anthropic-ai/sdk": "^0.37.0",
"@anthropic-ai/vertex-sdk": "^0.4.1",
"@anthropic-ai/vertex-sdk": "^0.7.0",
"@aws-sdk/client-bedrock-runtime": "^3.706.0",
"@google/generative-ai": "^0.18.0",
"@mistralai/mistralai": "^1.3.6",
Expand Down
250 changes: 250 additions & 0 deletions src/api/providers/__tests__/vertex.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import { Anthropic } from "@anthropic-ai/sdk"
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta"

import { VertexHandler } from "../vertex"
import { ApiStreamChunk } from "../../transform/stream"
Expand Down Expand Up @@ -431,6 +432,138 @@ describe("VertexHandler", () => {
})
})

describe("thinking functionality", () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello",
},
]

const systemPrompt = "You are a helpful assistant"

it("should handle thinking content blocks and deltas", async () => {
const mockStream = [
{
type: "message_start",
message: {
usage: {
input_tokens: 10,
output_tokens: 0,
},
},
},
{
type: "content_block_start",
index: 0,
content_block: {
type: "thinking",
thinking: "Let me think about this...",
},
},
{
type: "content_block_delta",
delta: {
type: "thinking_delta",
thinking: " I need to consider all options.",
},
},
{
type: "content_block_start",
index: 1,
content_block: {
type: "text",
text: "Here's my answer:",
},
},
]

// Setup async iterator for mock stream
const asyncIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of mockStream) {
yield chunk
}
},
}

const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
;(handler["client"].messages as any).create = mockCreate

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks: ApiStreamChunk[] = []

for await (const chunk of stream) {
chunks.push(chunk)
}

// Verify thinking content is processed correctly
const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning")
expect(reasoningChunks).toHaveLength(2)
expect(reasoningChunks[0].text).toBe("Let me think about this...")
expect(reasoningChunks[1].text).toBe(" I need to consider all options.")

// Verify text content is processed correctly
const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(2) // One for the text block, one for the newline
expect(textChunks[0].text).toBe("\n")
expect(textChunks[1].text).toBe("Here's my answer:")
})

it("should handle multiple thinking blocks with line breaks", async () => {
const mockStream = [
{
type: "content_block_start",
index: 0,
content_block: {
type: "thinking",
thinking: "First thinking block",
},
},
{
type: "content_block_start",
index: 1,
content_block: {
type: "thinking",
thinking: "Second thinking block",
},
},
]

const asyncIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of mockStream) {
yield chunk
}
},
}

const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
;(handler["client"].messages as any).create = mockCreate

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks: ApiStreamChunk[] = []

for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks.length).toBe(3)
expect(chunks[0]).toEqual({
type: "reasoning",
text: "First thinking block",
})
expect(chunks[1]).toEqual({
type: "reasoning",
text: "\n",
})
expect(chunks[2]).toEqual({
type: "reasoning",
text: "Second thinking block",
})
})
})

describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
const result = await handler.completePrompt("Test prompt")
Expand Down Expand Up @@ -500,4 +633,121 @@ describe("VertexHandler", () => {
expect(modelInfo.id).toBe("claude-3-7-sonnet@20250219") // Default model
})
})

describe("thinking model configuration", () => {
it("should configure thinking for models with :thinking suffix", () => {
const thinkingHandler = new VertexHandler({
apiModelId: "claude-3-7-sonnet@20250219:thinking",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
modelMaxTokens: 16384,
vertexThinking: 4096,
})

const modelInfo = thinkingHandler.getModel()

// Verify thinking configuration
expect(modelInfo.id).toBe("claude-3-7-sonnet@20250219")
expect(modelInfo.thinking).toBeDefined()
const thinkingConfig = modelInfo.thinking as { type: "enabled"; budget_tokens: number }
expect(thinkingConfig.type).toBe("enabled")
expect(thinkingConfig.budget_tokens).toBe(4096)
expect(modelInfo.temperature).toBe(1.0) // Thinking requires temperature 1.0
})

it("should calculate thinking budget correctly", () => {
// Test with explicit thinking budget
const handlerWithBudget = new VertexHandler({
apiModelId: "claude-3-7-sonnet@20250219:thinking",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
modelMaxTokens: 16384,
vertexThinking: 5000,
})

expect((handlerWithBudget.getModel().thinking as any).budget_tokens).toBe(5000)

// Test with default thinking budget (80% of max tokens)
const handlerWithDefaultBudget = new VertexHandler({
apiModelId: "claude-3-7-sonnet@20250219:thinking",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
modelMaxTokens: 10000,
})

expect((handlerWithDefaultBudget.getModel().thinking as any).budget_tokens).toBe(8000) // 80% of 10000

// Test with minimum thinking budget (should be at least 1024)
const handlerWithSmallMaxTokens = new VertexHandler({
apiModelId: "claude-3-7-sonnet@20250219:thinking",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
modelMaxTokens: 1000, // This would result in 800 tokens for thinking, but minimum is 1024
})

expect((handlerWithSmallMaxTokens.getModel().thinking as any).budget_tokens).toBe(1024)
})

it("should use anthropicThinking value if vertexThinking is not provided", () => {
const handler = new VertexHandler({
apiModelId: "claude-3-7-sonnet@20250219:thinking",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
modelMaxTokens: 16384,
anthropicThinking: 6000, // Should be used as fallback
})

expect((handler.getModel().thinking as any).budget_tokens).toBe(6000)
})

it("should pass thinking configuration to API", async () => {
const thinkingHandler = new VertexHandler({
apiModelId: "claude-3-7-sonnet@20250219:thinking",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
modelMaxTokens: 16384,
vertexThinking: 4096,
})

const mockCreate = jest.fn().mockImplementation(async (options) => {
if (!options.stream) {
return {
id: "test-completion",
content: [{ type: "text", text: "Test response" }],
role: "assistant",
model: options.model,
usage: {
input_tokens: 10,
output_tokens: 5,
},
}
}
return {
async *[Symbol.asyncIterator]() {
yield {
type: "message_start",
message: {
usage: {
input_tokens: 10,
output_tokens: 5,
},
},
}
},
}
})
;(thinkingHandler["client"].messages as any).create = mockCreate

await thinkingHandler
.createMessage("You are a helpful assistant", [{ role: "user", content: "Hello" }])
.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
thinking: { type: "enabled", budget_tokens: 4096 },
temperature: 1.0, // Thinking requires temperature 1.0
}),
)
})
})
})
Loading

0 comments on commit 54c6874

Please sign in to comment.