Skip to content

Commit

Permalink
Merge pull request #1312 from RooVetGit/count_tokens
Browse files Browse the repository at this point in the history
Infrastructure to support calling token count APIs, starting with Anthropic
  • Loading branch information
mrubens authored Mar 2, 2025
2 parents a2d441c + 7e62d34 commit 773e556
Show file tree
Hide file tree
Showing 18 changed files with 543 additions and 361 deletions.
10 changes: 10 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ export interface SingleCompletionHandler {
export interface ApiHandler {
createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
getModel(): { id: string; info: ModelInfo }

/**
* Counts tokens for content blocks
* All providers extend BaseProvider which provides a default tiktoken implementation,
* but they can override this to use their native token counting endpoints
*
* @param content The content to count tokens for
* @returns A promise resolving to the token count
*/
countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number>
}

export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
Expand Down
38 changes: 35 additions & 3 deletions src/api/providers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ import {
ModelInfo,
} from "../../shared/api"
import { ApiStream } from "../transform/stream"
import { BaseProvider } from "./base-provider"
import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "./constants"
import { ApiHandler, SingleCompletionHandler, getModelParams } from "../index"
import { SingleCompletionHandler, getModelParams } from "../index"

export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
export class AnthropicHandler extends BaseProvider implements SingleCompletionHandler {
private options: ApiHandlerOptions
private client: Anthropic

constructor(options: ApiHandlerOptions) {
super()
this.options = options

this.client = new Anthropic({
apiKey: this.options.apiKey,
baseURL: this.options.anthropicBaseUrl || undefined,
Expand Down Expand Up @@ -212,4 +213,35 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
const content = message.content.find(({ type }) => type === "text")
return content?.type === "text" ? content.text : ""
}

/**
* Counts tokens for the given content using Anthropic's API
*
* @param content The content blocks to count tokens for
* @returns A promise resolving to the token count
*/
override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
try {
// Use the current model
const actualModelId = this.getModel().id

const response = await this.client.messages.countTokens({
model: actualModelId,
messages: [
{
role: "user",
content: content,
},
],
})

return response.input_tokens
} catch (error) {
// Log error but fallback to tiktoken estimation
console.warn("Anthropic token counting failed, using fallback", error)

// Use the base provider's implementation as fallback
return super.countTokens(content)
}
}
}
64 changes: 64 additions & 0 deletions src/api/providers/base-provider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandler } from ".."
import { ModelInfo } from "../../shared/api"
import { ApiStream } from "../transform/stream"
import { Tiktoken } from "js-tiktoken/lite"
import o200kBase from "js-tiktoken/ranks/o200k_base"

// Reuse the fudge factor used in the original code
const TOKEN_FUDGE_FACTOR = 1.5

/**
* Base class for API providers that implements common functionality
*/
export abstract class BaseProvider implements ApiHandler {
// Cache the Tiktoken encoder instance since it's stateless
private encoder: Tiktoken | null = null
abstract createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
abstract getModel(): { id: string; info: ModelInfo }

/**
* Default token counting implementation using tiktoken
* Providers can override this to use their native token counting endpoints
*
* Uses a cached Tiktoken encoder instance for performance since it's stateless.
* The encoder is created lazily on first use and reused for subsequent calls.
*
* @param content The content to count tokens for
* @returns A promise resolving to the token count
*/
async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
if (!content || content.length === 0) return 0

let totalTokens = 0

// Lazily create and cache the encoder if it doesn't exist
if (!this.encoder) {
this.encoder = new Tiktoken(o200kBase)
}

// Process each content block using the cached encoder
for (const block of content) {
if (block.type === "text") {
// Use tiktoken for text token counting
const text = block.text || ""
if (text.length > 0) {
const tokens = this.encoder.encode(text)
totalTokens += tokens.length
}
} else if (block.type === "image") {
// For images, calculate based on data size
const imageSource = block.source
if (imageSource && typeof imageSource === "object" && "data" in imageSource) {
const base64Data = imageSource.data as string
totalTokens += Math.ceil(Math.sqrt(base64Data.length))
} else {
totalTokens += 300 // Conservative estimate for unknown images
}
}
}

// Add a fudge factor to account for the fact that tiktoken is not always accurate
return Math.ceil(totalTokens * TOKEN_FUDGE_FACTOR)
}
}
12 changes: 7 additions & 5 deletions src/api/providers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ import {
} from "@aws-sdk/client-bedrock-runtime"
import { fromIni } from "@aws-sdk/credential-providers"
import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandler, SingleCompletionHandler } from "../"
import { SingleCompletionHandler } from "../"
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
import { ApiStream } from "../transform/stream"
import { convertToBedrockConverseMessages } from "../transform/bedrock-converse-format"
import { BaseProvider } from "./base-provider"

const BEDROCK_DEFAULT_TEMPERATURE = 0.3

Expand Down Expand Up @@ -46,11 +47,12 @@ export interface StreamEvent {
}
}

export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
export class AwsBedrockHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: BedrockRuntimeClient

constructor(options: ApiHandlerOptions) {
super()
this.options = options

const clientConfig: BedrockRuntimeClientConfig = {
Expand All @@ -74,7 +76,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
this.client = new BedrockRuntimeClient(clientConfig)
}

async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const modelConfig = this.getModel()

// Handle cross-region inference
Expand Down Expand Up @@ -205,7 +207,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
}
}

getModel(): { id: BedrockModelId | string; info: ModelInfo } {
override getModel(): { id: BedrockModelId | string; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId) {
// For tests, allow any model ID
Expand Down
12 changes: 7 additions & 5 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { GoogleGenerativeAI } from "@google/generative-ai"
import { ApiHandler, SingleCompletionHandler } from "../"
import { SingleCompletionHandler } from "../"
import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api"
import { convertAnthropicMessageToGemini } from "../transform/gemini-format"
import { ApiStream } from "../transform/stream"
import { BaseProvider } from "./base-provider"

const GEMINI_DEFAULT_TEMPERATURE = 0

export class GeminiHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: GoogleGenerativeAI

constructor(options: ApiHandlerOptions) {
super()
this.options = options
this.client = new GoogleGenerativeAI(options.geminiApiKey ?? "not-provided")
}

async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const model = this.client.getGenerativeModel({
model: this.getModel().id,
systemInstruction: systemPrompt,
Expand Down Expand Up @@ -44,7 +46,7 @@ export class GeminiHandler implements ApiHandler, SingleCompletionHandler {
}
}

getModel(): { id: GeminiModelId; info: ModelInfo } {
override getModel(): { id: GeminiModelId; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId && modelId in geminiModels) {
const id = modelId as GeminiModelId
Expand Down
40 changes: 21 additions & 19 deletions src/api/providers/glama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,39 @@ import { ApiHandlerOptions, ModelInfo, glamaDefaultModelId, glamaDefaultModelInf
import { parseApiPrice } from "../../utils/cost"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"
import { ApiHandler, SingleCompletionHandler } from "../"
import { SingleCompletionHandler } from "../"
import { BaseProvider } from "./base-provider"

const GLAMA_DEFAULT_TEMPERATURE = 0

export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
export class GlamaHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: OpenAI

constructor(options: ApiHandlerOptions) {
super()
this.options = options
const baseURL = "https://glama.ai/api/gateway/openai/v1"
const apiKey = this.options.glamaApiKey ?? "not-provided"
this.client = new OpenAI({ baseURL, apiKey })
}

async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
private supportsTemperature(): boolean {
return !this.getModel().id.startsWith("openai/o3-mini")
}

override getModel(): { id: string; info: ModelInfo } {
const modelId = this.options.glamaModelId
const modelInfo = this.options.glamaModelInfo

if (modelId && modelInfo) {
return { id: modelId, info: modelInfo }
}

return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
}

override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
// Convert Anthropic messages to OpenAI format
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
Expand Down Expand Up @@ -152,21 +169,6 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
}
}

private supportsTemperature(): boolean {
return !this.getModel().id.startsWith("openai/o3-mini")
}

getModel(): { id: string; info: ModelInfo } {
const modelId = this.options.glamaModelId
const modelInfo = this.options.glamaModelInfo

if (modelId && modelInfo) {
return { id: modelId, info: modelInfo }
}

return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
}

async completePrompt(prompt: string): Promise<string> {
try {
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
Expand Down
12 changes: 7 additions & 5 deletions src/api/providers/lmstudio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,28 @@ import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import axios from "axios"

import { ApiHandler, SingleCompletionHandler } from "../"
import { SingleCompletionHandler } from "../"
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"
import { BaseProvider } from "./base-provider"

const LMSTUDIO_DEFAULT_TEMPERATURE = 0

export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: OpenAI

constructor(options: ApiHandlerOptions) {
super()
this.options = options
this.client = new OpenAI({
baseURL: (this.options.lmStudioBaseUrl || "http://localhost:1234") + "/v1",
apiKey: "noop",
})
}

async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
Expand Down Expand Up @@ -51,7 +53,7 @@ export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
}
}

getModel(): { id: string; info: ModelInfo } {
override getModel(): { id: string; info: ModelInfo } {
return {
id: this.options.lmStudioModelId || "",
info: openAiModelInfoSaneDefaults,
Expand Down
12 changes: 7 additions & 5 deletions src/api/providers/mistral.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { Mistral } from "@mistralai/mistralai"
import { ApiHandler } from "../"
import { SingleCompletionHandler } from "../"
import {
ApiHandlerOptions,
mistralDefaultModelId,
Expand All @@ -13,14 +13,16 @@ import {
} from "../../shared/api"
import { convertToMistralMessages } from "../transform/mistral-format"
import { ApiStream } from "../transform/stream"
import { BaseProvider } from "./base-provider"

const MISTRAL_DEFAULT_TEMPERATURE = 0

export class MistralHandler implements ApiHandler {
private options: ApiHandlerOptions
export class MistralHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: Mistral

constructor(options: ApiHandlerOptions) {
super()
if (!options.mistralApiKey) {
throw new Error("Mistral API key is required")
}
Expand Down Expand Up @@ -48,7 +50,7 @@ export class MistralHandler implements ApiHandler {
return "https://api.mistral.ai"
}

async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const response = await this.client.chat.stream({
model: this.options.apiModelId || mistralDefaultModelId,
messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)],
Expand Down Expand Up @@ -81,7 +83,7 @@ export class MistralHandler implements ApiHandler {
}
}

getModel(): { id: MistralModelId; info: ModelInfo } {
override getModel(): { id: MistralModelId; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId && modelId in mistralModels) {
const id = modelId as MistralModelId
Expand Down
Loading

0 comments on commit 773e556

Please sign in to comment.