Skip to content

Commit

Permalink
updated anthropic agent with tool
Browse files Browse the repository at this point in the history
  • Loading branch information
brnaba-aws committed Oct 11, 2024
1 parent 70ad728 commit d98c887
Showing 1 changed file with 87 additions and 22 deletions.
109 changes: 87 additions & 22 deletions typescript/src/agents/anthropicAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export interface AnthropicAgentOptions extends AgentOptions {
streaming?: boolean;
toolConfig?: {
tool: Anthropic.Tool[];
useToolHandler: (response: any, conversation: ConversationMessage[]) => void ;
useToolHandler: (response: any, conversation: any[]) => any ;
toolMaxRecursions?: number;
};
// Optional: Configuration for the inference process
Expand All @@ -36,21 +36,21 @@ export interface AnthropicAgentOptions extends AgentOptions {
customSystemPrompt?: {
template: string, variables?: TemplateVariables
};
}
}

type WithApiKey = {
type WithApiKey = {
apiKey: string;
client?: never;
};
};

type WithClient = {
type WithClient = {
client: Anthropic;
apiKey?: never;
};
};

export type AnthropicAgentOptionsWithAuth = AnthropicAgentOptions & (WithApiKey | WithClient);
export type AnthropicAgentOptionsWithAuth = AnthropicAgentOptions & (WithApiKey | WithClient);

export class AnthropicAgent extends Agent {
export class AnthropicAgent extends Agent {

private client: Anthropic;
protected streaming: boolean;
Expand All @@ -67,7 +67,7 @@ export class AnthropicAgent extends Agent {

private toolConfig?: {
tool: Anthropic.Tool[];
useToolHandler: (response: any, conversation: ConversationMessage[]) => void;
useToolHandler: (response: any, conversation: any[]) => any;
toolMaxRecursions?: number;
};

Expand Down Expand Up @@ -162,20 +162,56 @@ export class AnthropicAgent extends Agent {
systemPrompt = systemPrompt + contextPrompt;
}


try {
const response = await this.client.messages.create({
model: this.modelId,
max_tokens: this.inferenceConfig.maxTokens,
messages: messages,
system: systemPrompt,
temperature: this.inferenceConfig.temperature,
top_p: this.inferenceConfig.topP,
tools: this.toolConfig?.tool
});
const textContent = response.content.find(
(content): content is Anthropic.TextBlock => content.type === "text"
);
return {role: ParticipantRole.ASSISTANT, content:[{'text':textContent?.text}]}
if (this.streaming){
return this.handleStreamingResponse(messages, systemPrompt);

} else {
let finalMessage:string = '';
let toolUse = false;
let recursions = this.toolConfig?.toolMaxRecursions || 5;
let tools = this.toolConfig?.tool;
do {

// Call Anthropic
const response = await this.handleSingleResponse({
model: this.modelId,
max_tokens: this.inferenceConfig.maxTokens,
messages: messages,
system: systemPrompt,
temperature: this.inferenceConfig.temperature,
top_p: this.inferenceConfig.topP,
tools: tools,
});

const toolUseBlocks = response.content.filter<Anthropic.ToolUseBlock>(
(content) => content.type === "tool_use",
);

if (toolUseBlocks.length > 0) {
// Append current response to the conversation
messages.push({role:'assistant', content:response.content});
const toolResponse = await this.toolConfig!.useToolHandler(response, messages);
messages.push(toolResponse);
toolUse = true;
} else {
const textContent = response.content.find(
(content): content is Anthropic.TextBlock => content.type === "text"
);
finalMessage = textContent?.text || '';
}

if (response.stop_reason === 'end_turn'){
toolUse = false;
}

recursions--;
}while (toolUse && recursions > 0)


return {role: ParticipantRole.ASSISTANT, content:[{'text':finalMessage}]}
}
}
catch (error) {
Logger.logger.error("Error processing request:", error);
Expand All @@ -184,6 +220,35 @@ export class AnthropicAgent extends Agent {
}
}

protected async handleSingleResponse(input: any): Promise<Anthropic.Message> {
try {
const response = await this.client.messages.create(input);
return response as Anthropic.Message;

} catch (error) {
Logger.logger.error("Error invoking Anthropic:", error);
throw error;
}
}

private async *handleStreamingResponse(messages: any[], prompt:any): AsyncIterable<string> {
const stream = await this.client.messages.stream({
model: this.modelId,
max_tokens: this.inferenceConfig.maxTokens,
messages: messages,
system: prompt,
temperature: this.inferenceConfig.temperature,
top_p: this.inferenceConfig.topP,
tools: this.toolConfig?.tool,
});

for await (const event of stream) {
if (event.type === 'content_block_delta' && event.delta.type === 'text_delta') {
yield event.delta.text;
}
}
}

setSystemPrompt(template?: string, variables?: TemplateVariables): void {
if (template) {
this.promptTemplate = template;
Expand Down

0 comments on commit d98c887

Please sign in to comment.