Skip to content

Commit

Permalink
loadbalance
Browse files Browse the repository at this point in the history
  • Loading branch information
shariqriazz committed Mar 2, 2025
1 parent a2d441c commit 2c2547a
Show file tree
Hide file tree
Showing 11 changed files with 471 additions and 48 deletions.
51 changes: 32 additions & 19 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { VertexHandler } from "./providers/vertex"
import { OpenAiHandler } from "./providers/openai"
import { OllamaHandler } from "./providers/ollama"
import { LmStudioHandler } from "./providers/lmstudio"
import { GeminiHandler } from "./providers/gemini"
import { GeminiHandler, ApiKeyRotationCallback, RequestCountUpdateCallback } from "./providers/gemini"
import { OpenAiNativeHandler } from "./providers/openai-native"
import { DeepSeekHandler } from "./providers/deepseek"
import { MistralHandler } from "./providers/mistral"
Expand All @@ -29,41 +29,54 @@ export interface ApiHandler {
getModel(): { id: string; info: ModelInfo }
}

export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
const { apiProvider, ...options } = configuration
/**
* Callbacks that can be passed to API handlers
*/
export interface ApiHandlerCallbacks {
onGeminiApiKeyRotation?: ApiKeyRotationCallback
onGeminiRequestCountUpdate?: RequestCountUpdateCallback
geminiInitialRequestCount?: number
}

export function buildApiHandler(configuration: ApiConfiguration, callbacks?: ApiHandlerCallbacks): ApiHandler {
const { apiProvider, ...handlerOptions } = configuration
switch (apiProvider) {
case "anthropic":
return new AnthropicHandler(options)
return new AnthropicHandler(handlerOptions)
case "glama":
return new GlamaHandler(options)
return new GlamaHandler(handlerOptions)
case "openrouter":
return new OpenRouterHandler(options)
return new OpenRouterHandler(handlerOptions)
case "bedrock":
return new AwsBedrockHandler(options)
return new AwsBedrockHandler(handlerOptions)
case "vertex":
return new VertexHandler(options)
return new VertexHandler(handlerOptions)
case "openai":
return new OpenAiHandler(options)
return new OpenAiHandler(handlerOptions)
case "ollama":
return new OllamaHandler(options)
return new OllamaHandler(handlerOptions)
case "lmstudio":
return new LmStudioHandler(options)
return new LmStudioHandler(handlerOptions)
case "gemini":
return new GeminiHandler(options)
return new GeminiHandler(handlerOptions, {
onApiKeyRotation: callbacks?.onGeminiApiKeyRotation,
onRequestCountUpdate: callbacks?.onGeminiRequestCountUpdate,
initialRequestCount: callbacks?.geminiInitialRequestCount,
})
case "openai-native":
return new OpenAiNativeHandler(options)
return new OpenAiNativeHandler(handlerOptions)
case "deepseek":
return new DeepSeekHandler(options)
return new DeepSeekHandler(handlerOptions)
case "vscode-lm":
return new VsCodeLmHandler(options)
return new VsCodeLmHandler(handlerOptions)
case "mistral":
return new MistralHandler(options)
return new MistralHandler(handlerOptions)
case "unbound":
return new UnboundHandler(options)
return new UnboundHandler(handlerOptions)
case "requesty":
return new RequestyHandler(options)
return new RequestyHandler(handlerOptions)
default:
return new AnthropicHandler(options)
return new AnthropicHandler(handlerOptions)
}
}

Expand Down
146 changes: 144 additions & 2 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,156 @@ import { convertAnthropicMessageToGemini } from "../transform/gemini-format"
import { ApiStream } from "../transform/stream"

const GEMINI_DEFAULT_TEMPERATURE = 0
const DEFAULT_REQUEST_COUNT = 10 // Default number of requests before switching API keys

// Define a callback type for API key rotation
export type ApiKeyRotationCallback = (newIndex: number, totalKeys: number, apiKey: string) => void
export type RequestCountUpdateCallback = (newCount: number) => void

export class GeminiHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
private client: GoogleGenerativeAI
private requestCount: number = 0
private onApiKeyRotation?: ApiKeyRotationCallback
private onRequestCountUpdate?: RequestCountUpdateCallback

constructor(options: ApiHandlerOptions) {
constructor(
options: ApiHandlerOptions,
callbacks?: {
onApiKeyRotation?: ApiKeyRotationCallback
onRequestCountUpdate?: RequestCountUpdateCallback
initialRequestCount?: number
},
) {
this.options = options
this.client = new GoogleGenerativeAI(options.geminiApiKey ?? "not-provided")
this.onApiKeyRotation = callbacks?.onApiKeyRotation
this.onRequestCountUpdate = callbacks?.onRequestCountUpdate

// Initialize request count from saved state if provided
if (callbacks?.initialRequestCount !== undefined) {
this.requestCount = callbacks.initialRequestCount
console.log(`[GeminiHandler] Initialized with request count: ${this.requestCount}`)
}

// Initialize with the current API key
const apiKey = this.getCurrentApiKey()
this.client = new GoogleGenerativeAI(apiKey)

// Log initial API key setup if load balancing is enabled
if (
this.options.geminiLoadBalancingEnabled &&
this.options.geminiApiKeys &&
this.options.geminiApiKeys.length > 0
) {
console.log(
`[GeminiHandler] Load balancing enabled with ${this.options.geminiApiKeys.length} keys. Current index: ${this.options.geminiCurrentApiKeyIndex ?? 0}`,
)
}
}

/**
* Get the current API key based on load balancing settings
*/
private getCurrentApiKey(): string {
// If load balancing is not enabled or there are no multiple API keys, use the single API key
if (
!this.options.geminiLoadBalancingEnabled ||
!this.options.geminiApiKeys ||
this.options.geminiApiKeys.length === 0
) {
return this.options.geminiApiKey ?? "not-provided"
}

// Get the current API key index, defaulting to 0 if not set
const currentIndex = this.options.geminiCurrentApiKeyIndex ?? 0

// Return the API key at the current index
return this.options.geminiApiKeys[currentIndex] ?? "not-provided"
}

/**
* Update the client with the next API key if load balancing is enabled
*/
private updateApiKeyIfNeeded(): void {
// If load balancing is not enabled or there are no multiple API keys, do nothing
if (
!this.options.geminiLoadBalancingEnabled ||
!this.options.geminiApiKeys ||
this.options.geminiApiKeys.length <= 1
) {
return
}

// Increment the request count
this.requestCount++
console.log(
`[GeminiHandler] Request count: ${this.requestCount}/${this.options.geminiLoadBalancingRequestCount ?? DEFAULT_REQUEST_COUNT}`,
)

// Notify about request count update
if (this.onRequestCountUpdate) {
this.onRequestCountUpdate(this.requestCount)
}

// Get the request count threshold, defaulting to DEFAULT_REQUEST_COUNT if not set
const requestCountThreshold = this.options.geminiLoadBalancingRequestCount ?? DEFAULT_REQUEST_COUNT

// If the request count has reached the threshold, switch to the next API key
if (this.requestCount >= requestCountThreshold) {
// Reset the request count
this.requestCount = 0

// Notify about request count reset
if (this.onRequestCountUpdate) {
this.onRequestCountUpdate(0)
}

// Get the current API key index, defaulting to 0 if not set
let currentIndex = this.options.geminiCurrentApiKeyIndex ?? 0

// Calculate the next index, wrapping around if necessary
currentIndex = (currentIndex + 1) % this.options.geminiApiKeys.length

// Notify callback first to update global state
if (this.onApiKeyRotation) {
// Get the API key for the new index
const apiKey = this.options.geminiApiKeys[currentIndex] ?? "not-provided"

// Only send the first few characters of the API key for security
const maskedKey = apiKey.substring(0, 4) + "..." + apiKey.substring(apiKey.length - 4)

// Call the callback to update global state
this.onApiKeyRotation(currentIndex, this.options.geminiApiKeys.length, maskedKey)

// Update the current index in the options AFTER the callback
// This ensures we're using the index that was just set in global state
this.options.geminiCurrentApiKeyIndex = currentIndex

// Update the client with the new API key
this.client = new GoogleGenerativeAI(apiKey)

console.log(
`[GeminiHandler] Rotated to API key index: ${currentIndex} (${this.options.geminiApiKeys.length} total keys)`,
)
} else {
// No callback provided, just update locally
this.options.geminiCurrentApiKeyIndex = currentIndex

// Update the client with the new API key
const apiKey = this.getCurrentApiKey()
this.client = new GoogleGenerativeAI(apiKey)

console.log(
`[GeminiHandler] Rotated to API key index: ${currentIndex} (${this.options.geminiApiKeys.length} total keys)`,
)
}
}
}

async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
// Update the API key if needed before making the request
this.updateApiKeyIfNeeded()

const model = this.client.getGenerativeModel({
model: this.getModel().id,
systemInstruction: systemPrompt,
Expand Down Expand Up @@ -55,6 +194,9 @@ export class GeminiHandler implements ApiHandler, SingleCompletionHandler {

async completePrompt(prompt: string): Promise<string> {
try {
// Update the API key if needed before making the request
this.updateApiKeyIfNeeded()

const model = this.client.getGenerativeModel({
model: this.getModel().id,
})
Expand Down
60 changes: 59 additions & 1 deletion src/core/Cline.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,17 @@ export class Cline {
this.taskId = historyItem ? historyItem.id : crypto.randomUUID()

this.apiConfiguration = apiConfiguration
this.api = buildApiHandler(apiConfiguration)
this.api = buildApiHandler(apiConfiguration, {
onGeminiApiKeyRotation: (newIndex, totalKeys, maskedKey) => {
// Update the global state with the new API key index
this.handleGeminiApiKeyRotation(newIndex, totalKeys, maskedKey)
},
onGeminiRequestCountUpdate: (newCount) => {
// Update the global state with the new request count
this.handleGeminiRequestCountUpdate(newCount)
},
geminiInitialRequestCount: apiConfiguration.geminiRequestCount,
})
this.terminalManager = new TerminalManager()
this.urlContentFetcher = new UrlContentFetcher(provider.context)
this.browserSession = new BrowserSession(provider.context)
Expand Down Expand Up @@ -202,6 +212,54 @@ export class Cline {
this.diffStrategy = getDiffStrategy(this.api.getModel().id, this.fuzzyMatchThreshold, experimentalDiffStrategy)
}

/**
* Handle Gemini API key rotation by updating the global state
* This is called by the GeminiHandler when it rotates to a new API key
*/
private async handleGeminiApiKeyRotation(newIndex: number, totalKeys: number, maskedKey: string) {
console.log(`[Cline] Gemini API key rotated to index ${newIndex} of ${totalKeys} keys (${maskedKey})`)

// Update the global state with the new API key index
const provider = this.providerRef.deref()
if (provider) {
// Update the specific state key for the API key index
await provider.updateGlobalState("geminiCurrentApiKeyIndex", newIndex)

// Also update the apiConfiguration in memory to ensure UI consistency
this.apiConfiguration.geminiCurrentApiKeyIndex = newIndex

// Log the rotation for debugging
provider.log(`Gemini API key rotated to index ${newIndex} of ${totalKeys} keys`)

// Notify the user that the API key has been rotated
await this.say("text", `Gemini API key rotated to key #${newIndex + 1} of ${totalKeys}`)

// Force a state update to the webview to ensure the UI reflects the change
await provider.postStateToWebview()
}
}

/**
* Handle Gemini request count update by updating the global state
* This is called by the GeminiHandler when the request count changes
*/
private async handleGeminiRequestCountUpdate(newCount: number) {
console.log(`[Cline] Gemini request count updated to ${newCount}`)

// Update the global state with the new request count
const provider = this.providerRef.deref()
if (provider) {
// Update the specific state key for the request count
await provider.updateGlobalState("geminiRequestCount", newCount)

// Also update the apiConfiguration in memory to ensure consistency
this.apiConfiguration.geminiRequestCount = newCount

// Log the update for debugging
provider.log(`Gemini request count updated to ${newCount}`)
}
}

// Storing task to disk for history

private async ensureTaskDirectoryExists(): Promise<string> {
Expand Down
Loading

0 comments on commit 2c2547a

Please sign in to comment.